• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 19926459353

04 Dec 2025 10:53AM UTC coverage: 65.195% (+9.8%) from 55.404%
19926459353

Pull #10420

github

web-flow
Merge 1acf30985 into 20473482d
Pull Request #10420: graph: fix various races

3 of 5 new or added lines in 3 files covered. (60.0%)

25 existing lines in 8 files now uncovered.

137610 of 211074 relevant lines covered (65.2%)

20777.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        color "image/color"
11
        "iter"
12
        "maps"
13
        "math"
14
        "net"
15
        "slices"
16
        "strconv"
17
        "sync"
18
        "time"
19

20
        "github.com/btcsuite/btcd/btcec/v2"
21
        "github.com/btcsuite/btcd/btcutil"
22
        "github.com/btcsuite/btcd/chaincfg/chainhash"
23
        "github.com/btcsuite/btcd/wire"
24
        "github.com/lightningnetwork/lnd/aliasmgr"
25
        "github.com/lightningnetwork/lnd/batch"
26
        "github.com/lightningnetwork/lnd/fn/v2"
27
        "github.com/lightningnetwork/lnd/graph/db/models"
28
        "github.com/lightningnetwork/lnd/lnwire"
29
        "github.com/lightningnetwork/lnd/routing/route"
30
        "github.com/lightningnetwork/lnd/sqldb"
31
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
32
        "github.com/lightningnetwork/lnd/tlv"
33
        "github.com/lightningnetwork/lnd/tor"
34
)
35

36
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
37
// execute queries against the SQL graph tables.
38
//
39
//nolint:ll,interfacebloat
40
type SQLQueries interface {
41
        /*
42
                Node queries.
43
        */
44
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
45
        UpsertSourceNode(ctx context.Context, arg sqlc.UpsertSourceNodeParams) (int64, error)
46
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
47
        GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
48
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
49
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
50
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
51
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
52
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
53
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
54
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
55
        DeleteNode(ctx context.Context, id int64) error
56

57
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
58
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
59
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
60
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
61

62
        UpsertNodeAddress(ctx context.Context, arg sqlc.UpsertNodeAddressParams) error
63
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
64
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
65
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
66

67
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
68
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
69
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
70
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
71
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
72

73
        /*
74
                Source node queries.
75
        */
76
        AddSourceNode(ctx context.Context, nodeID int64) error
77
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
78

79
        /*
80
                Channel queries.
81
        */
82
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
83
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
84
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
85
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
86
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
87
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
88
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
89
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
90
        GetChannelsByIDs(ctx context.Context, ids []int64) ([]sqlc.GetChannelsByIDsRow, error)
91
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
92
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
93
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
94
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
95
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
96
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
97
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
98
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
99
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
100
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
101
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
102
        DeleteChannels(ctx context.Context, ids []int64) error
103

104
        UpsertChannelExtraType(ctx context.Context, arg sqlc.UpsertChannelExtraTypeParams) error
105
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
106
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
107
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
108

109
        /*
110
                Channel Policy table queries.
111
        */
112
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
113
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
114
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
115

116
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
117
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
118
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
119

120
        /*
121
                Zombie index queries.
122
        */
123
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
124
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
125
        GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error)
126
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
127
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
128
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
129

130
        /*
131
                Prune log table queries.
132
        */
133
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
134
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
135
        GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]sqlc.GraphPruneLog, error)
136
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
137
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
138

139
        /*
140
                Closed SCID table queries.
141
        */
142
        InsertClosedChannel(ctx context.Context, scid []byte) error
143
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
144
        GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
145

146
        /*
147
                Migration specific queries.
148

149
                NOTE: these should not be used in code other than migrations.
150
                Once sqldbv2 is in place, these can be removed from this struct
151
                as then migrations will have their own dedicated queries
152
                structs.
153
        */
154
        InsertNodeMig(ctx context.Context, arg sqlc.InsertNodeMigParams) (int64, error)
155
        InsertChannelMig(ctx context.Context, arg sqlc.InsertChannelMigParams) (int64, error)
156
        InsertEdgePolicyMig(ctx context.Context, arg sqlc.InsertEdgePolicyMigParams) (int64, error)
157
}
158

159
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
160
// database operations.
161
type BatchedSQLQueries interface {
162
        SQLQueries
163
        sqldb.BatchedTx[SQLQueries]
164
}
165

166
// SQLStore is an implementation of the V1Store interface that uses a SQL
167
// database as the backend.
168
type SQLStore struct {
169
        cfg *SQLStoreConfig
170
        db  BatchedSQLQueries
171

172
        // cacheMu guards all caches (rejectCache and chanCache). If
173
        // this mutex will be acquired at the same time as the DB mutex then
174
        // the cacheMu MUST be acquired first to prevent deadlock.
175
        cacheMu     sync.RWMutex
176
        rejectCache *rejectCache
177
        chanCache   *channelCache
178

179
        chanScheduler batch.Scheduler[SQLQueries]
180
        nodeScheduler batch.Scheduler[SQLQueries]
181

182
        srcNodes  map[lnwire.GossipVersion]*srcNodeInfo
183
        srcNodeMu sync.Mutex
184
}
185

186
// A compile-time assertion to ensure that SQLStore implements the V1Store
187
// interface.
188
var _ V1Store = (*SQLStore)(nil)
189

190
// SQLStoreConfig holds the configuration for the SQLStore.
191
type SQLStoreConfig struct {
192
        // ChainHash is the genesis hash for the chain that all the gossip
193
        // messages in this store are aimed at.
194
        ChainHash chainhash.Hash
195

196
        // QueryConfig holds configuration values for SQL queries.
197
        QueryCfg *sqldb.QueryConfig
198
}
199

200
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
201
// storage backend.
202
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
203
        options ...StoreOptionModifier) (*SQLStore, error) {
×
204

×
205
        opts := DefaultOptions()
×
206
        for _, o := range options {
×
207
                o(opts)
×
208
        }
×
209

210
        if opts.NoMigration {
×
211
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
212
                        "supported for SQL stores")
×
213
        }
×
214

215
        s := &SQLStore{
×
216
                cfg:         cfg,
×
217
                db:          db,
×
218
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
219
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
220
                srcNodes:    make(map[lnwire.GossipVersion]*srcNodeInfo),
×
221
        }
×
222

×
223
        s.chanScheduler = batch.NewTimeScheduler(
×
224
                db, &s.cacheMu, opts.BatchCommitInterval,
×
225
        )
×
226
        s.nodeScheduler = batch.NewTimeScheduler(
×
227
                db, nil, opts.BatchCommitInterval,
×
228
        )
×
229

×
230
        return s, nil
×
231
}
232

233
// AddNode adds a vertex/node to the graph database. If the node is not
234
// in the database from before, this will add a new, unconnected one to the
235
// graph. If it is present from before, this will update that node's
236
// information.
237
//
238
// NOTE: part of the V1Store interface.
239
func (s *SQLStore) AddNode(ctx context.Context,
240
        node *models.Node, opts ...batch.SchedulerOption) error {
×
241

×
242
        r := &batch.Request[SQLQueries]{
×
243
                Opts: batch.NewSchedulerOptions(opts...),
×
244
                Do: func(queries SQLQueries) error {
×
245
                        _, err := upsertNode(ctx, queries, node)
×
246

×
247
                        // It is possible that two of the same node
×
248
                        // announcements are both being processed in the same
×
249
                        // batch. This may case the UpsertNode conflict to
×
250
                        // be hit since we require at the db layer that the
×
251
                        // new last_update is greater than the existing
×
252
                        // last_update. We need to gracefully handle this here.
×
253
                        if errors.Is(err, sql.ErrNoRows) {
×
254
                                return nil
×
255
                        }
×
256

257
                        return err
×
258
                },
259
        }
260

261
        return s.nodeScheduler.Execute(ctx, r)
×
262
}
263

264
// FetchNode attempts to look up a target node by its identity public
265
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
266
// returned.
267
//
268
// NOTE: part of the V1Store interface.
269
func (s *SQLStore) FetchNode(ctx context.Context,
270
        pubKey route.Vertex) (*models.Node, error) {
×
271

×
272
        var node *models.Node
×
273
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
274
                var err error
×
275
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, pubKey)
×
276

×
277
                return err
×
278
        }, sqldb.NoOpReset)
×
279
        if err != nil {
×
280
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
281
        }
×
282

283
        return node, nil
×
284
}
285

286
// HasNode determines if the graph has a vertex identified by the
287
// target node identity public key. If the node exists in the database, a
288
// timestamp of when the data for the node was lasted updated is returned along
289
// with a true boolean. Otherwise, an empty time.Time is returned with a false
290
// boolean.
291
//
292
// NOTE: part of the V1Store interface.
293
func (s *SQLStore) HasNode(ctx context.Context,
294
        pubKey [33]byte) (time.Time, bool, error) {
×
295

×
296
        var (
×
297
                exists     bool
×
298
                lastUpdate time.Time
×
299
        )
×
300
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
301
                dbNode, err := db.GetNodeByPubKey(
×
302
                        ctx, sqlc.GetNodeByPubKeyParams{
×
303
                                Version: int16(lnwire.GossipVersion1),
×
304
                                PubKey:  pubKey[:],
×
305
                        },
×
306
                )
×
307
                if errors.Is(err, sql.ErrNoRows) {
×
308
                        return nil
×
309
                } else if err != nil {
×
310
                        return fmt.Errorf("unable to fetch node: %w", err)
×
311
                }
×
312

313
                exists = true
×
314

×
315
                if dbNode.LastUpdate.Valid {
×
316
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
317
                }
×
318

319
                return nil
×
320
        }, sqldb.NoOpReset)
321
        if err != nil {
×
322
                return time.Time{}, false,
×
323
                        fmt.Errorf("unable to fetch node: %w", err)
×
324
        }
×
325

326
        return lastUpdate, exists, nil
×
327
}
328

329
// AddrsForNode returns all known addresses for the target node public key
330
// that the graph DB is aware of. The returned boolean indicates if the
331
// given node is unknown to the graph DB or not.
332
//
333
// NOTE: part of the V1Store interface.
334
func (s *SQLStore) AddrsForNode(ctx context.Context,
335
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
336

×
337
        var (
×
338
                addresses []net.Addr
×
339
                known     bool
×
340
        )
×
341
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
342
                // First, check if the node exists and get its DB ID if it
×
343
                // does.
×
344
                dbID, err := db.GetNodeIDByPubKey(
×
345
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
346
                                Version: int16(lnwire.GossipVersion1),
×
347
                                PubKey:  nodePub.SerializeCompressed(),
×
348
                        },
×
349
                )
×
350
                if errors.Is(err, sql.ErrNoRows) {
×
351
                        return nil
×
352
                }
×
353

354
                known = true
×
355

×
356
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
357
                if err != nil {
×
358
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
359
                                err)
×
360
                }
×
361

362
                return nil
×
363
        }, sqldb.NoOpReset)
364
        if err != nil {
×
365
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
366
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
367
        }
×
368

369
        return known, addresses, nil
×
370
}
371

372
// DeleteNode starts a new database transaction to remove a vertex/node
373
// from the database according to the node's public key.
374
//
375
// NOTE: part of the V1Store interface.
376
func (s *SQLStore) DeleteNode(ctx context.Context,
377
        pubKey route.Vertex) error {
×
378

×
379
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
380
                res, err := db.DeleteNodeByPubKey(
×
381
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
382
                                Version: int16(lnwire.GossipVersion1),
×
383
                                PubKey:  pubKey[:],
×
384
                        },
×
385
                )
×
386
                if err != nil {
×
387
                        return err
×
388
                }
×
389

390
                rows, err := res.RowsAffected()
×
391
                if err != nil {
×
392
                        return err
×
393
                }
×
394

395
                if rows == 0 {
×
396
                        return ErrGraphNodeNotFound
×
397
                } else if rows > 1 {
×
398
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
399
                }
×
400

401
                return err
×
402
        }, sqldb.NoOpReset)
403
        if err != nil {
×
404
                return fmt.Errorf("unable to delete node: %w", err)
×
405
        }
×
406

407
        return nil
×
408
}
409

410
// FetchNodeFeatures returns the features of the given node. If no features are
411
// known for the node, an empty feature vector is returned.
412
//
413
// NOTE: this is part of the graphdb.NodeTraverser interface.
414
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
415
        *lnwire.FeatureVector, error) {
×
416

×
417
        ctx := context.TODO()
×
418

×
419
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
420
}
×
421

422
// DisabledChannelIDs returns the channel ids of disabled channels.
423
// A channel is disabled when two of the associated ChanelEdgePolicies
424
// have their disabled bit on.
425
//
426
// NOTE: part of the V1Store interface.
427
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
428
        var (
×
429
                ctx     = context.TODO()
×
430
                chanIDs []uint64
×
431
        )
×
432
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
433
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
434
                if err != nil {
×
435
                        return fmt.Errorf("unable to fetch disabled "+
×
436
                                "channels: %w", err)
×
437
                }
×
438

439
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
440

×
441
                return nil
×
442
        }, sqldb.NoOpReset)
443
        if err != nil {
×
444
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
445
                        err)
×
446
        }
×
447

448
        return chanIDs, nil
×
449
}
450

451
// LookupAlias attempts to return the alias as advertised by the target node.
452
//
453
// NOTE: part of the V1Store interface.
454
func (s *SQLStore) LookupAlias(ctx context.Context,
455
        pub *btcec.PublicKey) (string, error) {
×
456

×
457
        var alias string
×
458
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
459
                dbNode, err := db.GetNodeByPubKey(
×
460
                        ctx, sqlc.GetNodeByPubKeyParams{
×
461
                                Version: int16(lnwire.GossipVersion1),
×
462
                                PubKey:  pub.SerializeCompressed(),
×
463
                        },
×
464
                )
×
465
                if errors.Is(err, sql.ErrNoRows) {
×
466
                        return ErrNodeAliasNotFound
×
467
                } else if err != nil {
×
468
                        return fmt.Errorf("unable to fetch node: %w", err)
×
469
                }
×
470

471
                if !dbNode.Alias.Valid {
×
472
                        return ErrNodeAliasNotFound
×
473
                }
×
474

475
                alias = dbNode.Alias.String
×
476

×
477
                return nil
×
478
        }, sqldb.NoOpReset)
479
        if err != nil {
×
480
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
481
        }
×
482

483
        return alias, nil
×
484
}
485

486
// SourceNode returns the source node of the graph. The source node is treated
487
// as the center node within a star-graph. This method may be used to kick off
488
// a path finding algorithm in order to explore the reachability of another
489
// node based off the source node.
490
//
491
// NOTE: part of the V1Store interface.
492
func (s *SQLStore) SourceNode(ctx context.Context) (*models.Node,
493
        error) {
×
494

×
495
        var node *models.Node
×
496
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
497
                _, nodePub, err := s.getSourceNode(
×
498
                        ctx, db, lnwire.GossipVersion1,
×
499
                )
×
500
                if err != nil {
×
501
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
502
                                err)
×
503
                }
×
504

505
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, nodePub)
×
506

×
507
                return err
×
508
        }, sqldb.NoOpReset)
509
        if err != nil {
×
510
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
511
        }
×
512

513
        return node, nil
×
514
}
515

516
// SetSourceNode sets the source node within the graph database. The source
517
// node is to be used as the center of a star-graph within path finding
518
// algorithms.
519
//
520
// NOTE: part of the V1Store interface.
521
func (s *SQLStore) SetSourceNode(ctx context.Context,
522
        node *models.Node) error {
×
523

×
524
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
525
                // For the source node, we use a less strict upsert that allows
×
526
                // updates even when the timestamp hasn't changed. This handles
×
527
                // the race condition where multiple goroutines (e.g.,
×
528
                // setSelfNode, createNewHiddenService, RPC updates) read the
×
529
                // same old timestamp, independently increment it, and try to
×
530
                // write concurrently. We want all parameter changes to persist,
×
531
                // even if timestamps collide.
×
532
                id, err := upsertSourceNode(ctx, db, node)
×
533
                if err != nil {
×
534
                        return fmt.Errorf("unable to upsert source node: %w",
×
535
                                err)
×
536
                }
×
537

538
                // Make sure that if a source node for this version is already
539
                // set, then the ID is the same as the one we are about to set.
540
                dbSourceNodeID, _, err := s.getSourceNode(
×
541
                        ctx, db, lnwire.GossipVersion1,
×
542
                )
×
543
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
544
                        return fmt.Errorf("unable to fetch source node: %w",
×
545
                                err)
×
546
                } else if err == nil {
×
547
                        if dbSourceNodeID != id {
×
548
                                return fmt.Errorf("v1 source node already "+
×
549
                                        "set to a different node: %d vs %d",
×
550
                                        dbSourceNodeID, id)
×
551
                        }
×
552

553
                        return nil
×
554
                }
555

556
                return db.AddSourceNode(ctx, id)
×
557
        }, sqldb.NoOpReset)
558
}
559

560
// NodeUpdatesInHorizon returns all the known lightning node which have an
561
// update timestamp within the passed range. This method can be used by two
562
// nodes to quickly determine if they have the same set of up to date node
563
// announcements.
564
//
565
// NOTE: This is part of the V1Store interface.
566
func (s *SQLStore) NodeUpdatesInHorizon(startTime, endTime time.Time,
567
        opts ...IteratorOption) iter.Seq2[*models.Node, error] {
×
568

×
569
        cfg := defaultIteratorConfig()
×
570
        for _, opt := range opts {
×
571
                opt(cfg)
×
572
        }
×
573

574
        return func(yield func(*models.Node, error) bool) {
×
575
                var (
×
576
                        ctx            = context.TODO()
×
577
                        lastUpdateTime sql.NullInt64
×
578
                        lastPubKey     = make([]byte, 33)
×
579
                        hasMore        = true
×
580
                )
×
581

×
582
                // Each iteration, we'll read a batch amount of nodes, yield
×
583
                // them, then decide is we have more or not.
×
584
                for hasMore {
×
585
                        var batch []*models.Node
×
586

×
587
                        //nolint:ll
×
588
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
589
                                //nolint:ll
×
590
                                params := sqlc.GetNodesByLastUpdateRangeParams{
×
591
                                        StartTime: sqldb.SQLInt64(
×
592
                                                startTime.Unix(),
×
593
                                        ),
×
594
                                        EndTime: sqldb.SQLInt64(
×
595
                                                endTime.Unix(),
×
596
                                        ),
×
597
                                        LastUpdate: lastUpdateTime,
×
598
                                        LastPubKey: lastPubKey,
×
599
                                        OnlyPublic: sql.NullBool{
×
600
                                                Bool:  cfg.iterPublicNodes,
×
601
                                                Valid: true,
×
602
                                        },
×
603
                                        MaxResults: sqldb.SQLInt32(
×
604
                                                cfg.nodeUpdateIterBatchSize,
×
605
                                        ),
×
606
                                }
×
607
                                rows, err := db.GetNodesByLastUpdateRange(
×
608
                                        ctx, params,
×
609
                                )
×
610
                                if err != nil {
×
611
                                        return err
×
612
                                }
×
613

614
                                hasMore = len(rows) == cfg.nodeUpdateIterBatchSize
×
615

×
616
                                err = forEachNodeInBatch(
×
617
                                        ctx, s.cfg.QueryCfg, db, rows,
×
618
                                        func(_ int64, node *models.Node) error {
×
619
                                                batch = append(batch, node)
×
620

×
621
                                                // Update pagination cursors
×
622
                                                // based on the last processed
×
623
                                                // node.
×
624
                                                lastUpdateTime = sql.NullInt64{
×
625
                                                        Int64: node.LastUpdate.
×
626
                                                                Unix(),
×
627
                                                        Valid: true,
×
628
                                                }
×
629
                                                lastPubKey = node.PubKeyBytes[:]
×
630

×
631
                                                return nil
×
632
                                        },
×
633
                                )
634
                                if err != nil {
×
635
                                        return fmt.Errorf("unable to build "+
×
636
                                                "nodes: %w", err)
×
637
                                }
×
638

639
                                return nil
×
640
                        }, func() {
×
641
                                batch = []*models.Node{}
×
642
                        })
×
643

644
                        if err != nil {
×
645
                                log.Errorf("NodeUpdatesInHorizon batch "+
×
646
                                        "error: %v", err)
×
647

×
648
                                yield(&models.Node{}, err)
×
649

×
650
                                return
×
651
                        }
×
652

653
                        for _, node := range batch {
×
654
                                if !yield(node, nil) {
×
655
                                        return
×
656
                                }
×
657
                        }
658

659
                        // If the batch didn't yield anything, then we're done.
660
                        if len(batch) == 0 {
×
661
                                break
×
662
                        }
663
                }
664
        }
665
}
666

667
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
668
// undirected edge from the two target nodes are created. The information stored
669
// denotes the static attributes of the channel, such as the channelID, the keys
670
// involved in creation of the channel, and the set of features that the channel
671
// supports. The chanPoint and chanID are used to uniquely identify the edge
672
// globally within the database.
673
//
674
// NOTE: part of the V1Store interface.
675
func (s *SQLStore) AddChannelEdge(ctx context.Context,
676
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
677

×
678
        var alreadyExists bool
×
679
        r := &batch.Request[SQLQueries]{
×
680
                Opts: batch.NewSchedulerOptions(opts...),
×
681
                Reset: func() {
×
682
                        alreadyExists = false
×
683
                },
×
684
                Do: func(tx SQLQueries) error {
×
685
                        chanIDB := channelIDToBytes(edge.ChannelID)
×
686

×
687
                        // Make sure that the channel doesn't already exist. We
×
688
                        // do this explicitly instead of relying on catching a
×
689
                        // unique constraint error because relying on SQL to
×
690
                        // throw that error would abort the entire batch of
×
691
                        // transactions.
×
692
                        _, err := tx.GetChannelBySCID(
×
693
                                ctx, sqlc.GetChannelBySCIDParams{
×
694
                                        Scid:    chanIDB,
×
695
                                        Version: int16(lnwire.GossipVersion1),
×
696
                                },
×
697
                        )
×
698
                        if err == nil {
×
699
                                alreadyExists = true
×
700
                                return nil
×
701
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
702
                                return fmt.Errorf("unable to fetch channel: %w",
×
703
                                        err)
×
704
                        }
×
705

706
                        return insertChannel(ctx, tx, edge)
×
707
                },
708
                OnCommit: func(err error) error {
×
709
                        switch {
×
710
                        case err != nil:
×
711
                                return err
×
712
                        case alreadyExists:
×
713
                                return ErrEdgeAlreadyExist
×
714
                        default:
×
715
                                s.rejectCache.remove(edge.ChannelID)
×
716
                                s.chanCache.remove(edge.ChannelID)
×
717
                                return nil
×
718
                        }
719
                },
720
        }
721

722
        return s.chanScheduler.Execute(ctx, r)
×
723
}
724

725
// HighestChanID returns the "highest" known channel ID in the channel graph.
726
// This represents the "newest" channel from the PoV of the chain. This method
727
// can be used by peers to quickly determine if their graphs are in sync.
728
//
729
// NOTE: This is part of the V1Store interface.
730
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
731
        var highestChanID uint64
×
732
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
733
                chanID, err := db.HighestSCID(ctx, int16(lnwire.GossipVersion1))
×
734
                if errors.Is(err, sql.ErrNoRows) {
×
735
                        return nil
×
736
                } else if err != nil {
×
737
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
738
                                err)
×
739
                }
×
740

741
                highestChanID = byteOrder.Uint64(chanID)
×
742

×
743
                return nil
×
744
        }, sqldb.NoOpReset)
745
        if err != nil {
×
746
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
747
        }
×
748

749
        return highestChanID, nil
×
750
}
751

752
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
753
// within the database for the referenced channel. The `flags` attribute within
754
// the ChannelEdgePolicy determines which of the directed edges are being
755
// updated. If the flag is 1, then the first node's information is being
756
// updated, otherwise it's the second node's information. The node ordering is
757
// determined by the lexicographical ordering of the identity public keys of the
758
// nodes on either side of the channel.
759
//
760
// NOTE: part of the V1Store interface.
761
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
762
        edge *models.ChannelEdgePolicy,
763
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
764

×
765
        var (
×
766
                isUpdate1    bool
×
767
                edgeNotFound bool
×
768
                from, to     route.Vertex
×
769
        )
×
770

×
771
        r := &batch.Request[SQLQueries]{
×
772
                Opts: batch.NewSchedulerOptions(opts...),
×
773
                Reset: func() {
×
774
                        isUpdate1 = false
×
775
                        edgeNotFound = false
×
776
                },
×
777
                Do: func(tx SQLQueries) error {
×
778
                        var err error
×
779
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
780
                                ctx, tx, edge,
×
781
                        )
×
782
                        // It is possible that two of the same policy
×
783
                        // announcements are both being processed in the same
×
784
                        // batch. This may case the UpsertEdgePolicy conflict to
×
785
                        // be hit since we require at the db layer that the
×
786
                        // new last_update is greater than the existing
×
787
                        // last_update. We need to gracefully handle this here.
×
788
                        if errors.Is(err, sql.ErrNoRows) {
×
789
                                return nil
×
790
                        } else if err != nil {
×
791
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
792
                        }
×
793

794
                        // Silence ErrEdgeNotFound so that the batch can
795
                        // succeed, but propagate the error via local state.
796
                        if errors.Is(err, ErrEdgeNotFound) {
×
797
                                edgeNotFound = true
×
798
                                return nil
×
799
                        }
×
800

801
                        return err
×
802
                },
803
                OnCommit: func(err error) error {
×
804
                        switch {
×
805
                        case err != nil:
×
806
                                return err
×
807
                        case edgeNotFound:
×
808
                                return ErrEdgeNotFound
×
809
                        default:
×
810
                                s.updateEdgeCache(edge, isUpdate1)
×
811
                                return nil
×
812
                        }
813
                },
814
        }
815

816
        err := s.chanScheduler.Execute(ctx, r)
×
817

×
818
        return from, to, err
×
819
}
820

821
// updateEdgeCache updates our reject and channel caches with the new
822
// edge policy information.
823
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
824
        isUpdate1 bool) {
×
825

×
826
        // If an entry for this channel is found in reject cache, we'll modify
×
827
        // the entry with the updated timestamp for the direction that was just
×
828
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
829
        // during the next query for this edge.
×
830
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
831
                if isUpdate1 {
×
832
                        entry.upd1Time = e.LastUpdate.Unix()
×
833
                } else {
×
834
                        entry.upd2Time = e.LastUpdate.Unix()
×
835
                }
×
836
                s.rejectCache.insert(e.ChannelID, entry)
×
837
        }
838

839
        // If an entry for this channel is found in channel cache, we'll modify
840
        // the entry with the updated policy for the direction that was just
841
        // written. If the edge doesn't exist, we'll defer loading the info and
842
        // policies and lazily read from disk during the next query.
843
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
844
                if isUpdate1 {
×
845
                        channel.Policy1 = e
×
846
                } else {
×
847
                        channel.Policy2 = e
×
848
                }
×
849
                s.chanCache.insert(e.ChannelID, channel)
×
850
        }
851
}
852

853
// ForEachSourceNodeChannel iterates through all channels of the source node,
854
// executing the passed callback on each. The call-back is provided with the
855
// channel's outpoint, whether we have a policy for the channel and the channel
856
// peer's node information.
857
//
858
// NOTE: part of the V1Store interface.
859
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
860
        cb func(chanPoint wire.OutPoint, havePolicy bool,
861
                otherNode *models.Node) error, reset func()) error {
×
862

×
863
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
864
                nodeID, nodePub, err := s.getSourceNode(
×
865
                        ctx, db, lnwire.GossipVersion1,
×
866
                )
×
867
                if err != nil {
×
868
                        return fmt.Errorf("unable to fetch source node: %w",
×
869
                                err)
×
870
                }
×
871

872
                return forEachNodeChannel(
×
873
                        ctx, db, s.cfg, nodeID,
×
874
                        func(info *models.ChannelEdgeInfo,
×
875
                                outPolicy *models.ChannelEdgePolicy,
×
876
                                _ *models.ChannelEdgePolicy) error {
×
877

×
878
                                // Fetch the other node.
×
879
                                var (
×
880
                                        otherNodePub [33]byte
×
881
                                        node1        = info.NodeKey1Bytes
×
882
                                        node2        = info.NodeKey2Bytes
×
883
                                )
×
884
                                switch {
×
885
                                case bytes.Equal(node1[:], nodePub[:]):
×
886
                                        otherNodePub = node2
×
887
                                case bytes.Equal(node2[:], nodePub[:]):
×
888
                                        otherNodePub = node1
×
889
                                default:
×
890
                                        return fmt.Errorf("node not " +
×
891
                                                "participating in this channel")
×
892
                                }
893

894
                                _, otherNode, err := getNodeByPubKey(
×
895
                                        ctx, s.cfg.QueryCfg, db, otherNodePub,
×
896
                                )
×
897
                                if err != nil {
×
898
                                        return fmt.Errorf("unable to fetch "+
×
899
                                                "other node(%x): %w",
×
900
                                                otherNodePub, err)
×
901
                                }
×
902

903
                                return cb(
×
904
                                        info.ChannelPoint, outPolicy != nil,
×
905
                                        otherNode,
×
906
                                )
×
907
                        },
908
                )
909
        }, reset)
910
}
911

912
// ForEachNode iterates through all the stored vertices/nodes in the graph,
913
// executing the passed callback with each node encountered. If the callback
914
// returns an error, then the transaction is aborted and the iteration stops
915
// early.
916
//
917
// NOTE: part of the V1Store interface.
918
func (s *SQLStore) ForEachNode(ctx context.Context,
919
        cb func(node *models.Node) error, reset func()) error {
×
920

×
921
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
922
                return forEachNodePaginated(
×
923
                        ctx, s.cfg.QueryCfg, db,
×
924
                        lnwire.GossipVersion1, func(_ context.Context, _ int64,
×
925
                                node *models.Node) error {
×
926

×
927
                                return cb(node)
×
928
                        },
×
929
                )
930
        }, reset)
931
}
932

933
// ForEachNodeDirectedChannel iterates through all channels of a given node,
934
// executing the passed callback on the directed edge representing the channel
935
// and its incoming policy. If the callback returns an error, then the iteration
936
// is halted with the error propagated back up to the caller.
937
//
938
// Unknown policies are passed into the callback as nil values.
939
//
940
// NOTE: this is part of the graphdb.NodeTraverser interface.
941
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
942
        cb func(channel *DirectedChannel) error, reset func()) error {
×
943

×
944
        var ctx = context.TODO()
×
945

×
946
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
947
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
948
        }, reset)
×
949
}
950

951
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
952
// graph, executing the passed callback with each node encountered. If the
953
// callback returns an error, then the transaction is aborted and the iteration
954
// stops early.
955
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
956
        cb func(route.Vertex, *lnwire.FeatureVector) error,
957
        reset func()) error {
×
958

×
959
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
960
                return forEachNodeCacheable(
×
961
                        ctx, s.cfg.QueryCfg, db,
×
962
                        func(_ int64, nodePub route.Vertex,
×
963
                                features *lnwire.FeatureVector) error {
×
964

×
965
                                return cb(nodePub, features)
×
966
                        },
×
967
                )
968
        }, reset)
969
        if err != nil {
×
970
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
971
        }
×
972

973
        return nil
×
974
}
975

976
// ForEachNodeChannel iterates through all channels of the given node,
977
// executing the passed callback with an edge info structure and the policies
978
// of each end of the channel. The first edge policy is the outgoing edge *to*
979
// the connecting node, while the second is the incoming edge *from* the
980
// connecting node. If the callback returns an error, then the iteration is
981
// halted with the error propagated back up to the caller.
982
//
983
// Unknown policies are passed into the callback as nil values.
984
//
985
// NOTE: part of the V1Store interface.
986
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
987
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
988
                *models.ChannelEdgePolicy) error, reset func()) error {
×
989

×
990
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
991
                dbNode, err := db.GetNodeByPubKey(
×
992
                        ctx, sqlc.GetNodeByPubKeyParams{
×
993
                                Version: int16(lnwire.GossipVersion1),
×
994
                                PubKey:  nodePub[:],
×
995
                        },
×
996
                )
×
997
                if errors.Is(err, sql.ErrNoRows) {
×
998
                        return nil
×
999
                } else if err != nil {
×
1000
                        return fmt.Errorf("unable to fetch node: %w", err)
×
1001
                }
×
1002

1003
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
1004
        }, reset)
1005
}
1006

1007
// extractMaxUpdateTime returns the maximum of the two policy update times.
1008
// This is used for pagination cursor tracking.
1009
func extractMaxUpdateTime(
1010
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) int64 {
×
1011

×
1012
        switch {
×
1013
        case row.Policy1LastUpdate.Valid && row.Policy2LastUpdate.Valid:
×
1014
                return max(row.Policy1LastUpdate.Int64,
×
1015
                        row.Policy2LastUpdate.Int64)
×
1016
        case row.Policy1LastUpdate.Valid:
×
1017
                return row.Policy1LastUpdate.Int64
×
1018
        case row.Policy2LastUpdate.Valid:
×
1019
                return row.Policy2LastUpdate.Int64
×
1020
        default:
×
1021
                return 0
×
1022
        }
1023
}
1024

1025
// buildChannelFromRow constructs a ChannelEdge from a database row.
1026
// This includes building the nodes, channel info, and policies.
1027
func (s *SQLStore) buildChannelFromRow(ctx context.Context, db SQLQueries,
1028
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) (ChannelEdge, error) {
×
1029

×
1030
        node1, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode)
×
1031
        if err != nil {
×
1032
                return ChannelEdge{}, fmt.Errorf("unable to build node1: %w",
×
1033
                        err)
×
1034
        }
×
1035

1036
        node2, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode_2)
×
1037
        if err != nil {
×
1038
                return ChannelEdge{}, fmt.Errorf("unable to build node2: %w",
×
1039
                        err)
×
1040
        }
×
1041

1042
        channel, err := getAndBuildEdgeInfo(
×
1043
                ctx, s.cfg, db,
×
1044
                row.GraphChannel, node1.PubKeyBytes,
×
1045
                node2.PubKeyBytes,
×
1046
        )
×
1047
        if err != nil {
×
1048
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1049
                        "channel info: %w", err)
×
1050
        }
×
1051

1052
        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1053
        if err != nil {
×
1054
                return ChannelEdge{}, fmt.Errorf("unable to extract "+
×
1055
                        "channel policies: %w", err)
×
1056
        }
×
1057

1058
        p1, p2, err := getAndBuildChanPolicies(
×
1059
                ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, channel.ChannelID,
×
1060
                node1.PubKeyBytes, node2.PubKeyBytes,
×
1061
        )
×
1062
        if err != nil {
×
1063
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1064
                        "channel policies: %w", err)
×
1065
        }
×
1066

1067
        return ChannelEdge{
×
1068
                Info:    channel,
×
1069
                Policy1: p1,
×
1070
                Policy2: p2,
×
1071
                Node1:   node1,
×
1072
                Node2:   node2,
×
1073
        }, nil
×
1074
}
1075

1076
// updateChanCacheBatch updates the channel cache with multiple edges at once.
1077
// This method acquires the cache lock only once for the entire batch.
1078
func (s *SQLStore) updateChanCacheBatch(edgesToCache map[uint64]ChannelEdge) {
×
1079
        if len(edgesToCache) == 0 {
×
1080
                return
×
1081
        }
×
1082

1083
        s.cacheMu.Lock()
×
1084
        defer s.cacheMu.Unlock()
×
1085

×
1086
        for chanID, edge := range edgesToCache {
×
1087
                s.chanCache.insert(chanID, edge)
×
1088
        }
×
1089
}
1090

1091
// ChanUpdatesInHorizon returns all the known channel edges which have at least
1092
// one edge that has an update timestamp within the specified horizon.
1093
//
1094
// Iterator Lifecycle:
1095
// 1. Initialize state (edgesSeen map, cache tracking, pagination cursors)
1096
// 2. Query batch of channels with policies in time range
1097
// 3. For each channel: check if seen, check cache, or build from DB
1098
// 4. Yield channels to caller
1099
// 5. Update cache after successful batch
1100
// 6. Repeat with updated pagination cursor until no more results
1101
//
1102
// NOTE: This is part of the V1Store interface.
1103
func (s *SQLStore) ChanUpdatesInHorizon(startTime, endTime time.Time,
1104
        opts ...IteratorOption) iter.Seq2[ChannelEdge, error] {
×
1105

×
1106
        // Apply options.
×
1107
        cfg := defaultIteratorConfig()
×
1108
        for _, opt := range opts {
×
1109
                opt(cfg)
×
1110
        }
×
1111

1112
        return func(yield func(ChannelEdge, error) bool) {
×
1113
                var (
×
1114
                        ctx            = context.TODO()
×
1115
                        edgesSeen      = make(map[uint64]struct{})
×
1116
                        edgesToCache   = make(map[uint64]ChannelEdge)
×
1117
                        hits           int
×
1118
                        total          int
×
1119
                        lastUpdateTime sql.NullInt64
×
1120
                        lastID         sql.NullInt64
×
1121
                        hasMore        = true
×
1122
                )
×
1123

×
1124
                // Each iteration, we'll read a batch amount of channel updates
×
1125
                // (consulting the cache along the way), yield them, then loop
×
1126
                // back to decide if we have any more updates to read out.
×
1127
                for hasMore {
×
1128
                        var batch []ChannelEdge
×
1129

×
1130
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(),
×
1131
                                func(db SQLQueries) error {
×
1132
                                        //nolint:ll
×
1133
                                        params := sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1134
                                                Version: int16(lnwire.GossipVersion1),
×
1135
                                                StartTime: sqldb.SQLInt64(
×
1136
                                                        startTime.Unix(),
×
1137
                                                ),
×
1138
                                                EndTime: sqldb.SQLInt64(
×
1139
                                                        endTime.Unix(),
×
1140
                                                ),
×
1141
                                                LastUpdateTime: lastUpdateTime,
×
1142
                                                LastID:         lastID,
×
1143
                                                MaxResults: sql.NullInt32{
×
1144
                                                        Int32: int32(
×
1145
                                                                cfg.chanUpdateIterBatchSize,
×
1146
                                                        ),
×
1147
                                                        Valid: true,
×
1148
                                                },
×
1149
                                        }
×
1150
                                        //nolint:ll
×
1151
                                        rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1152
                                                ctx, params,
×
1153
                                        )
×
1154
                                        if err != nil {
×
1155
                                                return err
×
1156
                                        }
×
1157

1158
                                        //nolint:ll
1159
                                        hasMore = len(rows) == cfg.chanUpdateIterBatchSize
×
1160

×
1161
                                        //nolint:ll
×
1162
                                        for _, row := range rows {
×
1163
                                                lastUpdateTime = sql.NullInt64{
×
1164
                                                        Int64: extractMaxUpdateTime(row),
×
1165
                                                        Valid: true,
×
1166
                                                }
×
1167
                                                lastID = sql.NullInt64{
×
1168
                                                        Int64: row.GraphChannel.ID,
×
1169
                                                        Valid: true,
×
1170
                                                }
×
1171

×
1172
                                                // Skip if we've already
×
1173
                                                // processed this channel.
×
1174
                                                chanIDInt := byteOrder.Uint64(
×
1175
                                                        row.GraphChannel.Scid,
×
1176
                                                )
×
1177
                                                _, ok := edgesSeen[chanIDInt]
×
1178
                                                if ok {
×
1179
                                                        continue
×
1180
                                                }
1181

1182
                                                s.cacheMu.RLock()
×
1183
                                                channel, ok := s.chanCache.get(
×
1184
                                                        chanIDInt,
×
1185
                                                )
×
1186
                                                s.cacheMu.RUnlock()
×
1187
                                                if ok {
×
1188
                                                        hits++
×
1189
                                                        total++
×
1190
                                                        edgesSeen[chanIDInt] = struct{}{}
×
1191
                                                        batch = append(batch, channel)
×
1192

×
1193
                                                        continue
×
1194
                                                }
1195

1196
                                                chanEdge, err := s.buildChannelFromRow(
×
1197
                                                        ctx, db, row,
×
1198
                                                )
×
1199
                                                if err != nil {
×
1200
                                                        return err
×
1201
                                                }
×
1202

1203
                                                edgesSeen[chanIDInt] = struct{}{}
×
1204
                                                edgesToCache[chanIDInt] = chanEdge
×
1205

×
1206
                                                batch = append(batch, chanEdge)
×
1207

×
1208
                                                total++
×
1209
                                        }
1210

1211
                                        return nil
×
1212
                                }, func() {
×
1213
                                        batch = nil
×
1214
                                        edgesSeen = make(map[uint64]struct{})
×
1215
                                        edgesToCache = make(
×
1216
                                                map[uint64]ChannelEdge,
×
1217
                                        )
×
1218
                                })
×
1219

1220
                        if err != nil {
×
1221
                                log.Errorf("ChanUpdatesInHorizon "+
×
1222
                                        "batch error: %v", err)
×
1223

×
1224
                                yield(ChannelEdge{}, err)
×
1225

×
1226
                                return
×
1227
                        }
×
1228

1229
                        for _, edge := range batch {
×
1230
                                if !yield(edge, nil) {
×
1231
                                        return
×
1232
                                }
×
1233
                        }
1234

1235
                        // Update cache after successful batch yield, setting
1236
                        // the cache lock only once for the entire batch.
1237
                        s.updateChanCacheBatch(edgesToCache)
×
1238
                        edgesToCache = make(map[uint64]ChannelEdge)
×
1239

×
1240
                        // If the batch didn't yield anything, then we're done.
×
1241
                        if len(batch) == 0 {
×
1242
                                break
×
1243
                        }
1244
                }
1245

1246
                if total > 0 {
×
1247
                        log.Debugf("ChanUpdatesInHorizon hit percentage: "+
×
1248
                                "%.2f (%d/%d)",
×
1249
                                float64(hits)*100/float64(total), hits, total)
×
1250
                } else {
×
1251
                        log.Debugf("ChanUpdatesInHorizon returned no edges "+
×
1252
                                "in horizon (%s, %s)", startTime, endTime)
×
1253
                }
×
1254
        }
1255
}
1256

1257
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1258
// data to the call-back. If withAddrs is true, then the call-back will also be
1259
// provided with the addresses associated with the node. The address retrieval
1260
// result in an additional round-trip to the database, so it should only be used
1261
// if the addresses are actually needed.
1262
//
1263
// NOTE: part of the V1Store interface.
1264
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1265
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1266
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1267

×
1268
        type nodeCachedBatchData struct {
×
1269
                features      map[int64][]int
×
1270
                addrs         map[int64][]nodeAddress
×
1271
                chanBatchData *batchChannelData
×
1272
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1273
        }
×
1274

×
1275
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1276
                // pageQueryFunc is used to query the next page of nodes.
×
1277
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1278
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1279

×
1280
                        return db.ListNodeIDsAndPubKeys(
×
1281
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1282
                                        Version: int16(lnwire.GossipVersion1),
×
1283
                                        ID:      lastID,
×
1284
                                        Limit:   limit,
×
1285
                                },
×
1286
                        )
×
1287
                }
×
1288

1289
                // batchDataFunc is then used to batch load the data required
1290
                // for each page of nodes.
1291
                batchDataFunc := func(ctx context.Context,
×
1292
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1293

×
1294
                        // Batch load node features.
×
1295
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1296
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1297
                        )
×
1298
                        if err != nil {
×
1299
                                return nil, fmt.Errorf("unable to batch load "+
×
1300
                                        "node features: %w", err)
×
1301
                        }
×
1302

1303
                        // Maybe fetch the node's addresses if requested.
1304
                        var nodeAddrs map[int64][]nodeAddress
×
1305
                        if withAddrs {
×
1306
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1307
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1308
                                )
×
1309
                                if err != nil {
×
1310
                                        return nil, fmt.Errorf("unable to "+
×
1311
                                                "batch load node "+
×
1312
                                                "addresses: %w", err)
×
1313
                                }
×
1314
                        }
1315

1316
                        // Batch load ALL unique channels for ALL nodes in this
1317
                        // page.
1318
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1319
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1320
                                        Version:  int16(lnwire.GossipVersion1),
×
1321
                                        Node1Ids: nodeIDs,
×
1322
                                        Node2Ids: nodeIDs,
×
1323
                                },
×
1324
                        )
×
1325
                        if err != nil {
×
1326
                                return nil, fmt.Errorf("unable to batch "+
×
1327
                                        "fetch channels for nodes: %w", err)
×
1328
                        }
×
1329

1330
                        // Deduplicate channels and collect IDs.
1331
                        var (
×
1332
                                allChannelIDs []int64
×
1333
                                allPolicyIDs  []int64
×
1334
                        )
×
1335
                        uniqueChannels := make(
×
1336
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1337
                        )
×
1338

×
1339
                        for _, channel := range allChannels {
×
1340
                                channelID := channel.GraphChannel.ID
×
1341

×
1342
                                // Only process each unique channel once.
×
1343
                                _, exists := uniqueChannels[channelID]
×
1344
                                if exists {
×
1345
                                        continue
×
1346
                                }
1347

1348
                                uniqueChannels[channelID] = channel
×
1349
                                allChannelIDs = append(allChannelIDs, channelID)
×
1350

×
1351
                                if channel.Policy1ID.Valid {
×
1352
                                        allPolicyIDs = append(
×
1353
                                                allPolicyIDs,
×
1354
                                                channel.Policy1ID.Int64,
×
1355
                                        )
×
1356
                                }
×
1357
                                if channel.Policy2ID.Valid {
×
1358
                                        allPolicyIDs = append(
×
1359
                                                allPolicyIDs,
×
1360
                                                channel.Policy2ID.Int64,
×
1361
                                        )
×
1362
                                }
×
1363
                        }
1364

1365
                        // Batch load channel data for all unique channels.
1366
                        channelBatchData, err := batchLoadChannelData(
×
1367
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1368
                                allPolicyIDs,
×
1369
                        )
×
1370
                        if err != nil {
×
1371
                                return nil, fmt.Errorf("unable to batch "+
×
1372
                                        "load channel data: %w", err)
×
1373
                        }
×
1374

1375
                        // Create map of node ID to channels that involve this
1376
                        // node.
1377
                        nodeIDSet := make(map[int64]bool)
×
1378
                        for _, nodeID := range nodeIDs {
×
1379
                                nodeIDSet[nodeID] = true
×
1380
                        }
×
1381

1382
                        nodeChannelMap := make(
×
1383
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1384
                        )
×
1385
                        for _, channel := range uniqueChannels {
×
1386
                                // Add channel to both nodes if they're in our
×
1387
                                // current page.
×
1388
                                node1 := channel.GraphChannel.NodeID1
×
1389
                                if nodeIDSet[node1] {
×
1390
                                        nodeChannelMap[node1] = append(
×
1391
                                                nodeChannelMap[node1], channel,
×
1392
                                        )
×
1393
                                }
×
1394
                                node2 := channel.GraphChannel.NodeID2
×
1395
                                if nodeIDSet[node2] {
×
1396
                                        nodeChannelMap[node2] = append(
×
1397
                                                nodeChannelMap[node2], channel,
×
1398
                                        )
×
1399
                                }
×
1400
                        }
1401

1402
                        return &nodeCachedBatchData{
×
1403
                                features:      nodeFeatures,
×
1404
                                addrs:         nodeAddrs,
×
1405
                                chanBatchData: channelBatchData,
×
1406
                                chanMap:       nodeChannelMap,
×
1407
                        }, nil
×
1408
                }
1409

1410
                // processItem is used to process each node in the current page.
1411
                processItem := func(ctx context.Context,
×
1412
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1413
                        batchData *nodeCachedBatchData) error {
×
1414

×
1415
                        // Build feature vector for this node.
×
1416
                        fv := lnwire.EmptyFeatureVector()
×
1417
                        features, exists := batchData.features[nodeData.ID]
×
1418
                        if exists {
×
1419
                                for _, bit := range features {
×
1420
                                        fv.Set(lnwire.FeatureBit(bit))
×
1421
                                }
×
1422
                        }
1423

1424
                        var nodePub route.Vertex
×
1425
                        copy(nodePub[:], nodeData.PubKey)
×
1426

×
1427
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1428

×
1429
                        toNodeCallback := func() route.Vertex {
×
1430
                                return nodePub
×
1431
                        }
×
1432

1433
                        // Build cached channels map for this node.
1434
                        channels := make(map[uint64]*DirectedChannel)
×
1435
                        for _, channelRow := range nodeChannels {
×
1436
                                directedChan, err := buildDirectedChannel(
×
1437
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1438
                                        channelRow, batchData.chanBatchData, fv,
×
1439
                                        toNodeCallback,
×
1440
                                )
×
1441
                                if err != nil {
×
1442
                                        return err
×
1443
                                }
×
1444

1445
                                channels[directedChan.ChannelID] = directedChan
×
1446
                        }
1447

1448
                        addrs, err := buildNodeAddresses(
×
1449
                                batchData.addrs[nodeData.ID],
×
1450
                        )
×
1451
                        if err != nil {
×
1452
                                return fmt.Errorf("unable to build node "+
×
1453
                                        "addresses: %w", err)
×
1454
                        }
×
1455

1456
                        return cb(ctx, nodePub, addrs, channels)
×
1457
                }
1458

1459
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1460
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1461
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1462
                                return node.ID
×
1463
                        },
×
1464
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1465
                                error) {
×
1466

×
1467
                                return node.ID, nil
×
1468
                        },
×
1469
                        batchDataFunc, processItem,
1470
                )
1471
        }, reset)
1472
}
1473

1474
// ForEachChannelCacheable iterates through all the channel edges stored
1475
// within the graph and invokes the passed callback for each edge. The
1476
// callback takes two edges as since this is a directed graph, both the
1477
// in/out edges are visited. If the callback returns an error, then the
1478
// transaction is aborted and the iteration stops early.
1479
//
1480
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1481
// pointer for that particular channel edge routing policy will be
1482
// passed into the callback.
1483
//
1484
// NOTE: this method is like ForEachChannel but fetches only the data
1485
// required for the graph cache.
1486
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1487
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1488
        reset func()) error {
×
1489

×
1490
        ctx := context.TODO()
×
1491

×
1492
        handleChannel := func(_ context.Context,
×
1493
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1494

×
1495
                node1, node2, err := buildNodeVertices(
×
1496
                        row.Node1Pubkey, row.Node2Pubkey,
×
1497
                )
×
1498
                if err != nil {
×
1499
                        return err
×
1500
                }
×
1501

1502
                edge := buildCacheableChannelInfo(
×
1503
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1504
                )
×
1505

×
1506
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1507
                if err != nil {
×
1508
                        return err
×
1509
                }
×
1510

1511
                pol1, pol2, err := buildCachedChanPolicies(
×
1512
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1513
                )
×
1514
                if err != nil {
×
1515
                        return err
×
1516
                }
×
1517

1518
                return cb(edge, pol1, pol2)
×
1519
        }
1520

1521
        extractCursor := func(
×
1522
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1523

×
1524
                return row.ID
×
1525
        }
×
1526

1527
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1528
                //nolint:ll
×
1529
                queryFunc := func(ctx context.Context, lastID int64,
×
1530
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1531
                        error) {
×
1532

×
1533
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1534
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1535
                                        Version: int16(lnwire.GossipVersion1),
×
1536
                                        ID:      lastID,
×
1537
                                        Limit:   limit,
×
1538
                                },
×
1539
                        )
×
1540
                }
×
1541

1542
                return sqldb.ExecutePaginatedQuery(
×
1543
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1544
                        extractCursor, handleChannel,
×
1545
                )
×
1546
        }, reset)
1547
}
1548

1549
// ForEachChannel iterates through all the channel edges stored within the
1550
// graph and invokes the passed callback for each edge. The callback takes two
1551
// edges as since this is a directed graph, both the in/out edges are visited.
1552
// If the callback returns an error, then the transaction is aborted and the
1553
// iteration stops early.
1554
//
1555
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1556
// for that particular channel edge routing policy will be passed into the
1557
// callback.
1558
//
1559
// NOTE: part of the V1Store interface.
1560
func (s *SQLStore) ForEachChannel(ctx context.Context,
1561
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1562
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1563

×
1564
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1565
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1566
        }, reset)
×
1567
}
1568

1569
// FilterChannelRange returns the channel ID's of all known channels which were
1570
// mined in a block height within the passed range. The channel IDs are grouped
1571
// by their common block height. This method can be used to quickly share with a
1572
// peer the set of channels we know of within a particular range to catch them
1573
// up after a period of time offline. If withTimestamps is true then the
1574
// timestamp info of the latest received channel update messages of the channel
1575
// will be included in the response.
1576
//
1577
// NOTE: This is part of the V1Store interface.
1578
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1579
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1580

×
1581
        var (
×
1582
                ctx       = context.TODO()
×
1583
                startSCID = &lnwire.ShortChannelID{
×
1584
                        BlockHeight: startHeight,
×
1585
                }
×
1586
                endSCID = lnwire.ShortChannelID{
×
1587
                        BlockHeight: endHeight,
×
1588
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1589
                        TxPosition:  math.MaxUint16,
×
1590
                }
×
1591
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1592
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1593
        )
×
1594

×
1595
        // 1) get all channels where channelID is between start and end chan ID.
×
1596
        // 2) skip if not public (ie, no channel_proof)
×
1597
        // 3) collect that channel.
×
1598
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1599
        //    and add those timestamps to the collected channel.
×
1600
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1601
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1602
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1603
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1604
                                StartScid: chanIDStart,
×
1605
                                EndScid:   chanIDEnd,
×
1606
                        },
×
1607
                )
×
1608
                if err != nil {
×
1609
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1610
                                err)
×
1611
                }
×
1612

1613
                for _, dbChan := range dbChans {
×
1614
                        cid := lnwire.NewShortChanIDFromInt(
×
1615
                                byteOrder.Uint64(dbChan.Scid),
×
1616
                        )
×
1617
                        chanInfo := NewChannelUpdateInfo(
×
1618
                                cid, time.Time{}, time.Time{},
×
1619
                        )
×
1620

×
1621
                        if !withTimestamps {
×
1622
                                channelsPerBlock[cid.BlockHeight] = append(
×
1623
                                        channelsPerBlock[cid.BlockHeight],
×
1624
                                        chanInfo,
×
1625
                                )
×
1626

×
1627
                                continue
×
1628
                        }
1629

1630
                        //nolint:ll
1631
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1632
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1633
                                        Version:   int16(lnwire.GossipVersion1),
×
1634
                                        ChannelID: dbChan.ID,
×
1635
                                        NodeID:    dbChan.NodeID1,
×
1636
                                },
×
1637
                        )
×
1638
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1639
                                return fmt.Errorf("unable to fetch node1 "+
×
1640
                                        "policy: %w", err)
×
1641
                        } else if err == nil {
×
1642
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1643
                                        node1Policy.LastUpdate.Int64, 0,
×
1644
                                )
×
1645
                        }
×
1646

1647
                        //nolint:ll
1648
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1649
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1650
                                        Version:   int16(lnwire.GossipVersion1),
×
1651
                                        ChannelID: dbChan.ID,
×
1652
                                        NodeID:    dbChan.NodeID2,
×
1653
                                },
×
1654
                        )
×
1655
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1656
                                return fmt.Errorf("unable to fetch node2 "+
×
1657
                                        "policy: %w", err)
×
1658
                        } else if err == nil {
×
1659
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1660
                                        node2Policy.LastUpdate.Int64, 0,
×
1661
                                )
×
1662
                        }
×
1663

1664
                        channelsPerBlock[cid.BlockHeight] = append(
×
1665
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1666
                        )
×
1667
                }
1668

1669
                return nil
×
1670
        }, func() {
×
1671
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1672
        })
×
1673
        if err != nil {
×
1674
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1675
        }
×
1676

1677
        if len(channelsPerBlock) == 0 {
×
1678
                return nil, nil
×
1679
        }
×
1680

1681
        // Return the channel ranges in ascending block height order.
1682
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1683
        slices.Sort(blocks)
×
1684

×
1685
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1686
                return BlockChannelRange{
×
1687
                        Height:   block,
×
1688
                        Channels: channelsPerBlock[block],
×
1689
                }
×
1690
        }), nil
×
1691
}
1692

1693
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1694
// zombie. This method is used on an ad-hoc basis, when channels need to be
1695
// marked as zombies outside the normal pruning cycle.
1696
//
1697
// NOTE: part of the V1Store interface.
1698
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1699
        pubKey1, pubKey2 [33]byte) error {
×
1700

×
1701
        ctx := context.TODO()
×
1702

×
1703
        s.cacheMu.Lock()
×
1704
        defer s.cacheMu.Unlock()
×
1705

×
1706
        chanIDB := channelIDToBytes(chanID)
×
1707

×
1708
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1709
                return db.UpsertZombieChannel(
×
1710
                        ctx, sqlc.UpsertZombieChannelParams{
×
1711
                                Version:  int16(lnwire.GossipVersion1),
×
1712
                                Scid:     chanIDB,
×
1713
                                NodeKey1: pubKey1[:],
×
1714
                                NodeKey2: pubKey2[:],
×
1715
                        },
×
1716
                )
×
1717
        }, sqldb.NoOpReset)
×
1718
        if err != nil {
×
1719
                return fmt.Errorf("unable to upsert zombie channel "+
×
1720
                        "(channel_id=%d): %w", chanID, err)
×
1721
        }
×
1722

1723
        s.rejectCache.remove(chanID)
×
1724
        s.chanCache.remove(chanID)
×
1725

×
1726
        return nil
×
1727
}
1728

1729
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1730
//
1731
// NOTE: part of the V1Store interface.
1732
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1733
        s.cacheMu.Lock()
×
1734
        defer s.cacheMu.Unlock()
×
1735

×
1736
        var (
×
1737
                ctx     = context.TODO()
×
1738
                chanIDB = channelIDToBytes(chanID)
×
1739
        )
×
1740

×
1741
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1742
                res, err := db.DeleteZombieChannel(
×
1743
                        ctx, sqlc.DeleteZombieChannelParams{
×
1744
                                Scid:    chanIDB,
×
1745
                                Version: int16(lnwire.GossipVersion1),
×
1746
                        },
×
1747
                )
×
1748
                if err != nil {
×
1749
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1750
                                err)
×
1751
                }
×
1752

1753
                rows, err := res.RowsAffected()
×
1754
                if err != nil {
×
1755
                        return err
×
1756
                }
×
1757

1758
                if rows == 0 {
×
1759
                        return ErrZombieEdgeNotFound
×
1760
                } else if rows > 1 {
×
1761
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1762
                                "expected 1", rows)
×
1763
                }
×
1764

1765
                return nil
×
1766
        }, sqldb.NoOpReset)
1767
        if err != nil {
×
1768
                return fmt.Errorf("unable to mark edge live "+
×
1769
                        "(channel_id=%d): %w", chanID, err)
×
1770
        }
×
1771

1772
        s.rejectCache.remove(chanID)
×
1773
        s.chanCache.remove(chanID)
×
1774

×
1775
        return err
×
1776
}
1777

1778
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1779
// zombie, then the two node public keys corresponding to this edge are also
1780
// returned.
1781
//
1782
// NOTE: part of the V1Store interface.
1783
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1784
        error) {
×
1785

×
1786
        var (
×
1787
                ctx              = context.TODO()
×
1788
                isZombie         bool
×
1789
                pubKey1, pubKey2 route.Vertex
×
1790
                chanIDB          = channelIDToBytes(chanID)
×
1791
        )
×
1792

×
1793
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1794
                zombie, err := db.GetZombieChannel(
×
1795
                        ctx, sqlc.GetZombieChannelParams{
×
1796
                                Scid:    chanIDB,
×
1797
                                Version: int16(lnwire.GossipVersion1),
×
1798
                        },
×
1799
                )
×
1800
                if errors.Is(err, sql.ErrNoRows) {
×
1801
                        return nil
×
1802
                }
×
1803
                if err != nil {
×
1804
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1805
                                err)
×
1806
                }
×
1807

1808
                copy(pubKey1[:], zombie.NodeKey1)
×
1809
                copy(pubKey2[:], zombie.NodeKey2)
×
1810
                isZombie = true
×
1811

×
1812
                return nil
×
1813
        }, sqldb.NoOpReset)
1814
        if err != nil {
×
1815
                return false, route.Vertex{}, route.Vertex{},
×
1816
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1817
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1818
        }
×
1819

1820
        return isZombie, pubKey1, pubKey2, nil
×
1821
}
1822

1823
// NumZombies returns the current number of zombie channels in the graph.
1824
//
1825
// NOTE: part of the V1Store interface.
1826
func (s *SQLStore) NumZombies() (uint64, error) {
×
1827
        var (
×
1828
                ctx        = context.TODO()
×
1829
                numZombies uint64
×
1830
        )
×
1831
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1832
                count, err := db.CountZombieChannels(
×
1833
                        ctx, int16(lnwire.GossipVersion1),
×
1834
                )
×
1835
                if err != nil {
×
1836
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1837
                                err)
×
1838
                }
×
1839

1840
                numZombies = uint64(count)
×
1841

×
1842
                return nil
×
1843
        }, sqldb.NoOpReset)
1844
        if err != nil {
×
1845
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1846
        }
×
1847

1848
        return numZombies, nil
×
1849
}
1850

1851
// DeleteChannelEdges removes edges with the given channel IDs from the
1852
// database and marks them as zombies. This ensures that we're unable to re-add
1853
// it to our database once again. If an edge does not exist within the
1854
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1855
// true, then when we mark these edges as zombies, we'll set up the keys such
1856
// that we require the node that failed to send the fresh update to be the one
1857
// that resurrects the channel from its zombie state. The markZombie bool
1858
// denotes whether to mark the channel as a zombie.
1859
//
1860
// NOTE: part of the V1Store interface.
1861
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1862
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1863

×
1864
        s.cacheMu.Lock()
×
1865
        defer s.cacheMu.Unlock()
×
1866

×
1867
        // Keep track of which channels we end up finding so that we can
×
1868
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1869
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1870
        for _, chanID := range chanIDs {
×
1871
                chanLookup[chanID] = struct{}{}
×
1872
        }
×
1873

1874
        var (
×
1875
                ctx   = context.TODO()
×
1876
                edges []*models.ChannelEdgeInfo
×
1877
        )
×
1878
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1879
                // First, collect all channel rows.
×
1880
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1881
                chanCallBack := func(ctx context.Context,
×
1882
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1883

×
1884
                        // Deleting the entry from the map indicates that we
×
1885
                        // have found the channel.
×
1886
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1887
                        delete(chanLookup, scid)
×
1888

×
1889
                        channelRows = append(channelRows, row)
×
1890

×
1891
                        return nil
×
1892
                }
×
1893

1894
                err := s.forEachChanWithPoliciesInSCIDList(
×
1895
                        ctx, db, chanCallBack, chanIDs,
×
1896
                )
×
1897
                if err != nil {
×
1898
                        return err
×
1899
                }
×
1900

1901
                if len(chanLookup) > 0 {
×
1902
                        return ErrEdgeNotFound
×
1903
                }
×
1904

1905
                if len(channelRows) == 0 {
×
1906
                        return nil
×
1907
                }
×
1908

1909
                // Batch build all channel edges.
1910
                var chanIDsToDelete []int64
×
1911
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
1912
                        ctx, s.cfg, db, channelRows,
×
1913
                )
×
1914
                if err != nil {
×
1915
                        return err
×
1916
                }
×
1917

1918
                if markZombie {
×
1919
                        for i, row := range channelRows {
×
1920
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1921

×
1922
                                err := handleZombieMarking(
×
1923
                                        ctx, db, row, edges[i],
×
1924
                                        strictZombiePruning, scid,
×
1925
                                )
×
1926
                                if err != nil {
×
1927
                                        return fmt.Errorf("unable to mark "+
×
1928
                                                "channel as zombie: %w", err)
×
1929
                                }
×
1930
                        }
1931
                }
1932

1933
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1934
        }, func() {
×
1935
                edges = nil
×
1936

×
1937
                // Re-fill the lookup map.
×
1938
                for _, chanID := range chanIDs {
×
1939
                        chanLookup[chanID] = struct{}{}
×
1940
                }
×
1941
        })
1942
        if err != nil {
×
1943
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1944
                        err)
×
1945
        }
×
1946

1947
        for _, chanID := range chanIDs {
×
1948
                s.rejectCache.remove(chanID)
×
1949
                s.chanCache.remove(chanID)
×
1950
        }
×
1951

1952
        return edges, nil
×
1953
}
1954

1955
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1956
// channel identified by the channel ID. If the channel can't be found, then
1957
// ErrEdgeNotFound is returned. A struct which houses the general information
1958
// for the channel itself is returned as well as two structs that contain the
1959
// routing policies for the channel in either direction.
1960
//
1961
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1962
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1963
// the ChannelEdgeInfo will only include the public keys of each node.
1964
//
1965
// NOTE: part of the V1Store interface.
1966
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1967
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1968
        *models.ChannelEdgePolicy, error) {
×
1969

×
1970
        var (
×
1971
                ctx              = context.TODO()
×
1972
                edge             *models.ChannelEdgeInfo
×
1973
                policy1, policy2 *models.ChannelEdgePolicy
×
1974
                chanIDB          = channelIDToBytes(chanID)
×
1975
        )
×
1976
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1977
                row, err := db.GetChannelBySCIDWithPolicies(
×
1978
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1979
                                Scid:    chanIDB,
×
1980
                                Version: int16(lnwire.GossipVersion1),
×
1981
                        },
×
1982
                )
×
1983
                if errors.Is(err, sql.ErrNoRows) {
×
1984
                        // First check if this edge is perhaps in the zombie
×
1985
                        // index.
×
1986
                        zombie, err := db.GetZombieChannel(
×
1987
                                ctx, sqlc.GetZombieChannelParams{
×
1988
                                        Scid:    chanIDB,
×
1989
                                        Version: int16(lnwire.GossipVersion1),
×
1990
                                },
×
1991
                        )
×
1992
                        if errors.Is(err, sql.ErrNoRows) {
×
1993
                                return ErrEdgeNotFound
×
1994
                        } else if err != nil {
×
1995
                                return fmt.Errorf("unable to check if "+
×
1996
                                        "channel is zombie: %w", err)
×
1997
                        }
×
1998

1999
                        // At this point, we know the channel is a zombie, so
2000
                        // we'll return an error indicating this, and we will
2001
                        // populate the edge info with the public keys of each
2002
                        // party as this is the only information we have about
2003
                        // it.
2004
                        edge = &models.ChannelEdgeInfo{}
×
2005
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
2006
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
2007

×
2008
                        return ErrZombieEdge
×
2009
                } else if err != nil {
×
2010
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2011
                }
×
2012

2013
                node1, node2, err := buildNodeVertices(
×
2014
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
2015
                )
×
2016
                if err != nil {
×
2017
                        return err
×
2018
                }
×
2019

2020
                edge, err = getAndBuildEdgeInfo(
×
2021
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2022
                )
×
2023
                if err != nil {
×
2024
                        return fmt.Errorf("unable to build channel info: %w",
×
2025
                                err)
×
2026
                }
×
2027

2028
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2029
                if err != nil {
×
2030
                        return fmt.Errorf("unable to extract channel "+
×
2031
                                "policies: %w", err)
×
2032
                }
×
2033

2034
                policy1, policy2, err = getAndBuildChanPolicies(
×
2035
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2036
                        node1, node2,
×
2037
                )
×
2038
                if err != nil {
×
2039
                        return fmt.Errorf("unable to build channel "+
×
2040
                                "policies: %w", err)
×
2041
                }
×
2042

2043
                return nil
×
2044
        }, sqldb.NoOpReset)
2045
        if err != nil {
×
2046
                // If we are returning the ErrZombieEdge, then we also need to
×
2047
                // return the edge info as the method comment indicates that
×
2048
                // this will be populated when the edge is a zombie.
×
2049
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2050
                        err)
×
2051
        }
×
2052

2053
        return edge, policy1, policy2, nil
×
2054
}
2055

2056
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
2057
// the channel identified by the funding outpoint. If the channel can't be
2058
// found, then ErrEdgeNotFound is returned. A struct which houses the general
2059
// information for the channel itself is returned as well as two structs that
2060
// contain the routing policies for the channel in either direction.
2061
//
2062
// NOTE: part of the V1Store interface.
2063
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
2064
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
2065
        *models.ChannelEdgePolicy, error) {
×
2066

×
2067
        var (
×
2068
                ctx              = context.TODO()
×
2069
                edge             *models.ChannelEdgeInfo
×
2070
                policy1, policy2 *models.ChannelEdgePolicy
×
2071
        )
×
2072
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2073
                row, err := db.GetChannelByOutpointWithPolicies(
×
2074
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
2075
                                Outpoint: op.String(),
×
2076
                                Version:  int16(lnwire.GossipVersion1),
×
2077
                        },
×
2078
                )
×
2079
                if errors.Is(err, sql.ErrNoRows) {
×
2080
                        return ErrEdgeNotFound
×
2081
                } else if err != nil {
×
2082
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2083
                }
×
2084

2085
                node1, node2, err := buildNodeVertices(
×
2086
                        row.Node1Pubkey, row.Node2Pubkey,
×
2087
                )
×
2088
                if err != nil {
×
2089
                        return err
×
2090
                }
×
2091

2092
                edge, err = getAndBuildEdgeInfo(
×
2093
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2094
                )
×
2095
                if err != nil {
×
2096
                        return fmt.Errorf("unable to build channel info: %w",
×
2097
                                err)
×
2098
                }
×
2099

2100
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2101
                if err != nil {
×
2102
                        return fmt.Errorf("unable to extract channel "+
×
2103
                                "policies: %w", err)
×
2104
                }
×
2105

2106
                policy1, policy2, err = getAndBuildChanPolicies(
×
2107
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2108
                        node1, node2,
×
2109
                )
×
2110
                if err != nil {
×
2111
                        return fmt.Errorf("unable to build channel "+
×
2112
                                "policies: %w", err)
×
2113
                }
×
2114

2115
                return nil
×
2116
        }, sqldb.NoOpReset)
2117
        if err != nil {
×
2118
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2119
                        err)
×
2120
        }
×
2121

2122
        return edge, policy1, policy2, nil
×
2123
}
2124

2125
// HasChannelEdge returns true if the database knows of a channel edge with the
2126
// passed channel ID, and false otherwise. If an edge with that ID is found
2127
// within the graph, then two time stamps representing the last time the edge
2128
// was updated for both directed edges are returned along with the boolean. If
2129
// it is not found, then the zombie index is checked and its result is returned
2130
// as the second boolean.
2131
//
2132
// NOTE: part of the V1Store interface.
2133
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
2134
        bool, error) {
×
2135

×
2136
        ctx := context.TODO()
×
2137

×
2138
        var (
×
2139
                exists          bool
×
2140
                isZombie        bool
×
2141
                node1LastUpdate time.Time
×
2142
                node2LastUpdate time.Time
×
2143
        )
×
2144

×
2145
        // We'll query the cache with the shared lock held to allow multiple
×
2146
        // readers to access values in the cache concurrently if they exist.
×
2147
        s.cacheMu.RLock()
×
2148
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2149
                s.cacheMu.RUnlock()
×
2150
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2151
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2152
                exists, isZombie = entry.flags.unpack()
×
2153

×
2154
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2155
        }
×
2156
        s.cacheMu.RUnlock()
×
2157

×
2158
        s.cacheMu.Lock()
×
2159
        defer s.cacheMu.Unlock()
×
2160

×
2161
        // The item was not found with the shared lock, so we'll acquire the
×
2162
        // exclusive lock and check the cache again in case another method added
×
2163
        // the entry to the cache while no lock was held.
×
2164
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2165
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2166
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2167
                exists, isZombie = entry.flags.unpack()
×
2168

×
2169
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2170
        }
×
2171

2172
        chanIDB := channelIDToBytes(chanID)
×
2173
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2174
                channel, err := db.GetChannelBySCID(
×
2175
                        ctx, sqlc.GetChannelBySCIDParams{
×
2176
                                Scid:    chanIDB,
×
2177
                                Version: int16(lnwire.GossipVersion1),
×
2178
                        },
×
2179
                )
×
2180
                if errors.Is(err, sql.ErrNoRows) {
×
2181
                        // Check if it is a zombie channel.
×
2182
                        isZombie, err = db.IsZombieChannel(
×
2183
                                ctx, sqlc.IsZombieChannelParams{
×
2184
                                        Scid:    chanIDB,
×
2185
                                        Version: int16(lnwire.GossipVersion1),
×
2186
                                },
×
2187
                        )
×
2188
                        if err != nil {
×
2189
                                return fmt.Errorf("could not check if channel "+
×
2190
                                        "is zombie: %w", err)
×
2191
                        }
×
2192

2193
                        return nil
×
2194
                } else if err != nil {
×
2195
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2196
                }
×
2197

2198
                exists = true
×
2199

×
2200
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2201
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2202
                                Version:   int16(lnwire.GossipVersion1),
×
2203
                                ChannelID: channel.ID,
×
2204
                                NodeID:    channel.NodeID1,
×
2205
                        },
×
2206
                )
×
2207
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2208
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2209
                                err)
×
2210
                } else if err == nil {
×
2211
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2212
                }
×
2213

2214
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2215
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2216
                                Version:   int16(lnwire.GossipVersion1),
×
2217
                                ChannelID: channel.ID,
×
2218
                                NodeID:    channel.NodeID2,
×
2219
                        },
×
2220
                )
×
2221
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2222
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2223
                                err)
×
2224
                } else if err == nil {
×
2225
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2226
                }
×
2227

2228
                return nil
×
2229
        }, sqldb.NoOpReset)
2230
        if err != nil {
×
2231
                return time.Time{}, time.Time{}, false, false,
×
2232
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2233
        }
×
2234

2235
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2236
                upd1Time: node1LastUpdate.Unix(),
×
2237
                upd2Time: node2LastUpdate.Unix(),
×
2238
                flags:    packRejectFlags(exists, isZombie),
×
2239
        })
×
2240

×
2241
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2242
}
2243

2244
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2245
// passed channel point (outpoint). If the passed channel doesn't exist within
2246
// the database, then ErrEdgeNotFound is returned.
2247
//
2248
// NOTE: part of the V1Store interface.
2249
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2250
        var (
×
2251
                ctx       = context.TODO()
×
2252
                channelID uint64
×
2253
        )
×
2254
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2255
                chanID, err := db.GetSCIDByOutpoint(
×
2256
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2257
                                Outpoint: chanPoint.String(),
×
2258
                                Version:  int16(lnwire.GossipVersion1),
×
2259
                        },
×
2260
                )
×
2261
                if errors.Is(err, sql.ErrNoRows) {
×
2262
                        return ErrEdgeNotFound
×
2263
                } else if err != nil {
×
2264
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2265
                                err)
×
2266
                }
×
2267

2268
                channelID = byteOrder.Uint64(chanID)
×
2269

×
2270
                return nil
×
2271
        }, sqldb.NoOpReset)
2272
        if err != nil {
×
2273
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2274
        }
×
2275

2276
        return channelID, nil
×
2277
}
2278

2279
// IsPublicNode is a helper method that determines whether the node with the
2280
// given public key is seen as a public node in the graph from the graph's
2281
// source node's point of view.
2282
//
2283
// NOTE: part of the V1Store interface.
2284
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2285
        ctx := context.TODO()
×
2286

×
2287
        var isPublic bool
×
2288
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2289
                var err error
×
2290
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2291

×
2292
                return err
×
2293
        }, sqldb.NoOpReset)
×
2294
        if err != nil {
×
2295
                return false, fmt.Errorf("unable to check if node is "+
×
2296
                        "public: %w", err)
×
2297
        }
×
2298

2299
        return isPublic, nil
×
2300
}
2301

2302
// FetchChanInfos returns the set of channel edges that correspond to the passed
2303
// channel ID's. If an edge is the query is unknown to the database, it will
2304
// skipped and the result will contain only those edges that exist at the time
2305
// of the query. This can be used to respond to peer queries that are seeking to
2306
// fill in gaps in their view of the channel graph.
2307
//
2308
// NOTE: part of the V1Store interface.
2309
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2310
        var (
×
2311
                ctx   = context.TODO()
×
2312
                edges = make(map[uint64]ChannelEdge)
×
2313
        )
×
2314
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2315
                // First, collect all channel rows.
×
2316
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2317
                chanCallBack := func(ctx context.Context,
×
2318
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2319

×
2320
                        channelRows = append(channelRows, row)
×
2321
                        return nil
×
2322
                }
×
2323

2324
                err := s.forEachChanWithPoliciesInSCIDList(
×
2325
                        ctx, db, chanCallBack, chanIDs,
×
2326
                )
×
2327
                if err != nil {
×
2328
                        return err
×
2329
                }
×
2330

2331
                if len(channelRows) == 0 {
×
2332
                        return nil
×
2333
                }
×
2334

2335
                // Batch build all channel edges.
2336
                chans, err := batchBuildChannelEdges(
×
2337
                        ctx, s.cfg, db, channelRows,
×
2338
                )
×
2339
                if err != nil {
×
2340
                        return fmt.Errorf("unable to build channel edges: %w",
×
2341
                                err)
×
2342
                }
×
2343

2344
                for _, c := range chans {
×
2345
                        edges[c.Info.ChannelID] = c
×
2346
                }
×
2347

2348
                return err
×
2349
        }, func() {
×
2350
                clear(edges)
×
2351
        })
×
2352
        if err != nil {
×
2353
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2354
        }
×
2355

2356
        res := make([]ChannelEdge, 0, len(edges))
×
2357
        for _, chanID := range chanIDs {
×
2358
                edge, ok := edges[chanID]
×
2359
                if !ok {
×
2360
                        continue
×
2361
                }
2362

2363
                res = append(res, edge)
×
2364
        }
2365

2366
        return res, nil
×
2367
}
2368

2369
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2370
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2371
// channels in a paginated manner.
2372
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2373
        db SQLQueries, cb func(ctx context.Context,
2374
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2375
        chanIDs []uint64) error {
×
2376

×
2377
        queryWrapper := func(ctx context.Context,
×
2378
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2379
                error) {
×
2380

×
2381
                return db.GetChannelsBySCIDWithPolicies(
×
2382
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2383
                                Version: int16(lnwire.GossipVersion1),
×
2384
                                Scids:   scids,
×
2385
                        },
×
2386
                )
×
2387
        }
×
2388

2389
        return sqldb.ExecuteBatchQuery(
×
2390
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2391
                cb,
×
2392
        )
×
2393
}
2394

2395
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2396
// ID's that we don't know and are not known zombies of the passed set. In other
2397
// words, we perform a set difference of our set of chan ID's and the ones
2398
// passed in. This method can be used by callers to determine the set of
2399
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2400
// known zombies is also returned.
2401
//
2402
// NOTE: part of the V1Store interface.
2403
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2404
        []ChannelUpdateInfo, error) {
×
2405

×
2406
        var (
×
2407
                ctx          = context.TODO()
×
2408
                newChanIDs   []uint64
×
2409
                knownZombies []ChannelUpdateInfo
×
2410
                infoLookup   = make(
×
2411
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2412
                )
×
2413
        )
×
2414

×
2415
        // We first build a lookup map of the channel ID's to the
×
2416
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2417
        // already know about.
×
2418
        for _, chanInfo := range chansInfo {
×
2419
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2420
        }
×
2421

2422
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2423
                // The call-back function deletes known channels from
×
2424
                // infoLookup, so that we can later check which channels are
×
2425
                // zombies by only looking at the remaining channels in the set.
×
2426
                cb := func(ctx context.Context,
×
2427
                        channel sqlc.GraphChannel) error {
×
2428

×
2429
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2430

×
2431
                        return nil
×
2432
                }
×
2433

2434
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2435
                if err != nil {
×
2436
                        return fmt.Errorf("unable to iterate through "+
×
2437
                                "channels: %w", err)
×
2438
                }
×
2439

2440
                // We want to ensure that we deal with the channels in the
2441
                // same order that they were passed in, so we iterate over the
2442
                // original chansInfo slice and then check if that channel is
2443
                // still in the infoLookup map.
2444
                for _, chanInfo := range chansInfo {
×
2445
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2446
                        if _, ok := infoLookup[channelID]; !ok {
×
2447
                                continue
×
2448
                        }
2449

2450
                        isZombie, err := db.IsZombieChannel(
×
2451
                                ctx, sqlc.IsZombieChannelParams{
×
2452
                                        Scid:    channelIDToBytes(channelID),
×
2453
                                        Version: int16(lnwire.GossipVersion1),
×
2454
                                },
×
2455
                        )
×
2456
                        if err != nil {
×
2457
                                return fmt.Errorf("unable to fetch zombie "+
×
2458
                                        "channel: %w", err)
×
2459
                        }
×
2460

2461
                        if isZombie {
×
2462
                                knownZombies = append(knownZombies, chanInfo)
×
2463

×
2464
                                continue
×
2465
                        }
2466

2467
                        newChanIDs = append(newChanIDs, channelID)
×
2468
                }
2469

2470
                return nil
×
2471
        }, func() {
×
2472
                newChanIDs = nil
×
2473
                knownZombies = nil
×
2474
                // Rebuild the infoLookup map in case of a rollback.
×
2475
                for _, chanInfo := range chansInfo {
×
2476
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2477
                        infoLookup[scid] = chanInfo
×
2478
                }
×
2479
        })
2480
        if err != nil {
×
2481
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2482
        }
×
2483

2484
        return newChanIDs, knownZombies, nil
×
2485
}
2486

2487
// forEachChanInSCIDList is a helper method that executes a paged query
2488
// against the database to fetch all channels that match the passed
2489
// ChannelUpdateInfo slice. The callback function is called for each channel
2490
// that is found.
2491
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2492
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2493
        chansInfo []ChannelUpdateInfo) error {
×
2494

×
2495
        queryWrapper := func(ctx context.Context,
×
2496
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2497

×
2498
                return db.GetChannelsBySCIDs(
×
2499
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2500
                                Version: int16(lnwire.GossipVersion1),
×
2501
                                Scids:   scids,
×
2502
                        },
×
2503
                )
×
2504
        }
×
2505

2506
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2507
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2508

×
2509
                return channelIDToBytes(channelID)
×
2510
        }
×
2511

2512
        return sqldb.ExecuteBatchQuery(
×
2513
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2514
                cb,
×
2515
        )
×
2516
}
2517

2518
// PruneGraphNodes is a garbage collection method which attempts to prune out
2519
// any nodes from the channel graph that are currently unconnected. This ensure
2520
// that we only maintain a graph of reachable nodes. In the event that a pruned
2521
// node gains more channels, it will be re-added back to the graph.
2522
//
2523
// NOTE: this prunes nodes across protocol versions. It will never prune the
2524
// source nodes.
2525
//
2526
// NOTE: part of the V1Store interface.
2527
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2528
        var ctx = context.TODO()
×
2529

×
2530
        var prunedNodes []route.Vertex
×
2531
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2532
                var err error
×
2533
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2534

×
2535
                return err
×
2536
        }, func() {
×
2537
                prunedNodes = nil
×
2538
        })
×
2539
        if err != nil {
×
2540
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2541
        }
×
2542

2543
        return prunedNodes, nil
×
2544
}
2545

2546
// PruneGraph prunes newly closed channels from the channel graph in response
2547
// to a new block being solved on the network. Any transactions which spend the
2548
// funding output of any known channels within he graph will be deleted.
2549
// Additionally, the "prune tip", or the last block which has been used to
2550
// prune the graph is stored so callers can ensure the graph is fully in sync
2551
// with the current UTXO state. A slice of channels that have been closed by
2552
// the target block along with any pruned nodes are returned if the function
2553
// succeeds without error.
2554
//
2555
// NOTE: part of the V1Store interface.
2556
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2557
        blockHash *chainhash.Hash, blockHeight uint32) (
2558
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2559

×
2560
        ctx := context.TODO()
×
2561

×
2562
        s.cacheMu.Lock()
×
2563
        defer s.cacheMu.Unlock()
×
2564

×
2565
        var (
×
2566
                closedChans []*models.ChannelEdgeInfo
×
2567
                prunedNodes []route.Vertex
×
2568
        )
×
2569
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2570
                // First, collect all channel rows that need to be pruned.
×
2571
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2572
                channelCallback := func(ctx context.Context,
×
2573
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2574

×
2575
                        channelRows = append(channelRows, row)
×
2576

×
2577
                        return nil
×
2578
                }
×
2579

2580
                err := s.forEachChanInOutpoints(
×
2581
                        ctx, db, spentOutputs, channelCallback,
×
2582
                )
×
2583
                if err != nil {
×
2584
                        return fmt.Errorf("unable to fetch channels by "+
×
2585
                                "outpoints: %w", err)
×
2586
                }
×
2587

2588
                if len(channelRows) == 0 {
×
2589
                        // There are no channels to prune. So we can exit early
×
2590
                        // after updating the prune log.
×
2591
                        err = db.UpsertPruneLogEntry(
×
2592
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
2593
                                        BlockHash:   blockHash[:],
×
2594
                                        BlockHeight: int64(blockHeight),
×
2595
                                },
×
2596
                        )
×
2597
                        if err != nil {
×
2598
                                return fmt.Errorf("unable to insert prune log "+
×
2599
                                        "entry: %w", err)
×
2600
                        }
×
2601

2602
                        return nil
×
2603
                }
2604

2605
                // Batch build all channel edges for pruning.
2606
                var chansToDelete []int64
×
2607
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
2608
                        ctx, s.cfg, db, channelRows,
×
2609
                )
×
2610
                if err != nil {
×
2611
                        return err
×
2612
                }
×
2613

2614
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2615
                if err != nil {
×
2616
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2617
                }
×
2618

2619
                err = db.UpsertPruneLogEntry(
×
2620
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2621
                                BlockHash:   blockHash[:],
×
2622
                                BlockHeight: int64(blockHeight),
×
2623
                        },
×
2624
                )
×
2625
                if err != nil {
×
2626
                        return fmt.Errorf("unable to insert prune log "+
×
2627
                                "entry: %w", err)
×
2628
                }
×
2629

2630
                // Now that we've pruned some channels, we'll also prune any
2631
                // nodes that no longer have any channels.
2632
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2633
                if err != nil {
×
2634
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2635
                                err)
×
2636
                }
×
2637

2638
                return nil
×
2639
        }, func() {
×
2640
                prunedNodes = nil
×
2641
                closedChans = nil
×
2642
        })
×
2643
        if err != nil {
×
2644
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2645
        }
×
2646

2647
        for _, channel := range closedChans {
×
2648
                s.rejectCache.remove(channel.ChannelID)
×
2649
                s.chanCache.remove(channel.ChannelID)
×
2650
        }
×
2651

2652
        return closedChans, prunedNodes, nil
×
2653
}
2654

2655
// forEachChanInOutpoints is a helper function that executes a paginated
2656
// query to fetch channels by their outpoints and applies the given call-back
2657
// to each.
2658
//
2659
// NOTE: this fetches channels for all protocol versions.
2660
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2661
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2662
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2663

×
2664
        // Create a wrapper that uses the transaction's db instance to execute
×
2665
        // the query.
×
2666
        queryWrapper := func(ctx context.Context,
×
2667
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2668
                error) {
×
2669

×
2670
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2671
        }
×
2672

2673
        // Define the conversion function from Outpoint to string.
2674
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2675
                return outpoint.String()
×
2676
        }
×
2677

2678
        return sqldb.ExecuteBatchQuery(
×
2679
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2680
                queryWrapper, cb,
×
2681
        )
×
2682
}
2683

2684
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2685
        dbIDs []int64) error {
×
2686

×
2687
        // Create a wrapper that uses the transaction's db instance to execute
×
2688
        // the query.
×
2689
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2690
                return nil, db.DeleteChannels(ctx, ids)
×
2691
        }
×
2692

2693
        idConverter := func(id int64) int64 {
×
2694
                return id
×
2695
        }
×
2696

2697
        return sqldb.ExecuteBatchQuery(
×
2698
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2699
                queryWrapper, func(ctx context.Context, _ any) error {
×
2700
                        return nil
×
2701
                },
×
2702
        )
2703
}
2704

2705
// ChannelView returns the verifiable edge information for each active channel
2706
// within the known channel graph. The set of UTXOs (along with their scripts)
2707
// returned are the ones that need to be watched on chain to detect channel
2708
// closes on the resident blockchain.
2709
//
2710
// NOTE: part of the V1Store interface.
2711
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2712
        var (
×
2713
                ctx        = context.TODO()
×
2714
                edgePoints []EdgePoint
×
2715
        )
×
2716

×
2717
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2718
                handleChannel := func(_ context.Context,
×
2719
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2720

×
2721
                        pkScript, err := genMultiSigP2WSH(
×
2722
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2723
                        )
×
2724
                        if err != nil {
×
2725
                                return err
×
2726
                        }
×
2727

2728
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2729
                        if err != nil {
×
2730
                                return err
×
2731
                        }
×
2732

2733
                        edgePoints = append(edgePoints, EdgePoint{
×
2734
                                FundingPkScript: pkScript,
×
2735
                                OutPoint:        *op,
×
2736
                        })
×
2737

×
2738
                        return nil
×
2739
                }
2740

2741
                queryFunc := func(ctx context.Context, lastID int64,
×
2742
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2743

×
2744
                        return db.ListChannelsPaginated(
×
2745
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2746
                                        Version: int16(lnwire.GossipVersion1),
×
2747
                                        ID:      lastID,
×
2748
                                        Limit:   limit,
×
2749
                                },
×
2750
                        )
×
2751
                }
×
2752

2753
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2754
                        return row.ID
×
2755
                }
×
2756

2757
                return sqldb.ExecutePaginatedQuery(
×
2758
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
2759
                        extractCursor, handleChannel,
×
2760
                )
×
2761
        }, func() {
×
2762
                edgePoints = nil
×
2763
        })
×
2764
        if err != nil {
×
2765
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2766
        }
×
2767

2768
        return edgePoints, nil
×
2769
}
2770

2771
// PruneTip returns the block height and hash of the latest block that has been
2772
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2773
// to tell if the graph is currently in sync with the current best known UTXO
2774
// state.
2775
//
2776
// NOTE: part of the V1Store interface.
2777
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2778
        var (
×
2779
                ctx       = context.TODO()
×
2780
                tipHash   chainhash.Hash
×
2781
                tipHeight uint32
×
2782
        )
×
2783
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2784
                pruneTip, err := db.GetPruneTip(ctx)
×
2785
                if errors.Is(err, sql.ErrNoRows) {
×
2786
                        return ErrGraphNeverPruned
×
2787
                } else if err != nil {
×
2788
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2789
                }
×
2790

2791
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2792
                tipHeight = uint32(pruneTip.BlockHeight)
×
2793

×
2794
                return nil
×
2795
        }, sqldb.NoOpReset)
2796
        if err != nil {
×
2797
                return nil, 0, err
×
2798
        }
×
2799

2800
        return &tipHash, tipHeight, nil
×
2801
}
2802

2803
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2804
//
2805
// NOTE: this prunes nodes across protocol versions. It will never prune the
2806
// source nodes.
2807
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2808
        db SQLQueries) ([]route.Vertex, error) {
×
2809

×
2810
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2811
        if err != nil {
×
2812
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2813
                        "nodes: %w", err)
×
2814
        }
×
2815

2816
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2817
        for i, nodeKey := range nodeKeys {
×
2818
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2819
                if err != nil {
×
2820
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2821
                                "from bytes: %w", err)
×
2822
                }
×
2823

2824
                prunedNodes[i] = pub
×
2825
        }
2826

2827
        return prunedNodes, nil
×
2828
}
2829

2830
// DisconnectBlockAtHeight is used to indicate that the block specified
2831
// by the passed height has been disconnected from the main chain. This
2832
// will "rewind" the graph back to the height below, deleting channels
2833
// that are no longer confirmed from the graph. The prune log will be
2834
// set to the last prune height valid for the remaining chain.
2835
// Channels that were removed from the graph resulting from the
2836
// disconnected block are returned.
2837
//
2838
// NOTE: part of the V1Store interface.
2839
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2840
        []*models.ChannelEdgeInfo, error) {
×
2841

×
2842
        ctx := context.TODO()
×
2843

×
2844
        var (
×
2845
                // Every channel having a ShortChannelID starting at 'height'
×
2846
                // will no longer be confirmed.
×
2847
                startShortChanID = lnwire.ShortChannelID{
×
2848
                        BlockHeight: height,
×
2849
                }
×
2850

×
2851
                // Delete everything after this height from the db up until the
×
2852
                // SCID alias range.
×
2853
                endShortChanID = aliasmgr.StartingAlias
×
2854

×
2855
                removedChans []*models.ChannelEdgeInfo
×
2856

×
2857
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2858
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2859
        )
×
2860

×
2861
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2862
                rows, err := db.GetChannelsBySCIDRange(
×
2863
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2864
                                StartScid: chanIDStart,
×
2865
                                EndScid:   chanIDEnd,
×
2866
                        },
×
2867
                )
×
2868
                if err != nil {
×
2869
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2870
                }
×
2871

2872
                if len(rows) == 0 {
×
2873
                        // No channels to disconnect, but still clean up prune
×
2874
                        // log.
×
2875
                        return db.DeletePruneLogEntriesInRange(
×
2876
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2877
                                        StartHeight: int64(height),
×
2878
                                        EndHeight: int64(
×
2879
                                                endShortChanID.BlockHeight,
×
2880
                                        ),
×
2881
                                },
×
2882
                        )
×
2883
                }
×
2884

2885
                // Batch build all channel edges for disconnection.
2886
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
2887
                        ctx, s.cfg, db, rows,
×
2888
                )
×
2889
                if err != nil {
×
2890
                        return err
×
2891
                }
×
2892

2893
                removedChans = channelEdges
×
2894

×
2895
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2896
                if err != nil {
×
2897
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2898
                }
×
2899

2900
                return db.DeletePruneLogEntriesInRange(
×
2901
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2902
                                StartHeight: int64(height),
×
2903
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2904
                        },
×
2905
                )
×
2906
        }, func() {
×
2907
                removedChans = nil
×
2908
        })
×
2909
        if err != nil {
×
2910
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2911
                        "height: %w", err)
×
2912
        }
×
2913

NEW
2914
        s.cacheMu.Lock()
×
2915
        for _, channel := range removedChans {
×
2916
                s.rejectCache.remove(channel.ChannelID)
×
2917
                s.chanCache.remove(channel.ChannelID)
×
2918
        }
×
NEW
2919
        s.cacheMu.Unlock()
×
2920

×
2921
        return removedChans, nil
×
2922
}
2923

2924
// AddEdgeProof sets the proof of an existing edge in the graph database.
2925
//
2926
// NOTE: part of the V1Store interface.
2927
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2928
        proof *models.ChannelAuthProof) error {
×
2929

×
2930
        var (
×
2931
                ctx       = context.TODO()
×
2932
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2933
        )
×
2934

×
2935
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2936
                res, err := db.AddV1ChannelProof(
×
2937
                        ctx, sqlc.AddV1ChannelProofParams{
×
2938
                                Scid:              scidBytes,
×
2939
                                Node1Signature:    proof.NodeSig1Bytes,
×
2940
                                Node2Signature:    proof.NodeSig2Bytes,
×
2941
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2942
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2943
                        },
×
2944
                )
×
2945
                if err != nil {
×
2946
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2947
                }
×
2948

2949
                n, err := res.RowsAffected()
×
2950
                if err != nil {
×
2951
                        return err
×
2952
                }
×
2953

2954
                if n == 0 {
×
2955
                        return fmt.Errorf("no rows affected when adding edge "+
×
2956
                                "proof for SCID %v", scid)
×
2957
                } else if n > 1 {
×
2958
                        return fmt.Errorf("multiple rows affected when adding "+
×
2959
                                "edge proof for SCID %v: %d rows affected",
×
2960
                                scid, n)
×
2961
                }
×
2962

2963
                return nil
×
2964
        }, sqldb.NoOpReset)
2965
        if err != nil {
×
2966
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2967
        }
×
2968

2969
        return nil
×
2970
}
2971

2972
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2973
// that we can ignore channel announcements that we know to be closed without
2974
// having to validate them and fetch a block.
2975
//
2976
// NOTE: part of the V1Store interface.
2977
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2978
        var (
×
2979
                ctx     = context.TODO()
×
2980
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2981
        )
×
2982

×
2983
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2984
                return db.InsertClosedChannel(ctx, chanIDB)
×
2985
        }, sqldb.NoOpReset)
×
2986
}
2987

2988
// IsClosedScid checks whether a channel identified by the passed in scid is
2989
// closed. This helps avoid having to perform expensive validation checks.
2990
//
2991
// NOTE: part of the V1Store interface.
2992
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2993
        var (
×
2994
                ctx      = context.TODO()
×
2995
                isClosed bool
×
2996
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2997
        )
×
2998
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2999
                var err error
×
3000
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
3001
                if err != nil {
×
3002
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
3003
                                err)
×
3004
                }
×
3005

3006
                return nil
×
3007
        }, sqldb.NoOpReset)
3008
        if err != nil {
×
3009
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
3010
                        err)
×
3011
        }
×
3012

3013
        return isClosed, nil
×
3014
}
3015

3016
// GraphSession will provide the call-back with access to a NodeTraverser
3017
// instance which can be used to perform queries against the channel graph.
3018
//
3019
// NOTE: part of the V1Store interface.
3020
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
3021
        reset func()) error {
×
3022

×
3023
        var ctx = context.TODO()
×
3024

×
3025
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3026
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
3027
        }, reset)
×
3028
}
3029

3030
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
3031
// read only transaction for a consistent view of the graph.
3032
type sqlNodeTraverser struct {
3033
        db    SQLQueries
3034
        chain chainhash.Hash
3035
}
3036

3037
// A compile-time assertion to ensure that sqlNodeTraverser implements the
3038
// NodeTraverser interface.
3039
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
3040

3041
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
3042
func newSQLNodeTraverser(db SQLQueries,
3043
        chain chainhash.Hash) *sqlNodeTraverser {
×
3044

×
3045
        return &sqlNodeTraverser{
×
3046
                db:    db,
×
3047
                chain: chain,
×
3048
        }
×
3049
}
×
3050

3051
// ForEachNodeDirectedChannel calls the callback for every channel of the given
3052
// node.
3053
//
3054
// NOTE: Part of the NodeTraverser interface.
3055
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
3056
        cb func(channel *DirectedChannel) error, _ func()) error {
×
3057

×
3058
        ctx := context.TODO()
×
3059

×
3060
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
3061
}
×
3062

3063
// FetchNodeFeatures returns the features of the given node. If the node is
3064
// unknown, assume no additional features are supported.
3065
//
3066
// NOTE: Part of the NodeTraverser interface.
3067
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
3068
        *lnwire.FeatureVector, error) {
×
3069

×
3070
        ctx := context.TODO()
×
3071

×
3072
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
3073
}
×
3074

3075
// forEachNodeDirectedChannel iterates through all channels of a given
3076
// node, executing the passed callback on the directed edge representing the
3077
// channel and its incoming policy. If the node is not found, no error is
3078
// returned.
3079
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
3080
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
3081

×
3082
        toNodeCallback := func() route.Vertex {
×
3083
                return nodePub
×
3084
        }
×
3085

3086
        dbID, err := db.GetNodeIDByPubKey(
×
3087
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
3088
                        Version: int16(lnwire.GossipVersion1),
×
3089
                        PubKey:  nodePub[:],
×
3090
                },
×
3091
        )
×
3092
        if errors.Is(err, sql.ErrNoRows) {
×
3093
                return nil
×
3094
        } else if err != nil {
×
3095
                return fmt.Errorf("unable to fetch node: %w", err)
×
3096
        }
×
3097

3098
        rows, err := db.ListChannelsByNodeID(
×
3099
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3100
                        Version: int16(lnwire.GossipVersion1),
×
3101
                        NodeID1: dbID,
×
3102
                },
×
3103
        )
×
3104
        if err != nil {
×
3105
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3106
        }
×
3107

3108
        // Exit early if there are no channels for this node so we don't
3109
        // do the unnecessary feature fetching.
3110
        if len(rows) == 0 {
×
3111
                return nil
×
3112
        }
×
3113

3114
        features, err := getNodeFeatures(ctx, db, dbID)
×
3115
        if err != nil {
×
3116
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3117
        }
×
3118

3119
        for _, row := range rows {
×
3120
                node1, node2, err := buildNodeVertices(
×
3121
                        row.Node1Pubkey, row.Node2Pubkey,
×
3122
                )
×
3123
                if err != nil {
×
3124
                        return fmt.Errorf("unable to build node vertices: %w",
×
3125
                                err)
×
3126
                }
×
3127

3128
                edge := buildCacheableChannelInfo(
×
3129
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
3130
                        node1, node2,
×
3131
                )
×
3132

×
3133
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3134
                if err != nil {
×
3135
                        return err
×
3136
                }
×
3137

3138
                p1, p2, err := buildCachedChanPolicies(
×
3139
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3140
                )
×
3141
                if err != nil {
×
3142
                        return err
×
3143
                }
×
3144

3145
                // Determine the outgoing and incoming policy for this
3146
                // channel and node combo.
3147
                outPolicy, inPolicy := p1, p2
×
3148
                if p1 != nil && node2 == nodePub {
×
3149
                        outPolicy, inPolicy = p2, p1
×
3150
                } else if p2 != nil && node1 != nodePub {
×
3151
                        outPolicy, inPolicy = p2, p1
×
3152
                }
×
3153

3154
                var cachedInPolicy *models.CachedEdgePolicy
×
3155
                if inPolicy != nil {
×
3156
                        cachedInPolicy = inPolicy
×
3157
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3158
                        cachedInPolicy.ToNodeFeatures = features
×
3159
                }
×
3160

3161
                directedChannel := &DirectedChannel{
×
3162
                        ChannelID:    edge.ChannelID,
×
3163
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3164
                        OtherNode:    edge.NodeKey2Bytes,
×
3165
                        Capacity:     edge.Capacity,
×
3166
                        OutPolicySet: outPolicy != nil,
×
3167
                        InPolicy:     cachedInPolicy,
×
3168
                }
×
3169
                if outPolicy != nil {
×
3170
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3171
                                directedChannel.InboundFee = fee
×
3172
                        })
×
3173
                }
3174

3175
                if nodePub == edge.NodeKey2Bytes {
×
3176
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3177
                }
×
3178

3179
                if err := cb(directedChannel); err != nil {
×
3180
                        return err
×
3181
                }
×
3182
        }
3183

3184
        return nil
×
3185
}
3186

3187
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3188
// and executes the provided callback for each node. It does so via pagination
3189
// along with batch loading of the node feature bits.
3190
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
3191
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
3192
                features *lnwire.FeatureVector) error) error {
×
3193

×
3194
        handleNode := func(_ context.Context,
×
3195
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
3196
                featureBits map[int64][]int) error {
×
3197

×
3198
                fv := lnwire.EmptyFeatureVector()
×
3199
                if features, exists := featureBits[dbNode.ID]; exists {
×
3200
                        for _, bit := range features {
×
3201
                                fv.Set(lnwire.FeatureBit(bit))
×
3202
                        }
×
3203
                }
3204

3205
                var pub route.Vertex
×
3206
                copy(pub[:], dbNode.PubKey)
×
3207

×
3208
                return processNode(dbNode.ID, pub, fv)
×
3209
        }
3210

3211
        queryFunc := func(ctx context.Context, lastID int64,
×
3212
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
3213

×
3214
                return db.ListNodeIDsAndPubKeys(
×
3215
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3216
                                Version: int16(lnwire.GossipVersion1),
×
3217
                                ID:      lastID,
×
3218
                                Limit:   limit,
×
3219
                        },
×
3220
                )
×
3221
        }
×
3222

3223
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
3224
                return row.ID
×
3225
        }
×
3226

3227
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
3228
                return node.ID, nil
×
3229
        }
×
3230

3231
        batchQueryFunc := func(ctx context.Context,
×
3232
                nodeIDs []int64) (map[int64][]int, error) {
×
3233

×
3234
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
3235
        }
×
3236

3237
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
3238
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
3239
                batchQueryFunc, handleNode,
×
3240
        )
×
3241
}
3242

3243
// forEachNodeChannel iterates through all channels of a node, executing
3244
// the passed callback on each. The call-back is provided with the channel's
3245
// edge information, the outgoing policy and the incoming policy for the
3246
// channel and node combo.
3247
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3248
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
3249
                *models.ChannelEdgePolicy,
3250
                *models.ChannelEdgePolicy) error) error {
×
3251

×
3252
        // Get all the V1 channels for this node.
×
3253
        rows, err := db.ListChannelsByNodeID(
×
3254
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3255
                        Version: int16(lnwire.GossipVersion1),
×
3256
                        NodeID1: id,
×
3257
                },
×
3258
        )
×
3259
        if err != nil {
×
3260
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3261
        }
×
3262

3263
        // Collect all the channel and policy IDs.
3264
        var (
×
3265
                chanIDs   = make([]int64, 0, len(rows))
×
3266
                policyIDs = make([]int64, 0, 2*len(rows))
×
3267
        )
×
3268
        for _, row := range rows {
×
3269
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3270

×
3271
                if row.Policy1ID.Valid {
×
3272
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3273
                }
×
3274
                if row.Policy2ID.Valid {
×
3275
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3276
                }
×
3277
        }
3278

3279
        batchData, err := batchLoadChannelData(
×
3280
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3281
        )
×
3282
        if err != nil {
×
3283
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3284
        }
×
3285

3286
        // Call the call-back for each channel and its known policies.
3287
        for _, row := range rows {
×
3288
                node1, node2, err := buildNodeVertices(
×
3289
                        row.Node1Pubkey, row.Node2Pubkey,
×
3290
                )
×
3291
                if err != nil {
×
3292
                        return fmt.Errorf("unable to build node vertices: %w",
×
3293
                                err)
×
3294
                }
×
3295

3296
                edge, err := buildEdgeInfoWithBatchData(
×
3297
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3298
                        batchData,
×
3299
                )
×
3300
                if err != nil {
×
3301
                        return fmt.Errorf("unable to build channel info: %w",
×
3302
                                err)
×
3303
                }
×
3304

3305
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3306
                if err != nil {
×
3307
                        return fmt.Errorf("unable to extract channel "+
×
3308
                                "policies: %w", err)
×
3309
                }
×
3310

3311
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3312
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3313
                )
×
3314
                if err != nil {
×
3315
                        return fmt.Errorf("unable to build channel "+
×
3316
                                "policies: %w", err)
×
3317
                }
×
3318

3319
                // Determine the outgoing and incoming policy for this
3320
                // channel and node combo.
3321
                p1ToNode := row.GraphChannel.NodeID2
×
3322
                p2ToNode := row.GraphChannel.NodeID1
×
3323
                outPolicy, inPolicy := p1, p2
×
3324
                if (p1 != nil && p1ToNode == id) ||
×
3325
                        (p2 != nil && p2ToNode != id) {
×
3326

×
3327
                        outPolicy, inPolicy = p2, p1
×
3328
                }
×
3329

3330
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3331
                        return err
×
3332
                }
×
3333
        }
3334

3335
        return nil
×
3336
}
3337

3338
// updateChanEdgePolicy upserts the channel policy info we have stored for
3339
// a channel we already know of.
3340
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3341
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3342
        error) {
×
3343

×
3344
        var (
×
3345
                node1Pub, node2Pub route.Vertex
×
3346
                isNode1            bool
×
3347
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3348
        )
×
3349

×
3350
        // Check that this edge policy refers to a channel that we already
×
3351
        // know of. We do this explicitly so that we can return the appropriate
×
3352
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3353
        // abort the transaction which would abort the entire batch.
×
3354
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3355
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3356
                        Scid:    chanIDB,
×
3357
                        Version: int16(lnwire.GossipVersion1),
×
3358
                },
×
3359
        )
×
3360
        if errors.Is(err, sql.ErrNoRows) {
×
3361
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3362
        } else if err != nil {
×
3363
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3364
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3365
        }
×
3366

3367
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3368
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3369

×
3370
        // Figure out which node this edge is from.
×
3371
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3372
        nodeID := dbChan.NodeID1
×
3373
        if !isNode1 {
×
3374
                nodeID = dbChan.NodeID2
×
3375
        }
×
3376

3377
        var (
×
3378
                inboundBase sql.NullInt64
×
3379
                inboundRate sql.NullInt64
×
3380
        )
×
3381
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3382
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3383
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3384
        })
×
3385

3386
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3387
                Version:     int16(lnwire.GossipVersion1),
×
3388
                ChannelID:   dbChan.ID,
×
3389
                NodeID:      nodeID,
×
3390
                Timelock:    int32(edge.TimeLockDelta),
×
3391
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3392
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3393
                MinHtlcMsat: int64(edge.MinHTLC),
×
3394
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3395
                Disabled: sql.NullBool{
×
3396
                        Valid: true,
×
3397
                        Bool:  edge.IsDisabled(),
×
3398
                },
×
3399
                MaxHtlcMsat: sql.NullInt64{
×
3400
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3401
                        Int64: int64(edge.MaxHTLC),
×
3402
                },
×
3403
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3404
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3405
                InboundBaseFeeMsat:      inboundBase,
×
3406
                InboundFeeRateMilliMsat: inboundRate,
×
3407
                Signature:               edge.SigBytes,
×
3408
        })
×
3409
        if err != nil {
×
3410
                return node1Pub, node2Pub, isNode1,
×
3411
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3412
        }
×
3413

3414
        // Convert the flat extra opaque data into a map of TLV types to
3415
        // values.
3416
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3417
        if err != nil {
×
3418
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3419
                        "marshal extra opaque data: %w", err)
×
3420
        }
×
3421

3422
        // Update the channel policy's extra signed fields.
3423
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3424
        if err != nil {
×
3425
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3426
                        "policy extra TLVs: %w", err)
×
3427
        }
×
3428

3429
        return node1Pub, node2Pub, isNode1, nil
×
3430
}
3431

3432
// getNodeByPubKey attempts to look up a target node by its public key.
3433
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3434
        pubKey route.Vertex) (int64, *models.Node, error) {
×
3435

×
3436
        dbNode, err := db.GetNodeByPubKey(
×
3437
                ctx, sqlc.GetNodeByPubKeyParams{
×
3438
                        Version: int16(lnwire.GossipVersion1),
×
3439
                        PubKey:  pubKey[:],
×
3440
                },
×
3441
        )
×
3442
        if errors.Is(err, sql.ErrNoRows) {
×
3443
                return 0, nil, ErrGraphNodeNotFound
×
3444
        } else if err != nil {
×
3445
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3446
        }
×
3447

3448
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3449
        if err != nil {
×
3450
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3451
        }
×
3452

3453
        return dbNode.ID, node, nil
×
3454
}
3455

3456
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3457
// provided parameters.
3458
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3459
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3460

×
3461
        return &models.CachedEdgeInfo{
×
3462
                ChannelID:     byteOrder.Uint64(scid),
×
3463
                NodeKey1Bytes: node1Pub,
×
3464
                NodeKey2Bytes: node2Pub,
×
3465
                Capacity:      btcutil.Amount(capacity),
×
3466
        }
×
3467
}
×
3468

3469
// buildNode constructs a Node instance from the given database node
3470
// record. The node's features, addresses and extra signed fields are also
3471
// fetched from the database and set on the node.
3472
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3473
        dbNode sqlc.GraphNode) (*models.Node, error) {
×
3474

×
3475
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3476
        if err != nil {
×
3477
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3478
                        err)
×
3479
        }
×
3480

3481
        return buildNodeWithBatchData(dbNode, data)
×
3482
}
3483

3484
// buildNodeWithBatchData builds a models.Node instance
3485
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3486
// features/addresses/extra fields, then the corresponding fields are expected
3487
// to be present in the batchNodeData.
3488
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3489
        batchData *batchNodeData) (*models.Node, error) {
×
3490

×
3491
        if dbNode.Version != int16(lnwire.GossipVersion1) {
×
3492
                return nil, fmt.Errorf("unsupported node version: %d",
×
3493
                        dbNode.Version)
×
3494
        }
×
3495

3496
        var pub [33]byte
×
3497
        copy(pub[:], dbNode.PubKey)
×
3498

×
3499
        node := models.NewV1ShellNode(pub)
×
3500

×
3501
        if len(dbNode.Signature) == 0 {
×
3502
                return node, nil
×
3503
        }
×
3504

3505
        node.AuthSigBytes = dbNode.Signature
×
3506

×
3507
        if dbNode.Alias.Valid {
×
3508
                node.Alias = fn.Some(dbNode.Alias.String)
×
3509
        }
×
3510
        if dbNode.LastUpdate.Valid {
×
3511
                node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3512
        }
×
3513

3514
        var err error
×
3515
        if dbNode.Color.Valid {
×
3516
                nodeColor, err := DecodeHexColor(dbNode.Color.String)
×
3517
                if err != nil {
×
3518
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3519
                                err)
×
3520
                }
×
3521

3522
                node.Color = fn.Some(nodeColor)
×
3523
        }
3524

3525
        // Use preloaded features.
3526
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3527
                fv := lnwire.EmptyFeatureVector()
×
3528
                for _, bit := range features {
×
3529
                        fv.Set(lnwire.FeatureBit(bit))
×
3530
                }
×
3531
                node.Features = fv
×
3532
        }
3533

3534
        // Use preloaded addresses.
3535
        addresses, exists := batchData.addresses[dbNode.ID]
×
3536
        if exists && len(addresses) > 0 {
×
3537
                node.Addresses, err = buildNodeAddresses(addresses)
×
3538
                if err != nil {
×
3539
                        return nil, fmt.Errorf("unable to build addresses "+
×
3540
                                "for node(%d): %w", dbNode.ID, err)
×
3541
                }
×
3542
        }
3543

3544
        // Use preloaded extra fields.
3545
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3546
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3547
                if err != nil {
×
3548
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3549
                                "signed fields: %w", err)
×
3550
                }
×
3551
                if len(recs) != 0 {
×
3552
                        node.ExtraOpaqueData = recs
×
3553
                }
×
3554
        }
3555

3556
        return node, nil
×
3557
}
3558

3559
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3560
// with the preloaded data, and executes the provided callback for each node.
3561
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3562
        db SQLQueries, nodes []sqlc.GraphNode,
3563
        cb func(dbID int64, node *models.Node) error) error {
×
3564

×
3565
        // Extract node IDs for batch loading.
×
3566
        nodeIDs := make([]int64, len(nodes))
×
3567
        for i, node := range nodes {
×
3568
                nodeIDs[i] = node.ID
×
3569
        }
×
3570

3571
        // Batch load all related data for this page.
3572
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3573
        if err != nil {
×
3574
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3575
        }
×
3576

3577
        for _, dbNode := range nodes {
×
3578
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
3579
                if err != nil {
×
3580
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3581
                                dbNode.ID, err)
×
3582
                }
×
3583

3584
                if err := cb(dbNode.ID, node); err != nil {
×
3585
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3586
                                dbNode.ID, err)
×
3587
                }
×
3588
        }
3589

3590
        return nil
×
3591
}
3592

3593
// getNodeFeatures fetches the feature bits and constructs the feature vector
3594
// for a node with the given DB ID.
3595
func getNodeFeatures(ctx context.Context, db SQLQueries,
3596
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3597

×
3598
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3599
        if err != nil {
×
3600
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3601
                        nodeID, err)
×
3602
        }
×
3603

3604
        features := lnwire.EmptyFeatureVector()
×
3605
        for _, feature := range rows {
×
3606
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3607
        }
×
3608

3609
        return features, nil
×
3610
}
3611

3612
// upsertNodeAncillaryData updates the node's features, addresses, and extra
3613
// signed fields. This is common logic shared by upsertNode and
3614
// upsertSourceNode.
3615
func upsertNodeAncillaryData(ctx context.Context, db SQLQueries,
3616
        nodeID int64, node *models.Node) error {
×
3617

×
3618
        // Update the node's features.
×
3619
        err := upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3620
        if err != nil {
×
3621
                return fmt.Errorf("inserting node features: %w", err)
×
3622
        }
×
3623

3624
        // Update the node's addresses.
3625
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3626
        if err != nil {
×
3627
                return fmt.Errorf("inserting node addresses: %w", err)
×
3628
        }
×
3629

3630
        // Convert the flat extra opaque data into a map of TLV types to
3631
        // values.
3632
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3633
        if err != nil {
×
3634
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
3635
                        err)
×
3636
        }
×
3637

3638
        // Update the node's extra signed fields.
3639
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3640
        if err != nil {
×
3641
                return fmt.Errorf("inserting node extra TLVs: %w", err)
×
3642
        }
×
3643

3644
        return nil
×
3645
}
3646

3647
// populateNodeParams populates the common node parameters from a models.Node.
3648
// This is a helper for building UpsertNodeParams and UpsertSourceNodeParams.
3649
func populateNodeParams(node *models.Node,
3650
        setParams func(lastUpdate sql.NullInt64, alias,
3651
                colorStr sql.NullString, signature []byte)) error {
×
3652

×
3653
        if !node.HaveAnnouncement() {
×
3654
                return nil
×
3655
        }
×
3656

3657
        switch node.Version {
×
3658
        case lnwire.GossipVersion1:
×
3659
                lastUpdate := sqldb.SQLInt64(node.LastUpdate.Unix())
×
3660
                var alias, colorStr sql.NullString
×
3661

×
3662
                node.Color.WhenSome(func(rgba color.RGBA) {
×
3663
                        colorStr = sqldb.SQLStrValid(EncodeHexColor(rgba))
×
3664
                })
×
3665
                node.Alias.WhenSome(func(s string) {
×
3666
                        alias = sqldb.SQLStrValid(s)
×
3667
                })
×
3668

3669
                setParams(lastUpdate, alias, colorStr, node.AuthSigBytes)
×
3670

3671
        case lnwire.GossipVersion2:
×
3672
                // No-op for now.
3673

3674
        default:
×
3675
                return fmt.Errorf("unknown gossip version: %d", node.Version)
×
3676
        }
3677

3678
        return nil
×
3679
}
3680

3681
// buildNodeUpsertParams builds the parameters for upserting a node using the
3682
// strict UpsertNode query (requires timestamp to be increasing).
3683
func buildNodeUpsertParams(node *models.Node) (sqlc.UpsertNodeParams, error) {
×
3684
        params := sqlc.UpsertNodeParams{
×
3685
                Version: int16(lnwire.GossipVersion1),
×
3686
                PubKey:  node.PubKeyBytes[:],
×
3687
        }
×
3688

×
3689
        err := populateNodeParams(
×
3690
                node, func(lastUpdate sql.NullInt64, alias,
×
3691
                        colorStr sql.NullString,
×
3692
                        signature []byte) {
×
3693

×
3694
                        params.LastUpdate = lastUpdate
×
3695
                        params.Alias = alias
×
3696
                        params.Color = colorStr
×
3697
                        params.Signature = signature
×
3698
                })
×
3699

3700
        return params, err
×
3701
}
3702

3703
// buildSourceNodeUpsertParams builds the parameters for upserting the source
3704
// node using the lenient UpsertSourceNode query (allows same timestamp).
3705
func buildSourceNodeUpsertParams(node *models.Node) (
3706
        sqlc.UpsertSourceNodeParams, error) {
×
3707

×
3708
        params := sqlc.UpsertSourceNodeParams{
×
3709
                Version: int16(lnwire.GossipVersion1),
×
3710
                PubKey:  node.PubKeyBytes[:],
×
3711
        }
×
3712

×
3713
        err := populateNodeParams(
×
3714
                node, func(lastUpdate sql.NullInt64, alias,
×
3715
                        colorStr sql.NullString, signature []byte) {
×
3716

×
3717
                        params.LastUpdate = lastUpdate
×
3718
                        params.Alias = alias
×
3719
                        params.Color = colorStr
×
3720
                        params.Signature = signature
×
3721
                },
×
3722
        )
3723

3724
        return params, err
×
3725
}
3726

3727
// upsertSourceNode upserts the source node record into the database using a
3728
// less strict upsert that allows updates even when the timestamp hasn't
3729
// changed. This is necessary to handle concurrent updates to our own node
3730
// during startup and runtime. The node's features, addresses and extra TLV
3731
// types are also updated. The node's DB ID is returned.
3732
func upsertSourceNode(ctx context.Context, db SQLQueries,
3733
        node *models.Node) (int64, error) {
×
3734

×
3735
        params, err := buildSourceNodeUpsertParams(node)
×
3736
        if err != nil {
×
3737
                return 0, err
×
3738
        }
×
3739

3740
        nodeID, err := db.UpsertSourceNode(ctx, params)
×
3741
        if err != nil {
×
3742
                return 0, fmt.Errorf("upserting source node(%x): %w",
×
3743
                        node.PubKeyBytes, err)
×
3744
        }
×
3745

3746
        // We can exit here if we don't have the announcement yet.
3747
        if !node.HaveAnnouncement() {
×
3748
                return nodeID, nil
×
3749
        }
×
3750

3751
        // Update the ancillary node data (features, addresses, extra fields).
3752
        err = upsertNodeAncillaryData(ctx, db, nodeID, node)
×
3753
        if err != nil {
×
3754
                return 0, err
×
3755
        }
×
3756

3757
        return nodeID, nil
×
3758
}
3759

3760
// upsertNode upserts the node record into the database. If the node already
3761
// exists, then the node's information is updated. If the node doesn't exist,
3762
// then a new node is created. The node's features, addresses and extra TLV
3763
// types are also updated. The node's DB ID is returned.
3764
func upsertNode(ctx context.Context, db SQLQueries,
3765
        node *models.Node) (int64, error) {
×
3766

×
3767
        params, err := buildNodeUpsertParams(node)
×
3768
        if err != nil {
×
3769
                return 0, err
×
3770
        }
×
3771

3772
        nodeID, err := db.UpsertNode(ctx, params)
×
3773
        if err != nil {
×
3774
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3775
                        err)
×
3776
        }
×
3777

3778
        // We can exit here if we don't have the announcement yet.
3779
        if !node.HaveAnnouncement() {
×
3780
                return nodeID, nil
×
3781
        }
×
3782

3783
        // Update the ancillary node data (features, addresses, extra fields).
3784
        err = upsertNodeAncillaryData(ctx, db, nodeID, node)
×
3785
        if err != nil {
×
3786
                return 0, err
×
3787
        }
×
3788

3789
        return nodeID, nil
×
3790
}
3791

3792
// upsertNodeFeatures updates the node's features node_features table. This
3793
// includes deleting any feature bits no longer present and inserting any new
3794
// feature bits. If the feature bit does not yet exist in the features table,
3795
// then an entry is created in that table first.
3796
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3797
        features *lnwire.FeatureVector) error {
×
3798

×
3799
        // Get any existing features for the node.
×
3800
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3801
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3802
                return err
×
3803
        }
×
3804

3805
        // Copy the nodes latest set of feature bits.
3806
        newFeatures := make(map[int32]struct{})
×
3807
        if features != nil {
×
3808
                for feature := range features.Features() {
×
3809
                        newFeatures[int32(feature)] = struct{}{}
×
3810
                }
×
3811
        }
3812

3813
        // For any current feature that already exists in the DB, remove it from
3814
        // the in-memory map. For any existing feature that does not exist in
3815
        // the in-memory map, delete it from the database.
3816
        for _, feature := range existingFeatures {
×
3817
                // The feature is still present, so there are no updates to be
×
3818
                // made.
×
3819
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3820
                        delete(newFeatures, feature.FeatureBit)
×
3821
                        continue
×
3822
                }
3823

3824
                // The feature is no longer present, so we remove it from the
3825
                // database.
3826
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3827
                        NodeID:     nodeID,
×
3828
                        FeatureBit: feature.FeatureBit,
×
3829
                })
×
3830
                if err != nil {
×
3831
                        return fmt.Errorf("unable to delete node(%d) "+
×
3832
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3833
                                err)
×
3834
                }
×
3835
        }
3836

3837
        // Any remaining entries in newFeatures are new features that need to be
3838
        // added to the database for the first time.
3839
        for feature := range newFeatures {
×
3840
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3841
                        NodeID:     nodeID,
×
3842
                        FeatureBit: feature,
×
3843
                })
×
3844
                if err != nil {
×
3845
                        return fmt.Errorf("unable to insert node(%d) "+
×
3846
                                "feature(%v): %w", nodeID, feature, err)
×
3847
                }
×
3848
        }
3849

3850
        return nil
×
3851
}
3852

3853
// fetchNodeFeatures fetches the features for a node with the given public key.
3854
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3855
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3856

×
3857
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3858
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3859
                        PubKey:  nodePub[:],
×
3860
                        Version: int16(lnwire.GossipVersion1),
×
3861
                },
×
3862
        )
×
3863
        if err != nil {
×
3864
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3865
                        nodePub, err)
×
3866
        }
×
3867

3868
        features := lnwire.EmptyFeatureVector()
×
3869
        for _, bit := range rows {
×
3870
                features.Set(lnwire.FeatureBit(bit))
×
3871
        }
×
3872

3873
        return features, nil
×
3874
}
3875

3876
// dbAddressType is an enum type that represents the different address types
3877
// that we store in the node_addresses table. The address type determines how
3878
// the address is to be serialised/deserialize.
3879
type dbAddressType uint8
3880

3881
const (
3882
        addressTypeIPv4   dbAddressType = 1
3883
        addressTypeIPv6   dbAddressType = 2
3884
        addressTypeTorV2  dbAddressType = 3
3885
        addressTypeTorV3  dbAddressType = 4
3886
        addressTypeDNS    dbAddressType = 5
3887
        addressTypeOpaque dbAddressType = math.MaxInt8
3888
)
3889

3890
// collectAddressRecords collects the addresses from the provided
3891
// net.Addr slice and returns a map of dbAddressType to a slice of address
3892
// strings.
3893
func collectAddressRecords(addresses []net.Addr) (map[dbAddressType][]string,
3894
        error) {
×
3895

×
3896
        // Copy the nodes latest set of addresses.
×
3897
        newAddresses := map[dbAddressType][]string{
×
3898
                addressTypeIPv4:   {},
×
3899
                addressTypeIPv6:   {},
×
3900
                addressTypeTorV2:  {},
×
3901
                addressTypeTorV3:  {},
×
3902
                addressTypeDNS:    {},
×
3903
                addressTypeOpaque: {},
×
3904
        }
×
3905
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3906
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3907
        }
×
3908

3909
        for _, address := range addresses {
×
3910
                switch addr := address.(type) {
×
3911
                case *net.TCPAddr:
×
3912
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3913
                                addAddr(addressTypeIPv4, addr)
×
3914
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3915
                                addAddr(addressTypeIPv6, addr)
×
3916
                        } else {
×
3917
                                return nil, fmt.Errorf("unhandled IP "+
×
3918
                                        "address: %v", addr)
×
3919
                        }
×
3920

3921
                case *tor.OnionAddr:
×
3922
                        switch len(addr.OnionService) {
×
3923
                        case tor.V2Len:
×
3924
                                addAddr(addressTypeTorV2, addr)
×
3925
                        case tor.V3Len:
×
3926
                                addAddr(addressTypeTorV3, addr)
×
3927
                        default:
×
3928
                                return nil, fmt.Errorf("invalid length for " +
×
3929
                                        "a tor address")
×
3930
                        }
3931

3932
                case *lnwire.DNSAddress:
×
3933
                        addAddr(addressTypeDNS, addr)
×
3934

3935
                case *lnwire.OpaqueAddrs:
×
3936
                        addAddr(addressTypeOpaque, addr)
×
3937

3938
                default:
×
3939
                        return nil, fmt.Errorf("unhandled address type: %T",
×
3940
                                addr)
×
3941
                }
3942
        }
3943

3944
        return newAddresses, nil
×
3945
}
3946

3947
// upsertNodeAddresses updates the node's addresses in the database. This
3948
// includes deleting any existing addresses and inserting the new set of
3949
// addresses. The deletion is necessary since the ordering of the addresses may
3950
// change, and we need to ensure that the database reflects the latest set of
3951
// addresses so that at the time of reconstructing the node announcement, the
3952
// order is preserved and the signature over the message remains valid.
3953
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3954
        addresses []net.Addr) error {
×
3955

×
3956
        // Delete any existing addresses for the node. This is required since
×
3957
        // even if the new set of addresses is the same, the ordering may have
×
3958
        // changed for a given address type.
×
3959
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3960
        if err != nil {
×
3961
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3962
                        nodeID, err)
×
3963
        }
×
3964

3965
        newAddresses, err := collectAddressRecords(addresses)
×
3966
        if err != nil {
×
3967
                return err
×
3968
        }
×
3969

3970
        // Any remaining entries in newAddresses are new addresses that need to
3971
        // be added to the database for the first time.
3972
        for addrType, addrList := range newAddresses {
×
3973
                for position, addr := range addrList {
×
3974
                        err := db.UpsertNodeAddress(
×
3975
                                ctx, sqlc.UpsertNodeAddressParams{
×
3976
                                        NodeID:   nodeID,
×
3977
                                        Type:     int16(addrType),
×
3978
                                        Address:  addr,
×
3979
                                        Position: int32(position),
×
3980
                                },
×
3981
                        )
×
3982
                        if err != nil {
×
3983
                                return fmt.Errorf("unable to insert "+
×
3984
                                        "node(%d) address(%v): %w", nodeID,
×
3985
                                        addr, err)
×
3986
                        }
×
3987
                }
3988
        }
3989

3990
        return nil
×
3991
}
3992

3993
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3994
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
3995
        error) {
×
3996

×
3997
        // GetNodeAddresses ensures that the addresses for a given type are
×
3998
        // returned in the same order as they were inserted.
×
3999
        rows, err := db.GetNodeAddresses(ctx, id)
×
4000
        if err != nil {
×
4001
                return nil, err
×
4002
        }
×
4003

4004
        addresses := make([]net.Addr, 0, len(rows))
×
4005
        for _, row := range rows {
×
4006
                address := row.Address
×
4007

×
4008
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
4009
                if err != nil {
×
4010
                        return nil, fmt.Errorf("unable to parse address "+
×
4011
                                "for node(%d): %v: %w", id, address, err)
×
4012
                }
×
4013

4014
                addresses = append(addresses, addr)
×
4015
        }
4016

4017
        // If we have no addresses, then we'll return nil instead of an
4018
        // empty slice.
4019
        if len(addresses) == 0 {
×
4020
                addresses = nil
×
4021
        }
×
4022

4023
        return addresses, nil
×
4024
}
4025

4026
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
4027
// database. This includes updating any existing types, inserting any new types,
4028
// and deleting any types that are no longer present.
4029
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
4030
        nodeID int64, extraFields map[uint64][]byte) error {
×
4031

×
4032
        // Get any existing extra signed fields for the node.
×
4033
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
4034
        if err != nil {
×
4035
                return err
×
4036
        }
×
4037

4038
        // Make a lookup map of the existing field types so that we can use it
4039
        // to keep track of any fields we should delete.
4040
        m := make(map[uint64]bool)
×
4041
        for _, field := range existingFields {
×
4042
                m[uint64(field.Type)] = true
×
4043
        }
×
4044

4045
        // For all the new fields, we'll upsert them and remove them from the
4046
        // map of existing fields.
4047
        for tlvType, value := range extraFields {
×
4048
                err = db.UpsertNodeExtraType(
×
4049
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
4050
                                NodeID: nodeID,
×
4051
                                Type:   int64(tlvType),
×
4052
                                Value:  value,
×
4053
                        },
×
4054
                )
×
4055
                if err != nil {
×
4056
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
4057
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4058
                }
×
4059

4060
                // Remove the field from the map of existing fields if it was
4061
                // present.
4062
                delete(m, tlvType)
×
4063
        }
4064

4065
        // For all the fields that are left in the map of existing fields, we'll
4066
        // delete them as they are no longer present in the new set of fields.
4067
        for tlvType := range m {
×
4068
                err = db.DeleteExtraNodeType(
×
4069
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
4070
                                NodeID: nodeID,
×
4071
                                Type:   int64(tlvType),
×
4072
                        },
×
4073
                )
×
4074
                if err != nil {
×
4075
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
4076
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4077
                }
×
4078
        }
4079

4080
        return nil
×
4081
}
4082

4083
// srcNodeInfo holds the information about the source node of the graph.
4084
type srcNodeInfo struct {
4085
        // id is the DB level ID of the source node entry in the "nodes" table.
4086
        id int64
4087

4088
        // pub is the public key of the source node.
4089
        pub route.Vertex
4090
}
4091

4092
// sourceNode returns the DB node ID and pub key of the source node for the
4093
// specified protocol version.
4094
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
4095
        version lnwire.GossipVersion) (int64, route.Vertex, error) {
×
4096

×
4097
        s.srcNodeMu.Lock()
×
4098
        defer s.srcNodeMu.Unlock()
×
4099

×
4100
        // If we already have the source node ID and pub key cached, then
×
4101
        // return them.
×
4102
        if info, ok := s.srcNodes[version]; ok {
×
4103
                return info.id, info.pub, nil
×
4104
        }
×
4105

4106
        var pubKey route.Vertex
×
4107

×
4108
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
4109
        if err != nil {
×
4110
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
4111
                        err)
×
4112
        }
×
4113

4114
        if len(nodes) == 0 {
×
4115
                return 0, pubKey, ErrSourceNodeNotSet
×
4116
        } else if len(nodes) > 1 {
×
4117
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
4118
                        "protocol %s found", version)
×
4119
        }
×
4120

4121
        copy(pubKey[:], nodes[0].PubKey)
×
4122

×
4123
        s.srcNodes[version] = &srcNodeInfo{
×
4124
                id:  nodes[0].NodeID,
×
4125
                pub: pubKey,
×
4126
        }
×
4127

×
4128
        return nodes[0].NodeID, pubKey, nil
×
4129
}
4130

4131
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
4132
// This then produces a map from TLV type to value. If the input is not a
4133
// valid TLV stream, then an error is returned.
4134
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
4135
        r := bytes.NewReader(data)
×
4136

×
4137
        tlvStream, err := tlv.NewStream()
×
4138
        if err != nil {
×
4139
                return nil, err
×
4140
        }
×
4141

4142
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
4143
        // pass it into the P2P decoding variant.
4144
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
4145
        if err != nil {
×
4146
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
4147
        }
×
4148
        if len(parsedTypes) == 0 {
×
4149
                return nil, nil
×
4150
        }
×
4151

4152
        records := make(map[uint64][]byte)
×
4153
        for k, v := range parsedTypes {
×
4154
                records[uint64(k)] = v
×
4155
        }
×
4156

4157
        return records, nil
×
4158
}
4159

4160
// insertChannel inserts a new channel record into the database.
4161
func insertChannel(ctx context.Context, db SQLQueries,
4162
        edge *models.ChannelEdgeInfo) error {
×
4163

×
4164
        // Make sure that at least a "shell" entry for each node is present in
×
4165
        // the nodes table.
×
4166
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
4167
        if err != nil {
×
4168
                return fmt.Errorf("unable to create shell node: %w", err)
×
4169
        }
×
4170

4171
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
4172
        if err != nil {
×
4173
                return fmt.Errorf("unable to create shell node: %w", err)
×
4174
        }
×
4175

4176
        var capacity sql.NullInt64
×
4177
        if edge.Capacity != 0 {
×
4178
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
4179
        }
×
4180

4181
        createParams := sqlc.CreateChannelParams{
×
4182
                Version:     int16(lnwire.GossipVersion1),
×
4183
                Scid:        channelIDToBytes(edge.ChannelID),
×
4184
                NodeID1:     node1DBID,
×
4185
                NodeID2:     node2DBID,
×
4186
                Outpoint:    edge.ChannelPoint.String(),
×
4187
                Capacity:    capacity,
×
4188
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
4189
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
4190
        }
×
4191

×
4192
        if edge.AuthProof != nil {
×
4193
                proof := edge.AuthProof
×
4194

×
4195
                createParams.Node1Signature = proof.NodeSig1Bytes
×
4196
                createParams.Node2Signature = proof.NodeSig2Bytes
×
4197
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
4198
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
4199
        }
×
4200

4201
        // Insert the new channel record.
4202
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
4203
        if err != nil {
×
4204
                return err
×
4205
        }
×
4206

4207
        // Insert any channel features.
4208
        for feature := range edge.Features.Features() {
×
4209
                err = db.InsertChannelFeature(
×
4210
                        ctx, sqlc.InsertChannelFeatureParams{
×
4211
                                ChannelID:  dbChanID,
×
4212
                                FeatureBit: int32(feature),
×
4213
                        },
×
4214
                )
×
4215
                if err != nil {
×
4216
                        return fmt.Errorf("unable to insert channel(%d) "+
×
4217
                                "feature(%v): %w", dbChanID, feature, err)
×
4218
                }
×
4219
        }
4220

4221
        // Finally, insert any extra TLV fields in the channel announcement.
4222
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
4223
        if err != nil {
×
4224
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
4225
                        err)
×
4226
        }
×
4227

4228
        for tlvType, value := range extra {
×
4229
                err := db.UpsertChannelExtraType(
×
4230
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
4231
                                ChannelID: dbChanID,
×
4232
                                Type:      int64(tlvType),
×
4233
                                Value:     value,
×
4234
                        },
×
4235
                )
×
4236
                if err != nil {
×
4237
                        return fmt.Errorf("unable to upsert channel(%d) "+
×
4238
                                "extra signed field(%v): %w", edge.ChannelID,
×
4239
                                tlvType, err)
×
4240
                }
×
4241
        }
4242

4243
        return nil
×
4244
}
4245

4246
// maybeCreateShellNode checks if a shell node entry exists for the
4247
// given public key. If it does not exist, then a new shell node entry is
4248
// created. The ID of the node is returned. A shell node only has a protocol
4249
// version and public key persisted.
4250
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
4251
        pubKey route.Vertex) (int64, error) {
×
4252

×
4253
        dbNode, err := db.GetNodeByPubKey(
×
4254
                ctx, sqlc.GetNodeByPubKeyParams{
×
4255
                        PubKey:  pubKey[:],
×
4256
                        Version: int16(lnwire.GossipVersion1),
×
4257
                },
×
4258
        )
×
4259
        // The node exists. Return the ID.
×
4260
        if err == nil {
×
4261
                return dbNode.ID, nil
×
4262
        } else if !errors.Is(err, sql.ErrNoRows) {
×
4263
                return 0, err
×
4264
        }
×
4265

4266
        // Otherwise, the node does not exist, so we create a shell entry for
4267
        // it.
4268
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
4269
                Version: int16(lnwire.GossipVersion1),
×
4270
                PubKey:  pubKey[:],
×
4271
        })
×
4272
        if err != nil {
×
4273
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4274
        }
×
4275

4276
        return id, nil
×
4277
}
4278

4279
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4280
// the database. This includes deleting any existing types and then inserting
4281
// the new types.
4282
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4283
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4284

×
4285
        // Delete all existing extra signed fields for the channel policy.
×
4286
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4287
        if err != nil {
×
4288
                return fmt.Errorf("unable to delete "+
×
4289
                        "existing policy extra signed fields for policy %d: %w",
×
4290
                        chanPolicyID, err)
×
4291
        }
×
4292

4293
        // Insert all new extra signed fields for the channel policy.
4294
        for tlvType, value := range extraFields {
×
4295
                err = db.UpsertChanPolicyExtraType(
×
4296
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
4297
                                ChannelPolicyID: chanPolicyID,
×
4298
                                Type:            int64(tlvType),
×
4299
                                Value:           value,
×
4300
                        },
×
4301
                )
×
4302
                if err != nil {
×
4303
                        return fmt.Errorf("unable to insert "+
×
4304
                                "channel_policy(%d) extra signed field(%v): %w",
×
4305
                                chanPolicyID, tlvType, err)
×
4306
                }
×
4307
        }
4308

4309
        return nil
×
4310
}
4311

4312
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4313
// provided dbChanRow and also fetches any other required information
4314
// to construct the edge info.
4315
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
4316
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
4317
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4318

×
4319
        data, err := batchLoadChannelData(
×
4320
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
4321
        )
×
4322
        if err != nil {
×
4323
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
4324
                        err)
×
4325
        }
×
4326

4327
        return buildEdgeInfoWithBatchData(
×
4328
                cfg.ChainHash, dbChan, node1, node2, data,
×
4329
        )
×
4330
}
4331

4332
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
4333
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
4334
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
4335
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
4336

×
4337
        if dbChan.Version != int16(lnwire.GossipVersion1) {
×
4338
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4339
                        dbChan.Version)
×
4340
        }
×
4341

4342
        // Use pre-loaded features and extras types.
4343
        fv := lnwire.EmptyFeatureVector()
×
4344
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
4345
                for _, bit := range features {
×
4346
                        fv.Set(lnwire.FeatureBit(bit))
×
4347
                }
×
4348
        }
4349

4350
        var extras map[uint64][]byte
×
4351
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
4352
        if exists {
×
4353
                extras = channelExtras
×
4354
        } else {
×
4355
                extras = make(map[uint64][]byte)
×
4356
        }
×
4357

4358
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4359
        if err != nil {
×
4360
                return nil, err
×
4361
        }
×
4362

4363
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4364
        if err != nil {
×
4365
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4366
                        "fields: %w", err)
×
4367
        }
×
4368
        if recs == nil {
×
4369
                recs = make([]byte, 0)
×
4370
        }
×
4371

4372
        var btcKey1, btcKey2 route.Vertex
×
4373
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4374
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4375

×
4376
        channel := &models.ChannelEdgeInfo{
×
4377
                ChainHash:        chain,
×
4378
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4379
                NodeKey1Bytes:    node1,
×
4380
                NodeKey2Bytes:    node2,
×
4381
                BitcoinKey1Bytes: btcKey1,
×
4382
                BitcoinKey2Bytes: btcKey2,
×
4383
                ChannelPoint:     *op,
×
4384
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4385
                Features:         fv,
×
4386
                ExtraOpaqueData:  recs,
×
4387
        }
×
4388

×
4389
        // We always set all the signatures at the same time, so we can
×
4390
        // safely check if one signature is present to determine if we have the
×
4391
        // rest of the signatures for the auth proof.
×
4392
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4393
                channel.AuthProof = &models.ChannelAuthProof{
×
4394
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4395
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4396
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4397
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4398
                }
×
4399
        }
×
4400

4401
        return channel, nil
×
4402
}
4403

4404
// buildNodeVertices is a helper that converts raw node public keys
4405
// into route.Vertex instances.
4406
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4407
        route.Vertex, error) {
×
4408

×
4409
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4410
        if err != nil {
×
4411
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4412
                        "create vertex from node1 pubkey: %w", err)
×
4413
        }
×
4414

4415
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4416
        if err != nil {
×
4417
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4418
                        "create vertex from node2 pubkey: %w", err)
×
4419
        }
×
4420

4421
        return node1Vertex, node2Vertex, nil
×
4422
}
4423

4424
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4425
// retrieves all the extra info required to build the complete
4426
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4427
// the provided sqlc.GraphChannelPolicy records are nil.
4428
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
4429
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4430
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
4431
        *models.ChannelEdgePolicy, error) {
×
4432

×
4433
        if dbPol1 == nil && dbPol2 == nil {
×
4434
                return nil, nil, nil
×
4435
        }
×
4436

4437
        var policyIDs = make([]int64, 0, 2)
×
4438
        if dbPol1 != nil {
×
4439
                policyIDs = append(policyIDs, dbPol1.ID)
×
4440
        }
×
4441
        if dbPol2 != nil {
×
4442
                policyIDs = append(policyIDs, dbPol2.ID)
×
4443
        }
×
4444

4445
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4446
        if err != nil {
×
4447
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4448
                        "data: %w", err)
×
4449
        }
×
4450

4451
        pol1, err := buildChanPolicyWithBatchData(
×
4452
                dbPol1, channelID, node2, batchData,
×
4453
        )
×
4454
        if err != nil {
×
4455
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4456
        }
×
4457

4458
        pol2, err := buildChanPolicyWithBatchData(
×
4459
                dbPol2, channelID, node1, batchData,
×
4460
        )
×
4461
        if err != nil {
×
4462
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4463
        }
×
4464

4465
        return pol1, pol2, nil
×
4466
}
4467

4468
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4469
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4470
// then nil is returned for it.
4471
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4472
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4473
        *models.CachedEdgePolicy, error) {
×
4474

×
4475
        var p1, p2 *models.CachedEdgePolicy
×
4476
        if dbPol1 != nil {
×
4477
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
4478
                if err != nil {
×
4479
                        return nil, nil, err
×
4480
                }
×
4481

4482
                p1 = models.NewCachedPolicy(policy1)
×
4483
        }
4484
        if dbPol2 != nil {
×
4485
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4486
                if err != nil {
×
4487
                        return nil, nil, err
×
4488
                }
×
4489

4490
                p2 = models.NewCachedPolicy(policy2)
×
4491
        }
4492

4493
        return p1, p2, nil
×
4494
}
4495

4496
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4497
// provided sqlc.GraphChannelPolicy and other required information.
4498
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4499
        extras map[uint64][]byte,
4500
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4501

×
4502
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4503
        if err != nil {
×
4504
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4505
                        "fields: %w", err)
×
4506
        }
×
4507

4508
        var inboundFee fn.Option[lnwire.Fee]
×
4509
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4510
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4511

×
4512
                inboundFee = fn.Some(lnwire.Fee{
×
4513
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4514
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4515
                })
×
4516
        }
×
4517

4518
        return &models.ChannelEdgePolicy{
×
4519
                SigBytes:  dbPolicy.Signature,
×
4520
                ChannelID: channelID,
×
4521
                LastUpdate: time.Unix(
×
4522
                        dbPolicy.LastUpdate.Int64, 0,
×
4523
                ),
×
4524
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4525
                        dbPolicy.MessageFlags,
×
4526
                ),
×
4527
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4528
                        dbPolicy.ChannelFlags,
×
4529
                ),
×
4530
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4531
                MinHTLC: lnwire.MilliSatoshi(
×
4532
                        dbPolicy.MinHtlcMsat,
×
4533
                ),
×
4534
                MaxHTLC: lnwire.MilliSatoshi(
×
4535
                        dbPolicy.MaxHtlcMsat.Int64,
×
4536
                ),
×
4537
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4538
                        dbPolicy.BaseFeeMsat,
×
4539
                ),
×
4540
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4541
                ToNode:                    toNode,
×
4542
                InboundFee:                inboundFee,
×
4543
                ExtraOpaqueData:           recs,
×
4544
        }, nil
×
4545
}
4546

4547
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4548
// row which is expected to be a sqlc type that contains channel policy
4549
// information. It returns two policies, which may be nil if the policy
4550
// information is not present in the row.
4551
//
4552
//nolint:ll,dupl,funlen
4553
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4554
        *sqlc.GraphChannelPolicy, error) {
×
4555

×
4556
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4557
        switch r := row.(type) {
×
4558
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4559
                if r.Policy1Timelock.Valid {
×
4560
                        policy1 = &sqlc.GraphChannelPolicy{
×
4561
                                Timelock:                r.Policy1Timelock.Int32,
×
4562
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4563
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4564
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4565
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4566
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4567
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4568
                                Disabled:                r.Policy1Disabled,
×
4569
                                MessageFlags:            r.Policy1MessageFlags,
×
4570
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4571
                        }
×
4572
                }
×
4573
                if r.Policy2Timelock.Valid {
×
4574
                        policy2 = &sqlc.GraphChannelPolicy{
×
4575
                                Timelock:                r.Policy2Timelock.Int32,
×
4576
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4577
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4578
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4579
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4580
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4581
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4582
                                Disabled:                r.Policy2Disabled,
×
4583
                                MessageFlags:            r.Policy2MessageFlags,
×
4584
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4585
                        }
×
4586
                }
×
4587

4588
                return policy1, policy2, nil
×
4589

4590
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4591
                if r.Policy1ID.Valid {
×
4592
                        policy1 = &sqlc.GraphChannelPolicy{
×
4593
                                ID:                      r.Policy1ID.Int64,
×
4594
                                Version:                 r.Policy1Version.Int16,
×
4595
                                ChannelID:               r.GraphChannel.ID,
×
4596
                                NodeID:                  r.Policy1NodeID.Int64,
×
4597
                                Timelock:                r.Policy1Timelock.Int32,
×
4598
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4599
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4600
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4601
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4602
                                LastUpdate:              r.Policy1LastUpdate,
×
4603
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4604
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4605
                                Disabled:                r.Policy1Disabled,
×
4606
                                MessageFlags:            r.Policy1MessageFlags,
×
4607
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4608
                                Signature:               r.Policy1Signature,
×
4609
                        }
×
4610
                }
×
4611
                if r.Policy2ID.Valid {
×
4612
                        policy2 = &sqlc.GraphChannelPolicy{
×
4613
                                ID:                      r.Policy2ID.Int64,
×
4614
                                Version:                 r.Policy2Version.Int16,
×
4615
                                ChannelID:               r.GraphChannel.ID,
×
4616
                                NodeID:                  r.Policy2NodeID.Int64,
×
4617
                                Timelock:                r.Policy2Timelock.Int32,
×
4618
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4619
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4620
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4621
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4622
                                LastUpdate:              r.Policy2LastUpdate,
×
4623
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4624
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4625
                                Disabled:                r.Policy2Disabled,
×
4626
                                MessageFlags:            r.Policy2MessageFlags,
×
4627
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4628
                                Signature:               r.Policy2Signature,
×
4629
                        }
×
4630
                }
×
4631

4632
                return policy1, policy2, nil
×
4633

4634
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4635
                if r.Policy1ID.Valid {
×
4636
                        policy1 = &sqlc.GraphChannelPolicy{
×
4637
                                ID:                      r.Policy1ID.Int64,
×
4638
                                Version:                 r.Policy1Version.Int16,
×
4639
                                ChannelID:               r.GraphChannel.ID,
×
4640
                                NodeID:                  r.Policy1NodeID.Int64,
×
4641
                                Timelock:                r.Policy1Timelock.Int32,
×
4642
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4643
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4644
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4645
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4646
                                LastUpdate:              r.Policy1LastUpdate,
×
4647
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4648
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4649
                                Disabled:                r.Policy1Disabled,
×
4650
                                MessageFlags:            r.Policy1MessageFlags,
×
4651
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4652
                                Signature:               r.Policy1Signature,
×
4653
                        }
×
4654
                }
×
4655
                if r.Policy2ID.Valid {
×
4656
                        policy2 = &sqlc.GraphChannelPolicy{
×
4657
                                ID:                      r.Policy2ID.Int64,
×
4658
                                Version:                 r.Policy2Version.Int16,
×
4659
                                ChannelID:               r.GraphChannel.ID,
×
4660
                                NodeID:                  r.Policy2NodeID.Int64,
×
4661
                                Timelock:                r.Policy2Timelock.Int32,
×
4662
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4663
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4664
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4665
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4666
                                LastUpdate:              r.Policy2LastUpdate,
×
4667
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4668
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4669
                                Disabled:                r.Policy2Disabled,
×
4670
                                MessageFlags:            r.Policy2MessageFlags,
×
4671
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4672
                                Signature:               r.Policy2Signature,
×
4673
                        }
×
4674
                }
×
4675

4676
                return policy1, policy2, nil
×
4677

4678
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4679
                if r.Policy1ID.Valid {
×
4680
                        policy1 = &sqlc.GraphChannelPolicy{
×
4681
                                ID:                      r.Policy1ID.Int64,
×
4682
                                Version:                 r.Policy1Version.Int16,
×
4683
                                ChannelID:               r.GraphChannel.ID,
×
4684
                                NodeID:                  r.Policy1NodeID.Int64,
×
4685
                                Timelock:                r.Policy1Timelock.Int32,
×
4686
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4687
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4688
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4689
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4690
                                LastUpdate:              r.Policy1LastUpdate,
×
4691
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4692
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4693
                                Disabled:                r.Policy1Disabled,
×
4694
                                MessageFlags:            r.Policy1MessageFlags,
×
4695
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4696
                                Signature:               r.Policy1Signature,
×
4697
                        }
×
4698
                }
×
4699
                if r.Policy2ID.Valid {
×
4700
                        policy2 = &sqlc.GraphChannelPolicy{
×
4701
                                ID:                      r.Policy2ID.Int64,
×
4702
                                Version:                 r.Policy2Version.Int16,
×
4703
                                ChannelID:               r.GraphChannel.ID,
×
4704
                                NodeID:                  r.Policy2NodeID.Int64,
×
4705
                                Timelock:                r.Policy2Timelock.Int32,
×
4706
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4707
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4708
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4709
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4710
                                LastUpdate:              r.Policy2LastUpdate,
×
4711
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4712
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4713
                                Disabled:                r.Policy2Disabled,
×
4714
                                MessageFlags:            r.Policy2MessageFlags,
×
4715
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4716
                                Signature:               r.Policy2Signature,
×
4717
                        }
×
4718
                }
×
4719

4720
                return policy1, policy2, nil
×
4721

4722
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4723
                if r.Policy1ID.Valid {
×
4724
                        policy1 = &sqlc.GraphChannelPolicy{
×
4725
                                ID:                      r.Policy1ID.Int64,
×
4726
                                Version:                 r.Policy1Version.Int16,
×
4727
                                ChannelID:               r.GraphChannel.ID,
×
4728
                                NodeID:                  r.Policy1NodeID.Int64,
×
4729
                                Timelock:                r.Policy1Timelock.Int32,
×
4730
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4731
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4732
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4733
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4734
                                LastUpdate:              r.Policy1LastUpdate,
×
4735
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4736
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4737
                                Disabled:                r.Policy1Disabled,
×
4738
                                MessageFlags:            r.Policy1MessageFlags,
×
4739
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4740
                                Signature:               r.Policy1Signature,
×
4741
                        }
×
4742
                }
×
4743
                if r.Policy2ID.Valid {
×
4744
                        policy2 = &sqlc.GraphChannelPolicy{
×
4745
                                ID:                      r.Policy2ID.Int64,
×
4746
                                Version:                 r.Policy2Version.Int16,
×
4747
                                ChannelID:               r.GraphChannel.ID,
×
4748
                                NodeID:                  r.Policy2NodeID.Int64,
×
4749
                                Timelock:                r.Policy2Timelock.Int32,
×
4750
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4751
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4752
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4753
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4754
                                LastUpdate:              r.Policy2LastUpdate,
×
4755
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4756
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4757
                                Disabled:                r.Policy2Disabled,
×
4758
                                MessageFlags:            r.Policy2MessageFlags,
×
4759
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4760
                                Signature:               r.Policy2Signature,
×
4761
                        }
×
4762
                }
×
4763

4764
                return policy1, policy2, nil
×
4765

4766
        case sqlc.ListChannelsForNodeIDsRow:
×
4767
                if r.Policy1ID.Valid {
×
4768
                        policy1 = &sqlc.GraphChannelPolicy{
×
4769
                                ID:                      r.Policy1ID.Int64,
×
4770
                                Version:                 r.Policy1Version.Int16,
×
4771
                                ChannelID:               r.GraphChannel.ID,
×
4772
                                NodeID:                  r.Policy1NodeID.Int64,
×
4773
                                Timelock:                r.Policy1Timelock.Int32,
×
4774
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4775
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4776
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4777
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4778
                                LastUpdate:              r.Policy1LastUpdate,
×
4779
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4780
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4781
                                Disabled:                r.Policy1Disabled,
×
4782
                                MessageFlags:            r.Policy1MessageFlags,
×
4783
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4784
                                Signature:               r.Policy1Signature,
×
4785
                        }
×
4786
                }
×
4787
                if r.Policy2ID.Valid {
×
4788
                        policy2 = &sqlc.GraphChannelPolicy{
×
4789
                                ID:                      r.Policy2ID.Int64,
×
4790
                                Version:                 r.Policy2Version.Int16,
×
4791
                                ChannelID:               r.GraphChannel.ID,
×
4792
                                NodeID:                  r.Policy2NodeID.Int64,
×
4793
                                Timelock:                r.Policy2Timelock.Int32,
×
4794
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4795
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4796
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4797
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4798
                                LastUpdate:              r.Policy2LastUpdate,
×
4799
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4800
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4801
                                Disabled:                r.Policy2Disabled,
×
4802
                                MessageFlags:            r.Policy2MessageFlags,
×
4803
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4804
                                Signature:               r.Policy2Signature,
×
4805
                        }
×
4806
                }
×
4807

4808
                return policy1, policy2, nil
×
4809

4810
        case sqlc.ListChannelsByNodeIDRow:
×
4811
                if r.Policy1ID.Valid {
×
4812
                        policy1 = &sqlc.GraphChannelPolicy{
×
4813
                                ID:                      r.Policy1ID.Int64,
×
4814
                                Version:                 r.Policy1Version.Int16,
×
4815
                                ChannelID:               r.GraphChannel.ID,
×
4816
                                NodeID:                  r.Policy1NodeID.Int64,
×
4817
                                Timelock:                r.Policy1Timelock.Int32,
×
4818
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4819
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4820
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4821
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4822
                                LastUpdate:              r.Policy1LastUpdate,
×
4823
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4824
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4825
                                Disabled:                r.Policy1Disabled,
×
4826
                                MessageFlags:            r.Policy1MessageFlags,
×
4827
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4828
                                Signature:               r.Policy1Signature,
×
4829
                        }
×
4830
                }
×
4831
                if r.Policy2ID.Valid {
×
4832
                        policy2 = &sqlc.GraphChannelPolicy{
×
4833
                                ID:                      r.Policy2ID.Int64,
×
4834
                                Version:                 r.Policy2Version.Int16,
×
4835
                                ChannelID:               r.GraphChannel.ID,
×
4836
                                NodeID:                  r.Policy2NodeID.Int64,
×
4837
                                Timelock:                r.Policy2Timelock.Int32,
×
4838
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4839
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4840
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4841
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4842
                                LastUpdate:              r.Policy2LastUpdate,
×
4843
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4844
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4845
                                Disabled:                r.Policy2Disabled,
×
4846
                                MessageFlags:            r.Policy2MessageFlags,
×
4847
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4848
                                Signature:               r.Policy2Signature,
×
4849
                        }
×
4850
                }
×
4851

4852
                return policy1, policy2, nil
×
4853

4854
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4855
                if r.Policy1ID.Valid {
×
4856
                        policy1 = &sqlc.GraphChannelPolicy{
×
4857
                                ID:                      r.Policy1ID.Int64,
×
4858
                                Version:                 r.Policy1Version.Int16,
×
4859
                                ChannelID:               r.GraphChannel.ID,
×
4860
                                NodeID:                  r.Policy1NodeID.Int64,
×
4861
                                Timelock:                r.Policy1Timelock.Int32,
×
4862
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4863
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4864
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4865
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4866
                                LastUpdate:              r.Policy1LastUpdate,
×
4867
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4868
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4869
                                Disabled:                r.Policy1Disabled,
×
4870
                                MessageFlags:            r.Policy1MessageFlags,
×
4871
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4872
                                Signature:               r.Policy1Signature,
×
4873
                        }
×
4874
                }
×
4875
                if r.Policy2ID.Valid {
×
4876
                        policy2 = &sqlc.GraphChannelPolicy{
×
4877
                                ID:                      r.Policy2ID.Int64,
×
4878
                                Version:                 r.Policy2Version.Int16,
×
4879
                                ChannelID:               r.GraphChannel.ID,
×
4880
                                NodeID:                  r.Policy2NodeID.Int64,
×
4881
                                Timelock:                r.Policy2Timelock.Int32,
×
4882
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4883
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4884
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4885
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4886
                                LastUpdate:              r.Policy2LastUpdate,
×
4887
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4888
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4889
                                Disabled:                r.Policy2Disabled,
×
4890
                                MessageFlags:            r.Policy2MessageFlags,
×
4891
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4892
                                Signature:               r.Policy2Signature,
×
4893
                        }
×
4894
                }
×
4895

4896
                return policy1, policy2, nil
×
4897

4898
        case sqlc.GetChannelsByIDsRow:
×
4899
                if r.Policy1ID.Valid {
×
4900
                        policy1 = &sqlc.GraphChannelPolicy{
×
4901
                                ID:                      r.Policy1ID.Int64,
×
4902
                                Version:                 r.Policy1Version.Int16,
×
4903
                                ChannelID:               r.GraphChannel.ID,
×
4904
                                NodeID:                  r.Policy1NodeID.Int64,
×
4905
                                Timelock:                r.Policy1Timelock.Int32,
×
4906
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4907
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4908
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4909
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4910
                                LastUpdate:              r.Policy1LastUpdate,
×
4911
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4912
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4913
                                Disabled:                r.Policy1Disabled,
×
4914
                                MessageFlags:            r.Policy1MessageFlags,
×
4915
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4916
                                Signature:               r.Policy1Signature,
×
4917
                        }
×
4918
                }
×
4919
                if r.Policy2ID.Valid {
×
4920
                        policy2 = &sqlc.GraphChannelPolicy{
×
4921
                                ID:                      r.Policy2ID.Int64,
×
4922
                                Version:                 r.Policy2Version.Int16,
×
4923
                                ChannelID:               r.GraphChannel.ID,
×
4924
                                NodeID:                  r.Policy2NodeID.Int64,
×
4925
                                Timelock:                r.Policy2Timelock.Int32,
×
4926
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4927
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4928
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4929
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4930
                                LastUpdate:              r.Policy2LastUpdate,
×
4931
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4932
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4933
                                Disabled:                r.Policy2Disabled,
×
4934
                                MessageFlags:            r.Policy2MessageFlags,
×
4935
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4936
                                Signature:               r.Policy2Signature,
×
4937
                        }
×
4938
                }
×
4939

4940
                return policy1, policy2, nil
×
4941

4942
        default:
×
4943
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4944
                        "extractChannelPolicies: %T", r)
×
4945
        }
4946
}
4947

4948
// channelIDToBytes converts a channel ID (SCID) to a byte array
4949
// representation.
4950
func channelIDToBytes(channelID uint64) []byte {
×
4951
        var chanIDB [8]byte
×
4952
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4953

×
4954
        return chanIDB[:]
×
4955
}
×
4956

4957
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4958
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4959
        if len(addresses) == 0 {
×
4960
                return nil, nil
×
4961
        }
×
4962

4963
        result := make([]net.Addr, 0, len(addresses))
×
4964
        for _, addr := range addresses {
×
4965
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
4966
                if err != nil {
×
4967
                        return nil, fmt.Errorf("unable to parse address %s "+
×
4968
                                "of type %d: %w", addr.address, addr.addrType,
×
4969
                                err)
×
4970
                }
×
4971
                if netAddr != nil {
×
4972
                        result = append(result, netAddr)
×
4973
                }
×
4974
        }
4975

4976
        // If we have no valid addresses, return nil instead of empty slice.
4977
        if len(result) == 0 {
×
4978
                return nil, nil
×
4979
        }
×
4980

4981
        return result, nil
×
4982
}
4983

4984
// parseAddress parses the given address string based on the address type
4985
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4986
// and opaque addresses.
4987
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4988
        switch addrType {
×
4989
        case addressTypeIPv4:
×
4990
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4991
                if err != nil {
×
4992
                        return nil, err
×
4993
                }
×
4994

4995
                tcp.IP = tcp.IP.To4()
×
4996

×
4997
                return tcp, nil
×
4998

4999
        case addressTypeIPv6:
×
5000
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
5001
                if err != nil {
×
5002
                        return nil, err
×
5003
                }
×
5004

5005
                return tcp, nil
×
5006

5007
        case addressTypeTorV3, addressTypeTorV2:
×
5008
                service, portStr, err := net.SplitHostPort(address)
×
5009
                if err != nil {
×
5010
                        return nil, fmt.Errorf("unable to split tor "+
×
5011
                                "address: %v", address)
×
5012
                }
×
5013

5014
                port, err := strconv.Atoi(portStr)
×
5015
                if err != nil {
×
5016
                        return nil, err
×
5017
                }
×
5018

5019
                return &tor.OnionAddr{
×
5020
                        OnionService: service,
×
5021
                        Port:         port,
×
5022
                }, nil
×
5023

5024
        case addressTypeDNS:
×
5025
                hostname, portStr, err := net.SplitHostPort(address)
×
5026
                if err != nil {
×
5027
                        return nil, fmt.Errorf("unable to split DNS "+
×
5028
                                "address: %v", address)
×
5029
                }
×
5030

5031
                port, err := strconv.Atoi(portStr)
×
5032
                if err != nil {
×
5033
                        return nil, err
×
5034
                }
×
5035

5036
                return &lnwire.DNSAddress{
×
5037
                        Hostname: hostname,
×
5038
                        Port:     uint16(port),
×
5039
                }, nil
×
5040

5041
        case addressTypeOpaque:
×
5042
                opaque, err := hex.DecodeString(address)
×
5043
                if err != nil {
×
5044
                        return nil, fmt.Errorf("unable to decode opaque "+
×
5045
                                "address: %v", address)
×
5046
                }
×
5047

5048
                return &lnwire.OpaqueAddrs{
×
5049
                        Payload: opaque,
×
5050
                }, nil
×
5051

5052
        default:
×
5053
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
5054
        }
5055
}
5056

5057
// batchNodeData holds all the related data for a batch of nodes.
5058
type batchNodeData struct {
5059
        // features is a map from a DB node ID to the feature bits for that
5060
        // node.
5061
        features map[int64][]int
5062

5063
        // addresses is a map from a DB node ID to the node's addresses.
5064
        addresses map[int64][]nodeAddress
5065

5066
        // extraFields is a map from a DB node ID to the extra signed fields
5067
        // for that node.
5068
        extraFields map[int64]map[uint64][]byte
5069
}
5070

5071
// nodeAddress holds the address type, position and address string for a
5072
// node. This is used to batch the fetching of node addresses.
5073
type nodeAddress struct {
5074
        addrType dbAddressType
5075
        position int32
5076
        address  string
5077
}
5078

5079
// batchLoadNodeData loads all related data for a batch of node IDs using the
5080
// provided SQLQueries interface. It returns a batchNodeData instance containing
5081
// the node features, addresses and extra signed fields.
5082
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
5083
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
5084

×
5085
        // Batch load the node features.
×
5086
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
5087
        if err != nil {
×
5088
                return nil, fmt.Errorf("unable to batch load node "+
×
5089
                        "features: %w", err)
×
5090
        }
×
5091

5092
        // Batch load the node addresses.
5093
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
5094
        if err != nil {
×
5095
                return nil, fmt.Errorf("unable to batch load node "+
×
5096
                        "addresses: %w", err)
×
5097
        }
×
5098

5099
        // Batch load the node extra signed fields.
5100
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
5101
        if err != nil {
×
5102
                return nil, fmt.Errorf("unable to batch load node extra "+
×
5103
                        "signed fields: %w", err)
×
5104
        }
×
5105

5106
        return &batchNodeData{
×
5107
                features:    features,
×
5108
                addresses:   addrs,
×
5109
                extraFields: extraTypes,
×
5110
        }, nil
×
5111
}
5112

5113
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
5114
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
5115
func batchLoadNodeFeaturesHelper(ctx context.Context,
5116
        cfg *sqldb.QueryConfig, db SQLQueries,
5117
        nodeIDs []int64) (map[int64][]int, error) {
×
5118

×
5119
        features := make(map[int64][]int)
×
5120

×
5121
        return features, sqldb.ExecuteBatchQuery(
×
5122
                ctx, cfg, nodeIDs,
×
5123
                func(id int64) int64 {
×
5124
                        return id
×
5125
                },
×
5126
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
5127
                        error) {
×
5128

×
5129
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
5130
                },
×
5131
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
5132
                        features[feature.NodeID] = append(
×
5133
                                features[feature.NodeID],
×
5134
                                int(feature.FeatureBit),
×
5135
                        )
×
5136

×
5137
                        return nil
×
5138
                },
×
5139
        )
5140
}
5141

5142
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
5143
// wrapper around the GetNodeAddressesBatch query. It returns a map from
5144
// node ID to a slice of nodeAddress structs.
5145
func batchLoadNodeAddressesHelper(ctx context.Context,
5146
        cfg *sqldb.QueryConfig, db SQLQueries,
5147
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
5148

×
5149
        addrs := make(map[int64][]nodeAddress)
×
5150

×
5151
        return addrs, sqldb.ExecuteBatchQuery(
×
5152
                ctx, cfg, nodeIDs,
×
5153
                func(id int64) int64 {
×
5154
                        return id
×
5155
                },
×
5156
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
5157
                        error) {
×
5158

×
5159
                        return db.GetNodeAddressesBatch(ctx, ids)
×
5160
                },
×
5161
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
5162
                        addrs[addr.NodeID] = append(
×
5163
                                addrs[addr.NodeID], nodeAddress{
×
5164
                                        addrType: dbAddressType(addr.Type),
×
5165
                                        position: addr.Position,
×
5166
                                        address:  addr.Address,
×
5167
                                },
×
5168
                        )
×
5169

×
5170
                        return nil
×
5171
                },
×
5172
        )
5173
}
5174

5175
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
5176
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
5177
// query.
5178
func batchLoadNodeExtraTypesHelper(ctx context.Context,
5179
        cfg *sqldb.QueryConfig, db SQLQueries,
5180
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5181

×
5182
        extraFields := make(map[int64]map[uint64][]byte)
×
5183

×
5184
        callback := func(ctx context.Context,
×
5185
                field sqlc.GraphNodeExtraType) error {
×
5186

×
5187
                if extraFields[field.NodeID] == nil {
×
5188
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
5189
                }
×
5190
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
5191

×
5192
                return nil
×
5193
        }
5194

5195
        return extraFields, sqldb.ExecuteBatchQuery(
×
5196
                ctx, cfg, nodeIDs,
×
5197
                func(id int64) int64 {
×
5198
                        return id
×
5199
                },
×
5200
                func(ctx context.Context, ids []int64) (
5201
                        []sqlc.GraphNodeExtraType, error) {
×
5202

×
5203
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
5204
                },
×
5205
                callback,
5206
        )
5207
}
5208

5209
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
5210
// from the provided sqlc.GraphChannelPolicy records and the
5211
// provided batchChannelData.
5212
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
5213
        channelID uint64, node1, node2 route.Vertex,
5214
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
5215
        *models.ChannelEdgePolicy, error) {
×
5216

×
5217
        pol1, err := buildChanPolicyWithBatchData(
×
5218
                dbPol1, channelID, node2, batchData,
×
5219
        )
×
5220
        if err != nil {
×
5221
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
5222
        }
×
5223

5224
        pol2, err := buildChanPolicyWithBatchData(
×
5225
                dbPol2, channelID, node1, batchData,
×
5226
        )
×
5227
        if err != nil {
×
5228
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
5229
        }
×
5230

5231
        return pol1, pol2, nil
×
5232
}
5233

5234
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
5235
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
5236
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
5237
        channelID uint64, toNode route.Vertex,
5238
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
5239

×
5240
        if dbPol == nil {
×
5241
                return nil, nil
×
5242
        }
×
5243

5244
        var dbPol1Extras map[uint64][]byte
×
5245
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
5246
                dbPol1Extras = extras
×
5247
        } else {
×
5248
                dbPol1Extras = make(map[uint64][]byte)
×
5249
        }
×
5250

5251
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
5252
}
5253

5254
// batchChannelData holds all the related data for a batch of channels.
5255
type batchChannelData struct {
5256
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
5257
        chanfeatures map[int64][]int
5258

5259
        // chanExtras is a map from DB channel ID to a map of TLV type to
5260
        // extra signed field bytes.
5261
        chanExtraTypes map[int64]map[uint64][]byte
5262

5263
        // policyExtras is a map from DB channel policy ID to a map of TLV type
5264
        // to extra signed field bytes.
5265
        policyExtras map[int64]map[uint64][]byte
5266
}
5267

5268
// batchLoadChannelData loads all related data for batches of channels and
5269
// policies.
5270
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
5271
        db SQLQueries, channelIDs []int64,
5272
        policyIDs []int64) (*batchChannelData, error) {
×
5273

×
5274
        batchData := &batchChannelData{
×
5275
                chanfeatures:   make(map[int64][]int),
×
5276
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
5277
                policyExtras:   make(map[int64]map[uint64][]byte),
×
5278
        }
×
5279

×
5280
        // Batch load channel features and extras
×
5281
        var err error
×
5282
        if len(channelIDs) > 0 {
×
5283
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
5284
                        ctx, cfg, db, channelIDs,
×
5285
                )
×
5286
                if err != nil {
×
5287
                        return nil, fmt.Errorf("unable to batch load "+
×
5288
                                "channel features: %w", err)
×
5289
                }
×
5290

5291
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
5292
                        ctx, cfg, db, channelIDs,
×
5293
                )
×
5294
                if err != nil {
×
5295
                        return nil, fmt.Errorf("unable to batch load "+
×
5296
                                "channel extras: %w", err)
×
5297
                }
×
5298
        }
5299

5300
        if len(policyIDs) > 0 {
×
5301
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
5302
                        ctx, cfg, db, policyIDs,
×
5303
                )
×
5304
                if err != nil {
×
5305
                        return nil, fmt.Errorf("unable to batch load "+
×
5306
                                "policy extras: %w", err)
×
5307
                }
×
5308
                batchData.policyExtras = policyExtras
×
5309
        }
5310

5311
        return batchData, nil
×
5312
}
5313

5314
// batchLoadChannelFeaturesHelper loads channel features for a batch of
5315
// channel IDs using ExecuteBatchQuery wrapper around the
5316
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
5317
// slice of feature bits.
5318
func batchLoadChannelFeaturesHelper(ctx context.Context,
5319
        cfg *sqldb.QueryConfig, db SQLQueries,
5320
        channelIDs []int64) (map[int64][]int, error) {
×
5321

×
5322
        features := make(map[int64][]int)
×
5323

×
5324
        return features, sqldb.ExecuteBatchQuery(
×
5325
                ctx, cfg, channelIDs,
×
5326
                func(id int64) int64 {
×
5327
                        return id
×
5328
                },
×
5329
                func(ctx context.Context,
5330
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
5331

×
5332
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
5333
                },
×
5334
                func(ctx context.Context,
5335
                        feature sqlc.GraphChannelFeature) error {
×
5336

×
5337
                        features[feature.ChannelID] = append(
×
5338
                                features[feature.ChannelID],
×
5339
                                int(feature.FeatureBit),
×
5340
                        )
×
5341

×
5342
                        return nil
×
5343
                },
×
5344
        )
5345
}
5346

5347
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
5348
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
5349
// query. It returns a map from DB channel ID to a map of TLV type to extra
5350
// signed field bytes.
5351
func batchLoadChannelExtrasHelper(ctx context.Context,
5352
        cfg *sqldb.QueryConfig, db SQLQueries,
5353
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5354

×
5355
        extras := make(map[int64]map[uint64][]byte)
×
5356

×
5357
        cb := func(ctx context.Context,
×
5358
                extra sqlc.GraphChannelExtraType) error {
×
5359

×
5360
                if extras[extra.ChannelID] == nil {
×
5361
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5362
                }
×
5363
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
5364

×
5365
                return nil
×
5366
        }
5367

5368
        return extras, sqldb.ExecuteBatchQuery(
×
5369
                ctx, cfg, channelIDs,
×
5370
                func(id int64) int64 {
×
5371
                        return id
×
5372
                },
×
5373
                func(ctx context.Context,
5374
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5375

×
5376
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5377
                }, cb,
×
5378
        )
5379
}
5380

5381
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5382
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5383
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5384
// a map of TLV type to extra signed field bytes.
5385
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5386
        cfg *sqldb.QueryConfig, db SQLQueries,
5387
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5388

×
5389
        extras := make(map[int64]map[uint64][]byte)
×
5390

×
5391
        return extras, sqldb.ExecuteBatchQuery(
×
5392
                ctx, cfg, policyIDs,
×
5393
                func(id int64) int64 {
×
5394
                        return id
×
5395
                },
×
5396
                func(ctx context.Context, ids []int64) (
5397
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5398

×
5399
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5400
                },
×
5401
                func(ctx context.Context,
5402
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5403

×
5404
                        if extras[row.PolicyID] == nil {
×
5405
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5406
                        }
×
5407
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5408

×
5409
                        return nil
×
5410
                },
5411
        )
5412
}
5413

5414
// forEachNodePaginated executes a paginated query to process each node in the
5415
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5416
// and applies the provided processNode function to each node.
5417
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5418
        db SQLQueries, protocol lnwire.GossipVersion,
5419
        processNode func(context.Context, int64,
5420
                *models.Node) error) error {
×
5421

×
5422
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5423
                limit int32) ([]sqlc.GraphNode, error) {
×
5424

×
5425
                return db.ListNodesPaginated(
×
5426
                        ctx, sqlc.ListNodesPaginatedParams{
×
5427
                                Version: int16(protocol),
×
5428
                                ID:      lastID,
×
5429
                                Limit:   limit,
×
5430
                        },
×
5431
                )
×
5432
        }
×
5433

5434
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5435
                return node.ID
×
5436
        }
×
5437

5438
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5439
                return node.ID, nil
×
5440
        }
×
5441

5442
        batchQueryFunc := func(ctx context.Context,
×
5443
                nodeIDs []int64) (*batchNodeData, error) {
×
5444

×
5445
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5446
        }
×
5447

5448
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5449
                batchData *batchNodeData) error {
×
5450

×
5451
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
5452
                if err != nil {
×
5453
                        return fmt.Errorf("unable to build "+
×
5454
                                "node(id=%d): %w", dbNode.ID, err)
×
5455
                }
×
5456

5457
                return processNode(ctx, dbNode.ID, node)
×
5458
        }
5459

5460
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5461
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5462
                collectFunc, batchQueryFunc, processItem,
×
5463
        )
×
5464
}
5465

5466
// forEachChannelWithPolicies executes a paginated query to process each channel
5467
// with policies in the graph.
5468
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5469
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5470
                *models.ChannelEdgePolicy,
5471
                *models.ChannelEdgePolicy) error) error {
×
5472

×
5473
        type channelBatchIDs struct {
×
5474
                channelID int64
×
5475
                policyIDs []int64
×
5476
        }
×
5477

×
5478
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5479
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5480
                error) {
×
5481

×
5482
                return db.ListChannelsWithPoliciesPaginated(
×
5483
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
5484
                                Version: int16(lnwire.GossipVersion1),
×
5485
                                ID:      lastID,
×
5486
                                Limit:   limit,
×
5487
                        },
×
5488
                )
×
5489
        }
×
5490

5491
        extractPageCursor := func(
×
5492
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5493

×
5494
                return row.GraphChannel.ID
×
5495
        }
×
5496

5497
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
5498
                channelBatchIDs, error) {
×
5499

×
5500
                ids := channelBatchIDs{
×
5501
                        channelID: row.GraphChannel.ID,
×
5502
                }
×
5503

×
5504
                // Extract policy IDs from the row.
×
5505
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5506
                if err != nil {
×
5507
                        return ids, err
×
5508
                }
×
5509

5510
                if dbPol1 != nil {
×
5511
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5512
                }
×
5513
                if dbPol2 != nil {
×
5514
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5515
                }
×
5516

5517
                return ids, nil
×
5518
        }
5519

5520
        batchDataFunc := func(ctx context.Context,
×
5521
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5522

×
5523
                // Separate channel IDs from policy IDs.
×
5524
                var (
×
5525
                        channelIDs = make([]int64, len(allIDs))
×
5526
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5527
                )
×
5528

×
5529
                for i, ids := range allIDs {
×
5530
                        channelIDs[i] = ids.channelID
×
5531
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5532
                }
×
5533

5534
                return batchLoadChannelData(
×
5535
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5536
                )
×
5537
        }
5538

5539
        processItem := func(ctx context.Context,
×
5540
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5541
                batchData *batchChannelData) error {
×
5542

×
5543
                node1, node2, err := buildNodeVertices(
×
5544
                        row.Node1Pubkey, row.Node2Pubkey,
×
5545
                )
×
5546
                if err != nil {
×
5547
                        return err
×
5548
                }
×
5549

5550
                edge, err := buildEdgeInfoWithBatchData(
×
5551
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5552
                        batchData,
×
5553
                )
×
5554
                if err != nil {
×
5555
                        return fmt.Errorf("unable to build channel info: %w",
×
5556
                                err)
×
5557
                }
×
5558

5559
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5560
                if err != nil {
×
5561
                        return err
×
5562
                }
×
5563

5564
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5565
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
5566
                )
×
5567
                if err != nil {
×
5568
                        return err
×
5569
                }
×
5570

5571
                return processChannel(edge, p1, p2)
×
5572
        }
5573

5574
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5575
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5576
                collectFunc, batchDataFunc, processItem,
×
5577
        )
×
5578
}
5579

5580
// buildDirectedChannel builds a DirectedChannel instance from the provided
5581
// data.
5582
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5583
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5584
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5585
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
5586

×
5587
        node1, node2, err := buildNodeVertices(
×
5588
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5589
        )
×
5590
        if err != nil {
×
5591
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5592
        }
×
5593

5594
        edge, err := buildEdgeInfoWithBatchData(
×
5595
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5596
        )
×
5597
        if err != nil {
×
5598
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5599
        }
×
5600

5601
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
5602
        if err != nil {
×
5603
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5604
                        err)
×
5605
        }
×
5606

5607
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5608
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5609
                channelBatchData,
×
5610
        )
×
5611
        if err != nil {
×
5612
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5613
                        err)
×
5614
        }
×
5615

5616
        // Determine outgoing and incoming policy for this specific node.
5617
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5618
        p2ToNode := channelRow.GraphChannel.NodeID1
×
5619
        outPolicy, inPolicy := p1, p2
×
5620
        if (p1 != nil && p1ToNode == nodeID) ||
×
5621
                (p2 != nil && p2ToNode != nodeID) {
×
5622

×
5623
                outPolicy, inPolicy = p2, p1
×
5624
        }
×
5625

5626
        // Build cached policy.
5627
        var cachedInPolicy *models.CachedEdgePolicy
×
5628
        if inPolicy != nil {
×
5629
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5630
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5631
                cachedInPolicy.ToNodeFeatures = features
×
5632
        }
×
5633

5634
        // Extract inbound fee.
5635
        var inboundFee lnwire.Fee
×
5636
        if outPolicy != nil {
×
5637
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5638
                        inboundFee = fee
×
5639
                })
×
5640
        }
5641

5642
        // Build directed channel.
5643
        directedChannel := &DirectedChannel{
×
5644
                ChannelID:    edge.ChannelID,
×
5645
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
5646
                OtherNode:    edge.NodeKey2Bytes,
×
5647
                Capacity:     edge.Capacity,
×
5648
                OutPolicySet: outPolicy != nil,
×
5649
                InPolicy:     cachedInPolicy,
×
5650
                InboundFee:   inboundFee,
×
5651
        }
×
5652

×
5653
        if nodePub == edge.NodeKey2Bytes {
×
5654
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5655
        }
×
5656

5657
        return directedChannel, nil
×
5658
}
5659

5660
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
5661
// provided rows. It uses batch loading for channels, policies, and nodes.
5662
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
5663
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
5664

×
5665
        var (
×
5666
                channelIDs = make([]int64, len(rows))
×
5667
                policyIDs  = make([]int64, 0, len(rows)*2)
×
5668
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
5669

×
5670
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
5671
                nodeIDSet = make(map[int64]bool)
×
5672

×
5673
                // edges will hold the final channel edges built from the rows.
×
5674
                edges = make([]ChannelEdge, 0, len(rows))
×
5675
        )
×
5676

×
5677
        // Collect all IDs needed for batch loading.
×
5678
        for i, row := range rows {
×
5679
                channelIDs[i] = row.Channel().ID
×
5680

×
5681
                // Collect policy IDs
×
5682
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5683
                if err != nil {
×
5684
                        return nil, fmt.Errorf("unable to extract channel "+
×
5685
                                "policies: %w", err)
×
5686
                }
×
5687
                if dbPol1 != nil {
×
5688
                        policyIDs = append(policyIDs, dbPol1.ID)
×
5689
                }
×
5690
                if dbPol2 != nil {
×
5691
                        policyIDs = append(policyIDs, dbPol2.ID)
×
5692
                }
×
5693

5694
                var (
×
5695
                        node1ID = row.Node1().ID
×
5696
                        node2ID = row.Node2().ID
×
5697
                )
×
5698

×
5699
                // Collect unique node IDs.
×
5700
                if !nodeIDSet[node1ID] {
×
5701
                        nodeIDs = append(nodeIDs, node1ID)
×
5702
                        nodeIDSet[node1ID] = true
×
5703
                }
×
5704

5705
                if !nodeIDSet[node2ID] {
×
5706
                        nodeIDs = append(nodeIDs, node2ID)
×
5707
                        nodeIDSet[node2ID] = true
×
5708
                }
×
5709
        }
5710

5711
        // Batch the data for all the channels and policies.
5712
        channelBatchData, err := batchLoadChannelData(
×
5713
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5714
        )
×
5715
        if err != nil {
×
5716
                return nil, fmt.Errorf("unable to batch load channel and "+
×
5717
                        "policy data: %w", err)
×
5718
        }
×
5719

5720
        // Batch the data for all the nodes.
5721
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
5722
        if err != nil {
×
5723
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
5724
                        err)
×
5725
        }
×
5726

5727
        // Build all channel edges using batch data.
5728
        for _, row := range rows {
×
5729
                // Build nodes using batch data.
×
5730
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
5731
                if err != nil {
×
5732
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
5733
                }
×
5734

5735
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
5736
                if err != nil {
×
5737
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
5738
                }
×
5739

5740
                // Build channel info using batch data.
5741
                channel, err := buildEdgeInfoWithBatchData(
×
5742
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
5743
                        node2.PubKeyBytes, channelBatchData,
×
5744
                )
×
5745
                if err != nil {
×
5746
                        return nil, fmt.Errorf("unable to build channel "+
×
5747
                                "info: %w", err)
×
5748
                }
×
5749

5750
                // Extract and build policies using batch data.
5751
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5752
                if err != nil {
×
5753
                        return nil, fmt.Errorf("unable to extract channel "+
×
5754
                                "policies: %w", err)
×
5755
                }
×
5756

5757
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5758
                        dbPol1, dbPol2, channel.ChannelID,
×
5759
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
5760
                )
×
5761
                if err != nil {
×
5762
                        return nil, fmt.Errorf("unable to build channel "+
×
5763
                                "policies: %w", err)
×
5764
                }
×
5765

5766
                edges = append(edges, ChannelEdge{
×
5767
                        Info:    channel,
×
5768
                        Policy1: p1,
×
5769
                        Policy2: p2,
×
5770
                        Node1:   node1,
×
5771
                        Node2:   node2,
×
5772
                })
×
5773
        }
5774

5775
        return edges, nil
×
5776
}
5777

5778
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
5779
// instances from the provided rows using batch loading for channel data.
5780
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
5781
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
5782
        []*models.ChannelEdgeInfo, []int64, error) {
×
5783

×
5784
        if len(rows) == 0 {
×
5785
                return nil, nil, nil
×
5786
        }
×
5787

5788
        // Collect all the channel IDs needed for batch loading.
5789
        channelIDs := make([]int64, len(rows))
×
5790
        for i, row := range rows {
×
5791
                channelIDs[i] = row.Channel().ID
×
5792
        }
×
5793

5794
        // Batch load the channel data.
5795
        channelBatchData, err := batchLoadChannelData(
×
5796
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
5797
        )
×
5798
        if err != nil {
×
5799
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
5800
                        "data: %w", err)
×
5801
        }
×
5802

5803
        // Build all channel edges using batch data.
5804
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
5805
        for _, row := range rows {
×
5806
                node1, node2, err := buildNodeVertices(
×
5807
                        row.Node1Pub(), row.Node2Pub(),
×
5808
                )
×
5809
                if err != nil {
×
5810
                        return nil, nil, err
×
5811
                }
×
5812

5813
                // Build channel info using batch data
5814
                info, err := buildEdgeInfoWithBatchData(
×
5815
                        cfg.ChainHash, row.Channel(), node1, node2,
×
5816
                        channelBatchData,
×
5817
                )
×
5818
                if err != nil {
×
5819
                        return nil, nil, err
×
5820
                }
×
5821

5822
                edges = append(edges, info)
×
5823
        }
5824

5825
        return edges, channelIDs, nil
×
5826
}
5827

5828
// handleZombieMarking is a helper function that handles the logic of
5829
// marking a channel as a zombie in the database. It takes into account whether
5830
// we are in strict zombie pruning mode, and adjusts the node public keys
5831
// accordingly based on the last update timestamps of the channel policies.
5832
func handleZombieMarking(ctx context.Context, db SQLQueries,
5833
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
5834
        strictZombiePruning bool, scid uint64) error {
×
5835

×
5836
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
5837

×
5838
        if strictZombiePruning {
×
5839
                var e1UpdateTime, e2UpdateTime *time.Time
×
5840
                if row.Policy1LastUpdate.Valid {
×
5841
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
5842
                        e1UpdateTime = &e1Time
×
5843
                }
×
5844
                if row.Policy2LastUpdate.Valid {
×
5845
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
5846
                        e2UpdateTime = &e2Time
×
5847
                }
×
5848

5849
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
5850
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
5851
                        e2UpdateTime,
×
5852
                )
×
5853
        }
5854

5855
        return db.UpsertZombieChannel(
×
5856
                ctx, sqlc.UpsertZombieChannelParams{
×
5857
                        Version:  int16(lnwire.GossipVersion1),
×
5858
                        Scid:     channelIDToBytes(scid),
×
5859
                        NodeKey1: nodeKey1[:],
×
5860
                        NodeKey2: nodeKey2[:],
×
5861
                },
×
5862
        )
×
5863
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc