• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 18197857992

02 Oct 2025 03:32PM UTC coverage: 66.622% (-0.02%) from 66.646%
18197857992

Pull #10267

github

web-flow
Merge 0d9bfccfe into 1c2ff4a7e
Pull Request #10267: [g175] multi: small G175 preparations

24 of 141 new or added lines in 12 files covered. (17.02%)

64 existing lines in 20 files now uncovered.

137216 of 205963 relevant lines covered (66.62%)

21302.01 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "iter"
11
        "maps"
12
        "math"
13
        "net"
14
        "slices"
15
        "strconv"
16
        "sync"
17
        "time"
18

19
        "github.com/btcsuite/btcd/btcec/v2"
20
        "github.com/btcsuite/btcd/btcutil"
21
        "github.com/btcsuite/btcd/chaincfg/chainhash"
22
        "github.com/btcsuite/btcd/wire"
23
        "github.com/lightningnetwork/lnd/aliasmgr"
24
        "github.com/lightningnetwork/lnd/batch"
25
        "github.com/lightningnetwork/lnd/fn/v2"
26
        "github.com/lightningnetwork/lnd/graph/db/models"
27
        "github.com/lightningnetwork/lnd/lnwire"
28
        "github.com/lightningnetwork/lnd/routing/route"
29
        "github.com/lightningnetwork/lnd/sqldb"
30
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
31
        "github.com/lightningnetwork/lnd/tlv"
32
        "github.com/lightningnetwork/lnd/tor"
33
)
34

35
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
36
// execute queries against the SQL graph tables.
37
//
38
//nolint:ll,interfacebloat
39
type SQLQueries interface {
40
        /*
41
                Node queries.
42
        */
43
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
44
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
45
        GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
46
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
47
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
48
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
49
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
50
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
51
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
52
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
53
        DeleteNode(ctx context.Context, id int64) error
54

55
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
56
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
57
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
58
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
59

60
        UpsertNodeAddress(ctx context.Context, arg sqlc.UpsertNodeAddressParams) error
61
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
62
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
63
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
64

65
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
66
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
67
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
68
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
69
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
70

71
        /*
72
                Source node queries.
73
        */
74
        AddSourceNode(ctx context.Context, nodeID int64) error
75
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
76

77
        /*
78
                Channel queries.
79
        */
80
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
81
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
82
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
83
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
84
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
85
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
86
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
87
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
88
        GetChannelsByIDs(ctx context.Context, ids []int64) ([]sqlc.GetChannelsByIDsRow, error)
89
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
90
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
91
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
92
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
93
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
94
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
95
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
96
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
97
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
98
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
99
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
100
        DeleteChannels(ctx context.Context, ids []int64) error
101

102
        UpsertChannelExtraType(ctx context.Context, arg sqlc.UpsertChannelExtraTypeParams) error
103
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
104
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
105
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
106

107
        /*
108
                Channel Policy table queries.
109
        */
110
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
111
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
112
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
113

114
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
115
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
116
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
117

118
        /*
119
                Zombie index queries.
120
        */
121
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
122
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
123
        GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error)
124
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
125
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
126
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
127

128
        /*
129
                Prune log table queries.
130
        */
131
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
132
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
133
        GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]sqlc.GraphPruneLog, error)
134
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
135
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
136

137
        /*
138
                Closed SCID table queries.
139
        */
140
        InsertClosedChannel(ctx context.Context, scid []byte) error
141
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
142
        GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
143

144
        /*
145
                Migration specific queries.
146

147
                NOTE: these should not be used in code other than migrations.
148
                Once sqldbv2 is in place, these can be removed from this struct
149
                as then migrations will have their own dedicated queries
150
                structs.
151
        */
152
        InsertNodeMig(ctx context.Context, arg sqlc.InsertNodeMigParams) (int64, error)
153
        InsertChannelMig(ctx context.Context, arg sqlc.InsertChannelMigParams) (int64, error)
154
        InsertEdgePolicyMig(ctx context.Context, arg sqlc.InsertEdgePolicyMigParams) (int64, error)
155
}
156

157
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
158
// database operations.
159
type BatchedSQLQueries interface {
160
        SQLQueries
161
        sqldb.BatchedTx[SQLQueries]
162
}
163

164
// SQLStore is an implementation of the V1Store interface that uses a SQL
165
// database as the backend.
166
type SQLStore struct {
167
        cfg *SQLStoreConfig
168
        db  BatchedSQLQueries
169

170
        // cacheMu guards all caches (rejectCache and chanCache). If
171
        // this mutex will be acquired at the same time as the DB mutex then
172
        // the cacheMu MUST be acquired first to prevent deadlock.
173
        cacheMu     sync.RWMutex
174
        rejectCache *rejectCache
175
        chanCache   *channelCache
176

177
        chanScheduler batch.Scheduler[SQLQueries]
178
        nodeScheduler batch.Scheduler[SQLQueries]
179

180
        srcNodes  map[lnwire.GossipVersion]*srcNodeInfo
181
        srcNodeMu sync.Mutex
182
}
183

184
// A compile-time assertion to ensure that SQLStore implements the V1Store
185
// interface.
186
var _ V1Store = (*SQLStore)(nil)
187

188
// SQLStoreConfig holds the configuration for the SQLStore.
189
type SQLStoreConfig struct {
190
        // ChainHash is the genesis hash for the chain that all the gossip
191
        // messages in this store are aimed at.
192
        ChainHash chainhash.Hash
193

194
        // QueryConfig holds configuration values for SQL queries.
195
        QueryCfg *sqldb.QueryConfig
196
}
197

198
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
199
// storage backend.
200
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
201
        options ...StoreOptionModifier) (*SQLStore, error) {
×
202

×
203
        opts := DefaultOptions()
×
204
        for _, o := range options {
×
205
                o(opts)
×
206
        }
×
207

208
        if opts.NoMigration {
×
209
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
210
                        "supported for SQL stores")
×
211
        }
×
212

213
        s := &SQLStore{
×
214
                cfg:         cfg,
×
215
                db:          db,
×
216
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
217
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
NEW
218
                srcNodes:    make(map[lnwire.GossipVersion]*srcNodeInfo),
×
219
        }
×
220

×
221
        s.chanScheduler = batch.NewTimeScheduler(
×
222
                db, &s.cacheMu, opts.BatchCommitInterval,
×
223
        )
×
224
        s.nodeScheduler = batch.NewTimeScheduler(
×
225
                db, nil, opts.BatchCommitInterval,
×
226
        )
×
227

×
228
        return s, nil
×
229
}
230

231
// AddNode adds a vertex/node to the graph database. If the node is not
232
// in the database from before, this will add a new, unconnected one to the
233
// graph. If it is present from before, this will update that node's
234
// information.
235
//
236
// NOTE: part of the V1Store interface.
237
func (s *SQLStore) AddNode(ctx context.Context,
238
        node *models.Node, opts ...batch.SchedulerOption) error {
×
239

×
240
        r := &batch.Request[SQLQueries]{
×
241
                Opts: batch.NewSchedulerOptions(opts...),
×
242
                Do: func(queries SQLQueries) error {
×
243
                        _, err := upsertNode(ctx, queries, node)
×
244
                        return err
×
245
                },
×
246
        }
247

248
        return s.nodeScheduler.Execute(ctx, r)
×
249
}
250

251
// FetchNode attempts to look up a target node by its identity public
252
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
253
// returned.
254
//
255
// NOTE: part of the V1Store interface.
256
func (s *SQLStore) FetchNode(ctx context.Context,
257
        pubKey route.Vertex) (*models.Node, error) {
×
258

×
259
        var node *models.Node
×
260
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
261
                var err error
×
262
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, pubKey)
×
263

×
264
                return err
×
265
        }, sqldb.NoOpReset)
×
266
        if err != nil {
×
267
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
268
        }
×
269

270
        return node, nil
×
271
}
272

273
// HasNode determines if the graph has a vertex identified by the
274
// target node identity public key. If the node exists in the database, a
275
// timestamp of when the data for the node was lasted updated is returned along
276
// with a true boolean. Otherwise, an empty time.Time is returned with a false
277
// boolean.
278
//
279
// NOTE: part of the V1Store interface.
280
func (s *SQLStore) HasNode(ctx context.Context,
281
        pubKey [33]byte) (time.Time, bool, error) {
×
282

×
283
        var (
×
284
                exists     bool
×
285
                lastUpdate time.Time
×
286
        )
×
287
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
288
                dbNode, err := db.GetNodeByPubKey(
×
289
                        ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
290
                                Version: int16(lnwire.GossipVersion1),
×
291
                                PubKey:  pubKey[:],
×
292
                        },
×
293
                )
×
294
                if errors.Is(err, sql.ErrNoRows) {
×
295
                        return nil
×
296
                } else if err != nil {
×
297
                        return fmt.Errorf("unable to fetch node: %w", err)
×
298
                }
×
299

300
                exists = true
×
301

×
302
                if dbNode.LastUpdate.Valid {
×
303
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
304
                }
×
305

306
                return nil
×
307
        }, sqldb.NoOpReset)
308
        if err != nil {
×
309
                return time.Time{}, false,
×
310
                        fmt.Errorf("unable to fetch node: %w", err)
×
311
        }
×
312

313
        return lastUpdate, exists, nil
×
314
}
315

316
// AddrsForNode returns all known addresses for the target node public key
317
// that the graph DB is aware of. The returned boolean indicates if the
318
// given node is unknown to the graph DB or not.
319
//
320
// NOTE: part of the V1Store interface.
321
func (s *SQLStore) AddrsForNode(ctx context.Context,
322
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
323

×
324
        var (
×
325
                addresses []net.Addr
×
326
                known     bool
×
327
        )
×
328
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
329
                // First, check if the node exists and get its DB ID if it
×
330
                // does.
×
331
                dbID, err := db.GetNodeIDByPubKey(
×
332
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
NEW
333
                                Version: int16(lnwire.GossipVersion1),
×
334
                                PubKey:  nodePub.SerializeCompressed(),
×
335
                        },
×
336
                )
×
337
                if errors.Is(err, sql.ErrNoRows) {
×
338
                        return nil
×
339
                }
×
340

341
                known = true
×
342

×
343
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
344
                if err != nil {
×
345
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
346
                                err)
×
347
                }
×
348

349
                return nil
×
350
        }, sqldb.NoOpReset)
351
        if err != nil {
×
352
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
353
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
354
        }
×
355

356
        return known, addresses, nil
×
357
}
358

359
// DeleteNode starts a new database transaction to remove a vertex/node
360
// from the database according to the node's public key.
361
//
362
// NOTE: part of the V1Store interface.
363
func (s *SQLStore) DeleteNode(ctx context.Context,
364
        pubKey route.Vertex) error {
×
365

×
366
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
367
                res, err := db.DeleteNodeByPubKey(
×
368
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
NEW
369
                                Version: int16(lnwire.GossipVersion1),
×
370
                                PubKey:  pubKey[:],
×
371
                        },
×
372
                )
×
373
                if err != nil {
×
374
                        return err
×
375
                }
×
376

377
                rows, err := res.RowsAffected()
×
378
                if err != nil {
×
379
                        return err
×
380
                }
×
381

382
                if rows == 0 {
×
383
                        return ErrGraphNodeNotFound
×
384
                } else if rows > 1 {
×
385
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
386
                }
×
387

388
                return err
×
389
        }, sqldb.NoOpReset)
390
        if err != nil {
×
391
                return fmt.Errorf("unable to delete node: %w", err)
×
392
        }
×
393

394
        return nil
×
395
}
396

397
// FetchNodeFeatures returns the features of the given node. If no features are
398
// known for the node, an empty feature vector is returned.
399
//
400
// NOTE: this is part of the graphdb.NodeTraverser interface.
401
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
402
        *lnwire.FeatureVector, error) {
×
403

×
404
        ctx := context.TODO()
×
405

×
406
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
407
}
×
408

409
// DisabledChannelIDs returns the channel ids of disabled channels.
410
// A channel is disabled when two of the associated ChanelEdgePolicies
411
// have their disabled bit on.
412
//
413
// NOTE: part of the V1Store interface.
414
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
415
        var (
×
416
                ctx     = context.TODO()
×
417
                chanIDs []uint64
×
418
        )
×
419
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
420
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
421
                if err != nil {
×
422
                        return fmt.Errorf("unable to fetch disabled "+
×
423
                                "channels: %w", err)
×
424
                }
×
425

426
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
427

×
428
                return nil
×
429
        }, sqldb.NoOpReset)
430
        if err != nil {
×
431
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
432
                        err)
×
433
        }
×
434

435
        return chanIDs, nil
×
436
}
437

438
// LookupAlias attempts to return the alias as advertised by the target node.
439
//
440
// NOTE: part of the V1Store interface.
441
func (s *SQLStore) LookupAlias(ctx context.Context,
442
        pub *btcec.PublicKey) (string, error) {
×
443

×
444
        var alias string
×
445
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
446
                dbNode, err := db.GetNodeByPubKey(
×
447
                        ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
448
                                Version: int16(lnwire.GossipVersion1),
×
449
                                PubKey:  pub.SerializeCompressed(),
×
450
                        },
×
451
                )
×
452
                if errors.Is(err, sql.ErrNoRows) {
×
453
                        return ErrNodeAliasNotFound
×
454
                } else if err != nil {
×
455
                        return fmt.Errorf("unable to fetch node: %w", err)
×
456
                }
×
457

458
                if !dbNode.Alias.Valid {
×
459
                        return ErrNodeAliasNotFound
×
460
                }
×
461

462
                alias = dbNode.Alias.String
×
463

×
464
                return nil
×
465
        }, sqldb.NoOpReset)
466
        if err != nil {
×
467
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
468
        }
×
469

470
        return alias, nil
×
471
}
472

473
// SourceNode returns the source node of the graph. The source node is treated
474
// as the center node within a star-graph. This method may be used to kick off
475
// a path finding algorithm in order to explore the reachability of another
476
// node based off the source node.
477
//
478
// NOTE: part of the V1Store interface.
479
func (s *SQLStore) SourceNode(ctx context.Context) (*models.Node,
480
        error) {
×
481

×
482
        var node *models.Node
×
483
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
484
                _, nodePub, err := s.getSourceNode(
×
NEW
485
                        ctx, db, lnwire.GossipVersion1,
×
NEW
486
                )
×
487
                if err != nil {
×
488
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
489
                                err)
×
490
                }
×
491

492
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, nodePub)
×
493

×
494
                return err
×
495
        }, sqldb.NoOpReset)
496
        if err != nil {
×
497
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
498
        }
×
499

500
        return node, nil
×
501
}
502

503
// SetSourceNode sets the source node within the graph database. The source
504
// node is to be used as the center of a star-graph within path finding
505
// algorithms.
506
//
507
// NOTE: part of the V1Store interface.
508
func (s *SQLStore) SetSourceNode(ctx context.Context,
509
        node *models.Node) error {
×
510

×
511
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
512
                id, err := upsertNode(ctx, db, node)
×
513
                if err != nil {
×
514
                        return fmt.Errorf("unable to upsert source node: %w",
×
515
                                err)
×
516
                }
×
517

518
                // Make sure that if a source node for this version is already
519
                // set, then the ID is the same as the one we are about to set.
NEW
520
                dbSourceNodeID, _, err := s.getSourceNode(
×
NEW
521
                        ctx, db, lnwire.GossipVersion1,
×
NEW
522
                )
×
523
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
524
                        return fmt.Errorf("unable to fetch source node: %w",
×
525
                                err)
×
526
                } else if err == nil {
×
527
                        if dbSourceNodeID != id {
×
528
                                return fmt.Errorf("v1 source node already "+
×
529
                                        "set to a different node: %d vs %d",
×
530
                                        dbSourceNodeID, id)
×
531
                        }
×
532

533
                        return nil
×
534
                }
535

536
                return db.AddSourceNode(ctx, id)
×
537
        }, sqldb.NoOpReset)
538
}
539

540
// NodeUpdatesInHorizon returns all the known lightning node which have an
541
// update timestamp within the passed range. This method can be used by two
542
// nodes to quickly determine if they have the same set of up to date node
543
// announcements.
544
//
545
// NOTE: This is part of the V1Store interface.
546
func (s *SQLStore) NodeUpdatesInHorizon(startTime, endTime time.Time,
547
        opts ...IteratorOption) iter.Seq2[models.Node, error] {
×
548

×
549
        cfg := defaultIteratorConfig()
×
550
        for _, opt := range opts {
×
551
                opt(cfg)
×
552
        }
×
553

554
        return func(yield func(models.Node, error) bool) {
×
555
                var (
×
556
                        ctx            = context.TODO()
×
557
                        lastUpdateTime sql.NullInt64
×
558
                        lastPubKey     = make([]byte, 33)
×
559
                        hasMore        = true
×
560
                )
×
561

×
562
                // Each iteration, we'll read a batch amount of nodes, yield
×
563
                // them, then decide is we have more or not.
×
564
                for hasMore {
×
565
                        var batch []models.Node
×
566

×
567
                        //nolint:ll
×
568
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
569
                                //nolint:ll
×
570
                                params := sqlc.GetNodesByLastUpdateRangeParams{
×
571
                                        StartTime: sqldb.SQLInt64(
×
572
                                                startTime.Unix(),
×
573
                                        ),
×
574
                                        EndTime: sqldb.SQLInt64(
×
575
                                                endTime.Unix(),
×
576
                                        ),
×
577
                                        LastUpdate: lastUpdateTime,
×
578
                                        LastPubKey: lastPubKey,
×
579
                                        OnlyPublic: sql.NullBool{
×
580
                                                Bool:  cfg.iterPublicNodes,
×
581
                                                Valid: true,
×
582
                                        },
×
583
                                        MaxResults: sqldb.SQLInt32(
×
584
                                                cfg.nodeUpdateIterBatchSize,
×
585
                                        ),
×
586
                                }
×
587
                                rows, err := db.GetNodesByLastUpdateRange(
×
588
                                        ctx, params,
×
589
                                )
×
590
                                if err != nil {
×
591
                                        return err
×
592
                                }
×
593

594
                                hasMore = len(rows) == cfg.nodeUpdateIterBatchSize
×
595

×
596
                                err = forEachNodeInBatch(
×
597
                                        ctx, s.cfg.QueryCfg, db, rows,
×
598
                                        func(_ int64, node *models.Node) error {
×
599
                                                batch = append(batch, *node)
×
600

×
601
                                                // Update pagination cursors
×
602
                                                // based on the last processed
×
603
                                                // node.
×
604
                                                lastUpdateTime = sql.NullInt64{
×
605
                                                        Int64: node.LastUpdate.
×
606
                                                                Unix(),
×
607
                                                        Valid: true,
×
608
                                                }
×
609
                                                lastPubKey = node.PubKeyBytes[:]
×
610

×
611
                                                return nil
×
612
                                        },
×
613
                                )
614
                                if err != nil {
×
615
                                        return fmt.Errorf("unable to build "+
×
616
                                                "nodes: %w", err)
×
617
                                }
×
618

619
                                return nil
×
620
                        }, func() {
×
621
                                batch = []models.Node{}
×
622
                        })
×
623

624
                        if err != nil {
×
625
                                log.Errorf("NodeUpdatesInHorizon batch "+
×
626
                                        "error: %v", err)
×
627

×
628
                                yield(models.Node{}, err)
×
629

×
630
                                return
×
631
                        }
×
632

633
                        for _, node := range batch {
×
634
                                if !yield(node, nil) {
×
635
                                        return
×
636
                                }
×
637
                        }
638

639
                        // If the batch didn't yield anything, then we're done.
640
                        if len(batch) == 0 {
×
641
                                break
×
642
                        }
643
                }
644
        }
645
}
646

647
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
648
// undirected edge from the two target nodes are created. The information stored
649
// denotes the static attributes of the channel, such as the channelID, the keys
650
// involved in creation of the channel, and the set of features that the channel
651
// supports. The chanPoint and chanID are used to uniquely identify the edge
652
// globally within the database.
653
//
654
// NOTE: part of the V1Store interface.
655
func (s *SQLStore) AddChannelEdge(ctx context.Context,
656
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
657

×
658
        var alreadyExists bool
×
659
        r := &batch.Request[SQLQueries]{
×
660
                Opts: batch.NewSchedulerOptions(opts...),
×
661
                Reset: func() {
×
662
                        alreadyExists = false
×
663
                },
×
664
                Do: func(tx SQLQueries) error {
×
665
                        chanIDB := channelIDToBytes(edge.ChannelID)
×
666

×
667
                        // Make sure that the channel doesn't already exist. We
×
668
                        // do this explicitly instead of relying on catching a
×
669
                        // unique constraint error because relying on SQL to
×
670
                        // throw that error would abort the entire batch of
×
671
                        // transactions.
×
672
                        _, err := tx.GetChannelBySCID(
×
673
                                ctx, sqlc.GetChannelBySCIDParams{
×
674
                                        Scid:    chanIDB,
×
NEW
675
                                        Version: int16(lnwire.GossipVersion1),
×
676
                                },
×
677
                        )
×
678
                        if err == nil {
×
679
                                alreadyExists = true
×
680
                                return nil
×
681
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
682
                                return fmt.Errorf("unable to fetch channel: %w",
×
683
                                        err)
×
684
                        }
×
685

686
                        return insertChannel(ctx, tx, edge)
×
687
                },
688
                OnCommit: func(err error) error {
×
689
                        switch {
×
690
                        case err != nil:
×
691
                                return err
×
692
                        case alreadyExists:
×
693
                                return ErrEdgeAlreadyExist
×
694
                        default:
×
695
                                s.rejectCache.remove(edge.ChannelID)
×
696
                                s.chanCache.remove(edge.ChannelID)
×
697
                                return nil
×
698
                        }
699
                },
700
        }
701

702
        return s.chanScheduler.Execute(ctx, r)
×
703
}
704

705
// HighestChanID returns the "highest" known channel ID in the channel graph.
706
// This represents the "newest" channel from the PoV of the chain. This method
707
// can be used by peers to quickly determine if their graphs are in sync.
708
//
709
// NOTE: This is part of the V1Store interface.
710
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
711
        var highestChanID uint64
×
712
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
713
                chanID, err := db.HighestSCID(ctx, int16(lnwire.GossipVersion1))
×
714
                if errors.Is(err, sql.ErrNoRows) {
×
715
                        return nil
×
716
                } else if err != nil {
×
717
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
718
                                err)
×
719
                }
×
720

721
                highestChanID = byteOrder.Uint64(chanID)
×
722

×
723
                return nil
×
724
        }, sqldb.NoOpReset)
725
        if err != nil {
×
726
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
727
        }
×
728

729
        return highestChanID, nil
×
730
}
731

732
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
733
// within the database for the referenced channel. The `flags` attribute within
734
// the ChannelEdgePolicy determines which of the directed edges are being
735
// updated. If the flag is 1, then the first node's information is being
736
// updated, otherwise it's the second node's information. The node ordering is
737
// determined by the lexicographical ordering of the identity public keys of the
738
// nodes on either side of the channel.
739
//
740
// NOTE: part of the V1Store interface.
741
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
742
        edge *models.ChannelEdgePolicy,
743
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
744

×
745
        var (
×
746
                isUpdate1    bool
×
747
                edgeNotFound bool
×
748
                from, to     route.Vertex
×
749
        )
×
750

×
751
        r := &batch.Request[SQLQueries]{
×
752
                Opts: batch.NewSchedulerOptions(opts...),
×
753
                Reset: func() {
×
754
                        isUpdate1 = false
×
755
                        edgeNotFound = false
×
756
                },
×
757
                Do: func(tx SQLQueries) error {
×
758
                        var err error
×
759
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
760
                                ctx, tx, edge,
×
761
                        )
×
762
                        if err != nil {
×
763
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
764
                        }
×
765

766
                        // Silence ErrEdgeNotFound so that the batch can
767
                        // succeed, but propagate the error via local state.
768
                        if errors.Is(err, ErrEdgeNotFound) {
×
769
                                edgeNotFound = true
×
770
                                return nil
×
771
                        }
×
772

773
                        return err
×
774
                },
775
                OnCommit: func(err error) error {
×
776
                        switch {
×
777
                        case err != nil:
×
778
                                return err
×
779
                        case edgeNotFound:
×
780
                                return ErrEdgeNotFound
×
781
                        default:
×
782
                                s.updateEdgeCache(edge, isUpdate1)
×
783
                                return nil
×
784
                        }
785
                },
786
        }
787

788
        err := s.chanScheduler.Execute(ctx, r)
×
789

×
790
        return from, to, err
×
791
}
792

793
// updateEdgeCache updates our reject and channel caches with the new
794
// edge policy information.
795
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
796
        isUpdate1 bool) {
×
797

×
798
        // If an entry for this channel is found in reject cache, we'll modify
×
799
        // the entry with the updated timestamp for the direction that was just
×
800
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
801
        // during the next query for this edge.
×
802
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
803
                if isUpdate1 {
×
804
                        entry.upd1Time = e.LastUpdate.Unix()
×
805
                } else {
×
806
                        entry.upd2Time = e.LastUpdate.Unix()
×
807
                }
×
808
                s.rejectCache.insert(e.ChannelID, entry)
×
809
        }
810

811
        // If an entry for this channel is found in channel cache, we'll modify
812
        // the entry with the updated policy for the direction that was just
813
        // written. If the edge doesn't exist, we'll defer loading the info and
814
        // policies and lazily read from disk during the next query.
815
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
816
                if isUpdate1 {
×
817
                        channel.Policy1 = e
×
818
                } else {
×
819
                        channel.Policy2 = e
×
820
                }
×
821
                s.chanCache.insert(e.ChannelID, channel)
×
822
        }
823
}
824

825
// ForEachSourceNodeChannel iterates through all channels of the source node,
826
// executing the passed callback on each. The call-back is provided with the
827
// channel's outpoint, whether we have a policy for the channel and the channel
828
// peer's node information.
829
//
830
// NOTE: part of the V1Store interface.
831
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
832
        cb func(chanPoint wire.OutPoint, havePolicy bool,
833
                otherNode *models.Node) error, reset func()) error {
×
834

×
835
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
836
                nodeID, nodePub, err := s.getSourceNode(
×
NEW
837
                        ctx, db, lnwire.GossipVersion1,
×
NEW
838
                )
×
839
                if err != nil {
×
840
                        return fmt.Errorf("unable to fetch source node: %w",
×
841
                                err)
×
842
                }
×
843

844
                return forEachNodeChannel(
×
845
                        ctx, db, s.cfg, nodeID,
×
846
                        func(info *models.ChannelEdgeInfo,
×
847
                                outPolicy *models.ChannelEdgePolicy,
×
848
                                _ *models.ChannelEdgePolicy) error {
×
849

×
850
                                // Fetch the other node.
×
851
                                var (
×
852
                                        otherNodePub [33]byte
×
853
                                        node1        = info.NodeKey1Bytes
×
854
                                        node2        = info.NodeKey2Bytes
×
855
                                )
×
856
                                switch {
×
857
                                case bytes.Equal(node1[:], nodePub[:]):
×
858
                                        otherNodePub = node2
×
859
                                case bytes.Equal(node2[:], nodePub[:]):
×
860
                                        otherNodePub = node1
×
861
                                default:
×
862
                                        return fmt.Errorf("node not " +
×
863
                                                "participating in this channel")
×
864
                                }
865

866
                                _, otherNode, err := getNodeByPubKey(
×
867
                                        ctx, s.cfg.QueryCfg, db, otherNodePub,
×
868
                                )
×
869
                                if err != nil {
×
870
                                        return fmt.Errorf("unable to fetch "+
×
871
                                                "other node(%x): %w",
×
872
                                                otherNodePub, err)
×
873
                                }
×
874

875
                                return cb(
×
876
                                        info.ChannelPoint, outPolicy != nil,
×
877
                                        otherNode,
×
878
                                )
×
879
                        },
880
                )
881
        }, reset)
882
}
883

884
// ForEachNode iterates through all the stored vertices/nodes in the graph,
885
// executing the passed callback with each node encountered. If the callback
886
// returns an error, then the transaction is aborted and the iteration stops
887
// early.
888
//
889
// NOTE: part of the V1Store interface.
890
func (s *SQLStore) ForEachNode(ctx context.Context,
891
        cb func(node *models.Node) error, reset func()) error {
×
892

×
893
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
894
                return forEachNodePaginated(
×
895
                        ctx, s.cfg.QueryCfg, db,
×
NEW
896
                        lnwire.GossipVersion1, func(_ context.Context, _ int64,
×
897
                                node *models.Node) error {
×
898

×
899
                                return cb(node)
×
900
                        },
×
901
                )
902
        }, reset)
903
}
904

905
// ForEachNodeDirectedChannel iterates through all channels of a given node,
906
// executing the passed callback on the directed edge representing the channel
907
// and its incoming policy. If the callback returns an error, then the iteration
908
// is halted with the error propagated back up to the caller.
909
//
910
// Unknown policies are passed into the callback as nil values.
911
//
912
// NOTE: this is part of the graphdb.NodeTraverser interface.
913
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
914
        cb func(channel *DirectedChannel) error, reset func()) error {
×
915

×
916
        var ctx = context.TODO()
×
917

×
918
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
919
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
920
        }, reset)
×
921
}
922

923
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
924
// graph, executing the passed callback with each node encountered. If the
925
// callback returns an error, then the transaction is aborted and the iteration
926
// stops early.
927
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
928
        cb func(route.Vertex, *lnwire.FeatureVector) error,
929
        reset func()) error {
×
930

×
931
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
932
                return forEachNodeCacheable(
×
933
                        ctx, s.cfg.QueryCfg, db,
×
934
                        func(_ int64, nodePub route.Vertex,
×
935
                                features *lnwire.FeatureVector) error {
×
936

×
937
                                return cb(nodePub, features)
×
938
                        },
×
939
                )
940
        }, reset)
941
        if err != nil {
×
942
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
943
        }
×
944

945
        return nil
×
946
}
947

948
// ForEachNodeChannel iterates through all channels of the given node,
949
// executing the passed callback with an edge info structure and the policies
950
// of each end of the channel. The first edge policy is the outgoing edge *to*
951
// the connecting node, while the second is the incoming edge *from* the
952
// connecting node. If the callback returns an error, then the iteration is
953
// halted with the error propagated back up to the caller.
954
//
955
// Unknown policies are passed into the callback as nil values.
956
//
957
// NOTE: part of the V1Store interface.
958
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
959
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
960
                *models.ChannelEdgePolicy) error, reset func()) error {
×
961

×
962
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
963
                dbNode, err := db.GetNodeByPubKey(
×
964
                        ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
965
                                Version: int16(lnwire.GossipVersion1),
×
966
                                PubKey:  nodePub[:],
×
967
                        },
×
968
                )
×
969
                if errors.Is(err, sql.ErrNoRows) {
×
970
                        return nil
×
971
                } else if err != nil {
×
972
                        return fmt.Errorf("unable to fetch node: %w", err)
×
973
                }
×
974

975
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
976
        }, reset)
977
}
978

979
// extractMaxUpdateTime returns the maximum of the two policy update times.
980
// This is used for pagination cursor tracking.
981
func extractMaxUpdateTime(
982
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) int64 {
×
983

×
984
        switch {
×
985
        case row.Policy1LastUpdate.Valid && row.Policy2LastUpdate.Valid:
×
986
                return max(row.Policy1LastUpdate.Int64,
×
987
                        row.Policy2LastUpdate.Int64)
×
988
        case row.Policy1LastUpdate.Valid:
×
989
                return row.Policy1LastUpdate.Int64
×
990
        case row.Policy2LastUpdate.Valid:
×
991
                return row.Policy2LastUpdate.Int64
×
992
        default:
×
993
                return 0
×
994
        }
995
}
996

997
// buildChannelFromRow constructs a ChannelEdge from a database row.
998
// This includes building the nodes, channel info, and policies.
999
func (s *SQLStore) buildChannelFromRow(ctx context.Context, db SQLQueries,
1000
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) (ChannelEdge, error) {
×
1001

×
1002
        node1, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode)
×
1003
        if err != nil {
×
1004
                return ChannelEdge{}, fmt.Errorf("unable to build node1: %w",
×
1005
                        err)
×
1006
        }
×
1007

1008
        node2, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode_2)
×
1009
        if err != nil {
×
1010
                return ChannelEdge{}, fmt.Errorf("unable to build node2: %w",
×
1011
                        err)
×
1012
        }
×
1013

1014
        channel, err := getAndBuildEdgeInfo(
×
1015
                ctx, s.cfg, db,
×
1016
                row.GraphChannel, node1.PubKeyBytes,
×
1017
                node2.PubKeyBytes,
×
1018
        )
×
1019
        if err != nil {
×
1020
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1021
                        "channel info: %w", err)
×
1022
        }
×
1023

1024
        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1025
        if err != nil {
×
1026
                return ChannelEdge{}, fmt.Errorf("unable to extract "+
×
1027
                        "channel policies: %w", err)
×
1028
        }
×
1029

1030
        p1, p2, err := getAndBuildChanPolicies(
×
1031
                ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, channel.ChannelID,
×
1032
                node1.PubKeyBytes, node2.PubKeyBytes,
×
1033
        )
×
1034
        if err != nil {
×
1035
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1036
                        "channel policies: %w", err)
×
1037
        }
×
1038

1039
        return ChannelEdge{
×
1040
                Info:    channel,
×
1041
                Policy1: p1,
×
1042
                Policy2: p2,
×
1043
                Node1:   node1,
×
1044
                Node2:   node2,
×
1045
        }, nil
×
1046
}
1047

1048
// updateChanCacheBatch updates the channel cache with multiple edges at once.
1049
// This method acquires the cache lock only once for the entire batch.
1050
func (s *SQLStore) updateChanCacheBatch(edgesToCache map[uint64]ChannelEdge) {
×
1051
        if len(edgesToCache) == 0 {
×
1052
                return
×
1053
        }
×
1054

1055
        s.cacheMu.Lock()
×
1056
        defer s.cacheMu.Unlock()
×
1057

×
1058
        for chanID, edge := range edgesToCache {
×
1059
                s.chanCache.insert(chanID, edge)
×
1060
        }
×
1061
}
1062

1063
// ChanUpdatesInHorizon returns all the known channel edges which have at least
1064
// one edge that has an update timestamp within the specified horizon.
1065
//
1066
// Iterator Lifecycle:
1067
// 1. Initialize state (edgesSeen map, cache tracking, pagination cursors)
1068
// 2. Query batch of channels with policies in time range
1069
// 3. For each channel: check if seen, check cache, or build from DB
1070
// 4. Yield channels to caller
1071
// 5. Update cache after successful batch
1072
// 6. Repeat with updated pagination cursor until no more results
1073
//
1074
// NOTE: This is part of the V1Store interface.
1075
func (s *SQLStore) ChanUpdatesInHorizon(startTime, endTime time.Time,
1076
        opts ...IteratorOption) iter.Seq2[ChannelEdge, error] {
×
1077

×
1078
        // Apply options.
×
1079
        cfg := defaultIteratorConfig()
×
1080
        for _, opt := range opts {
×
1081
                opt(cfg)
×
1082
        }
×
1083

1084
        return func(yield func(ChannelEdge, error) bool) {
×
1085
                var (
×
1086
                        ctx            = context.TODO()
×
1087
                        edgesSeen      = make(map[uint64]struct{})
×
1088
                        edgesToCache   = make(map[uint64]ChannelEdge)
×
1089
                        hits           int
×
1090
                        total          int
×
1091
                        lastUpdateTime sql.NullInt64
×
1092
                        lastID         sql.NullInt64
×
1093
                        hasMore        = true
×
1094
                )
×
1095

×
1096
                // Each iteration, we'll read a batch amount of channel updates
×
1097
                // (consulting the cache along the way), yield them, then loop
×
1098
                // back to decide if we have any more updates to read out.
×
1099
                for hasMore {
×
1100
                        var batch []ChannelEdge
×
1101

×
1102
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(),
×
1103
                                func(db SQLQueries) error {
×
1104
                                        //nolint:ll
×
1105
                                        params := sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
NEW
1106
                                                Version: int16(lnwire.GossipVersion1),
×
1107
                                                StartTime: sqldb.SQLInt64(
×
1108
                                                        startTime.Unix(),
×
1109
                                                ),
×
1110
                                                EndTime: sqldb.SQLInt64(
×
1111
                                                        endTime.Unix(),
×
1112
                                                ),
×
1113
                                                LastUpdateTime: lastUpdateTime,
×
1114
                                                LastID:         lastID,
×
1115
                                                MaxResults: sql.NullInt32{
×
1116
                                                        Int32: int32(
×
1117
                                                                cfg.chanUpdateIterBatchSize,
×
1118
                                                        ),
×
1119
                                                        Valid: true,
×
1120
                                                },
×
1121
                                        }
×
1122
                                        //nolint:ll
×
1123
                                        rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1124
                                                ctx, params,
×
1125
                                        )
×
1126
                                        if err != nil {
×
1127
                                                return err
×
1128
                                        }
×
1129

1130
                                        //nolint:ll
1131
                                        hasMore = len(rows) == cfg.chanUpdateIterBatchSize
×
1132

×
1133
                                        //nolint:ll
×
1134
                                        for _, row := range rows {
×
1135
                                                lastUpdateTime = sql.NullInt64{
×
1136
                                                        Int64: extractMaxUpdateTime(row),
×
1137
                                                        Valid: true,
×
1138
                                                }
×
1139
                                                lastID = sql.NullInt64{
×
1140
                                                        Int64: row.GraphChannel.ID,
×
1141
                                                        Valid: true,
×
1142
                                                }
×
1143

×
1144
                                                // Skip if we've already
×
1145
                                                // processed this channel.
×
1146
                                                chanIDInt := byteOrder.Uint64(
×
1147
                                                        row.GraphChannel.Scid,
×
1148
                                                )
×
1149
                                                _, ok := edgesSeen[chanIDInt]
×
1150
                                                if ok {
×
1151
                                                        continue
×
1152
                                                }
1153

1154
                                                s.cacheMu.RLock()
×
1155
                                                channel, ok := s.chanCache.get(
×
1156
                                                        chanIDInt,
×
1157
                                                )
×
1158
                                                s.cacheMu.RUnlock()
×
1159
                                                if ok {
×
1160
                                                        hits++
×
1161
                                                        total++
×
1162
                                                        edgesSeen[chanIDInt] = struct{}{}
×
1163
                                                        batch = append(batch, channel)
×
1164

×
1165
                                                        continue
×
1166
                                                }
1167

1168
                                                chanEdge, err := s.buildChannelFromRow(
×
1169
                                                        ctx, db, row,
×
1170
                                                )
×
1171
                                                if err != nil {
×
1172
                                                        return err
×
1173
                                                }
×
1174

1175
                                                edgesSeen[chanIDInt] = struct{}{}
×
1176
                                                edgesToCache[chanIDInt] = chanEdge
×
1177

×
1178
                                                batch = append(batch, chanEdge)
×
1179

×
1180
                                                total++
×
1181
                                        }
1182

1183
                                        return nil
×
1184
                                }, func() {
×
1185
                                        batch = nil
×
1186
                                        edgesSeen = make(map[uint64]struct{})
×
1187
                                        edgesToCache = make(
×
1188
                                                map[uint64]ChannelEdge,
×
1189
                                        )
×
1190
                                })
×
1191

1192
                        if err != nil {
×
1193
                                log.Errorf("ChanUpdatesInHorizon "+
×
1194
                                        "batch error: %v", err)
×
1195

×
1196
                                yield(ChannelEdge{}, err)
×
1197

×
1198
                                return
×
1199
                        }
×
1200

1201
                        for _, edge := range batch {
×
1202
                                if !yield(edge, nil) {
×
1203
                                        return
×
1204
                                }
×
1205
                        }
1206

1207
                        // Update cache after successful batch yield, setting
1208
                        // the cache lock only once for the entire batch.
1209
                        s.updateChanCacheBatch(edgesToCache)
×
1210
                        edgesToCache = make(map[uint64]ChannelEdge)
×
1211

×
1212
                        // If the batch didn't yield anything, then we're done.
×
1213
                        if len(batch) == 0 {
×
1214
                                break
×
1215
                        }
1216
                }
1217

1218
                if total > 0 {
×
1219
                        log.Debugf("ChanUpdatesInHorizon hit percentage: "+
×
1220
                                "%.2f (%d/%d)",
×
1221
                                float64(hits)*100/float64(total), hits, total)
×
1222
                } else {
×
1223
                        log.Debugf("ChanUpdatesInHorizon returned no edges "+
×
1224
                                "in horizon (%s, %s)", startTime, endTime)
×
1225
                }
×
1226
        }
1227
}
1228

1229
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1230
// data to the call-back. If withAddrs is true, then the call-back will also be
1231
// provided with the addresses associated with the node. The address retrieval
1232
// result in an additional round-trip to the database, so it should only be used
1233
// if the addresses are actually needed.
1234
//
1235
// NOTE: part of the V1Store interface.
1236
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1237
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1238
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1239

×
1240
        type nodeCachedBatchData struct {
×
1241
                features      map[int64][]int
×
1242
                addrs         map[int64][]nodeAddress
×
1243
                chanBatchData *batchChannelData
×
1244
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1245
        }
×
1246

×
1247
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1248
                // pageQueryFunc is used to query the next page of nodes.
×
1249
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1250
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1251

×
1252
                        return db.ListNodeIDsAndPubKeys(
×
1253
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
NEW
1254
                                        Version: int16(lnwire.GossipVersion1),
×
1255
                                        ID:      lastID,
×
1256
                                        Limit:   limit,
×
1257
                                },
×
1258
                        )
×
1259
                }
×
1260

1261
                // batchDataFunc is then used to batch load the data required
1262
                // for each page of nodes.
1263
                batchDataFunc := func(ctx context.Context,
×
1264
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1265

×
1266
                        // Batch load node features.
×
1267
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1268
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1269
                        )
×
1270
                        if err != nil {
×
1271
                                return nil, fmt.Errorf("unable to batch load "+
×
1272
                                        "node features: %w", err)
×
1273
                        }
×
1274

1275
                        // Maybe fetch the node's addresses if requested.
1276
                        var nodeAddrs map[int64][]nodeAddress
×
1277
                        if withAddrs {
×
1278
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1279
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1280
                                )
×
1281
                                if err != nil {
×
1282
                                        return nil, fmt.Errorf("unable to "+
×
1283
                                                "batch load node "+
×
1284
                                                "addresses: %w", err)
×
1285
                                }
×
1286
                        }
1287

1288
                        // Batch load ALL unique channels for ALL nodes in this
1289
                        // page.
1290
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1291
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
NEW
1292
                                        Version:  int16(lnwire.GossipVersion1),
×
1293
                                        Node1Ids: nodeIDs,
×
1294
                                        Node2Ids: nodeIDs,
×
1295
                                },
×
1296
                        )
×
1297
                        if err != nil {
×
1298
                                return nil, fmt.Errorf("unable to batch "+
×
1299
                                        "fetch channels for nodes: %w", err)
×
1300
                        }
×
1301

1302
                        // Deduplicate channels and collect IDs.
1303
                        var (
×
1304
                                allChannelIDs []int64
×
1305
                                allPolicyIDs  []int64
×
1306
                        )
×
1307
                        uniqueChannels := make(
×
1308
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1309
                        )
×
1310

×
1311
                        for _, channel := range allChannels {
×
1312
                                channelID := channel.GraphChannel.ID
×
1313

×
1314
                                // Only process each unique channel once.
×
1315
                                _, exists := uniqueChannels[channelID]
×
1316
                                if exists {
×
1317
                                        continue
×
1318
                                }
1319

1320
                                uniqueChannels[channelID] = channel
×
1321
                                allChannelIDs = append(allChannelIDs, channelID)
×
1322

×
1323
                                if channel.Policy1ID.Valid {
×
1324
                                        allPolicyIDs = append(
×
1325
                                                allPolicyIDs,
×
1326
                                                channel.Policy1ID.Int64,
×
1327
                                        )
×
1328
                                }
×
1329
                                if channel.Policy2ID.Valid {
×
1330
                                        allPolicyIDs = append(
×
1331
                                                allPolicyIDs,
×
1332
                                                channel.Policy2ID.Int64,
×
1333
                                        )
×
1334
                                }
×
1335
                        }
1336

1337
                        // Batch load channel data for all unique channels.
1338
                        channelBatchData, err := batchLoadChannelData(
×
1339
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1340
                                allPolicyIDs,
×
1341
                        )
×
1342
                        if err != nil {
×
1343
                                return nil, fmt.Errorf("unable to batch "+
×
1344
                                        "load channel data: %w", err)
×
1345
                        }
×
1346

1347
                        // Create map of node ID to channels that involve this
1348
                        // node.
1349
                        nodeIDSet := make(map[int64]bool)
×
1350
                        for _, nodeID := range nodeIDs {
×
1351
                                nodeIDSet[nodeID] = true
×
1352
                        }
×
1353

1354
                        nodeChannelMap := make(
×
1355
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1356
                        )
×
1357
                        for _, channel := range uniqueChannels {
×
1358
                                // Add channel to both nodes if they're in our
×
1359
                                // current page.
×
1360
                                node1 := channel.GraphChannel.NodeID1
×
1361
                                if nodeIDSet[node1] {
×
1362
                                        nodeChannelMap[node1] = append(
×
1363
                                                nodeChannelMap[node1], channel,
×
1364
                                        )
×
1365
                                }
×
1366
                                node2 := channel.GraphChannel.NodeID2
×
1367
                                if nodeIDSet[node2] {
×
1368
                                        nodeChannelMap[node2] = append(
×
1369
                                                nodeChannelMap[node2], channel,
×
1370
                                        )
×
1371
                                }
×
1372
                        }
1373

1374
                        return &nodeCachedBatchData{
×
1375
                                features:      nodeFeatures,
×
1376
                                addrs:         nodeAddrs,
×
1377
                                chanBatchData: channelBatchData,
×
1378
                                chanMap:       nodeChannelMap,
×
1379
                        }, nil
×
1380
                }
1381

1382
                // processItem is used to process each node in the current page.
1383
                processItem := func(ctx context.Context,
×
1384
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1385
                        batchData *nodeCachedBatchData) error {
×
1386

×
1387
                        // Build feature vector for this node.
×
1388
                        fv := lnwire.EmptyFeatureVector()
×
1389
                        features, exists := batchData.features[nodeData.ID]
×
1390
                        if exists {
×
1391
                                for _, bit := range features {
×
1392
                                        fv.Set(lnwire.FeatureBit(bit))
×
1393
                                }
×
1394
                        }
1395

1396
                        var nodePub route.Vertex
×
1397
                        copy(nodePub[:], nodeData.PubKey)
×
1398

×
1399
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1400

×
1401
                        toNodeCallback := func() route.Vertex {
×
1402
                                return nodePub
×
1403
                        }
×
1404

1405
                        // Build cached channels map for this node.
1406
                        channels := make(map[uint64]*DirectedChannel)
×
1407
                        for _, channelRow := range nodeChannels {
×
1408
                                directedChan, err := buildDirectedChannel(
×
1409
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1410
                                        channelRow, batchData.chanBatchData, fv,
×
1411
                                        toNodeCallback,
×
1412
                                )
×
1413
                                if err != nil {
×
1414
                                        return err
×
1415
                                }
×
1416

1417
                                channels[directedChan.ChannelID] = directedChan
×
1418
                        }
1419

1420
                        addrs, err := buildNodeAddresses(
×
1421
                                batchData.addrs[nodeData.ID],
×
1422
                        )
×
1423
                        if err != nil {
×
1424
                                return fmt.Errorf("unable to build node "+
×
1425
                                        "addresses: %w", err)
×
1426
                        }
×
1427

1428
                        return cb(ctx, nodePub, addrs, channels)
×
1429
                }
1430

1431
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1432
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1433
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1434
                                return node.ID
×
1435
                        },
×
1436
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1437
                                error) {
×
1438

×
1439
                                return node.ID, nil
×
1440
                        },
×
1441
                        batchDataFunc, processItem,
1442
                )
1443
        }, reset)
1444
}
1445

1446
// ForEachChannelCacheable iterates through all the channel edges stored
1447
// within the graph and invokes the passed callback for each edge. The
1448
// callback takes two edges as since this is a directed graph, both the
1449
// in/out edges are visited. If the callback returns an error, then the
1450
// transaction is aborted and the iteration stops early.
1451
//
1452
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1453
// pointer for that particular channel edge routing policy will be
1454
// passed into the callback.
1455
//
1456
// NOTE: this method is like ForEachChannel but fetches only the data
1457
// required for the graph cache.
1458
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1459
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1460
        reset func()) error {
×
1461

×
1462
        ctx := context.TODO()
×
1463

×
1464
        handleChannel := func(_ context.Context,
×
1465
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1466

×
1467
                node1, node2, err := buildNodeVertices(
×
1468
                        row.Node1Pubkey, row.Node2Pubkey,
×
1469
                )
×
1470
                if err != nil {
×
1471
                        return err
×
1472
                }
×
1473

1474
                edge := buildCacheableChannelInfo(
×
1475
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1476
                )
×
1477

×
1478
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1479
                if err != nil {
×
1480
                        return err
×
1481
                }
×
1482

1483
                pol1, pol2, err := buildCachedChanPolicies(
×
1484
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1485
                )
×
1486
                if err != nil {
×
1487
                        return err
×
1488
                }
×
1489

1490
                return cb(edge, pol1, pol2)
×
1491
        }
1492

1493
        extractCursor := func(
×
1494
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1495

×
1496
                return row.ID
×
1497
        }
×
1498

1499
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1500
                //nolint:ll
×
1501
                queryFunc := func(ctx context.Context, lastID int64,
×
1502
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1503
                        error) {
×
1504

×
1505
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1506
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
NEW
1507
                                        Version: int16(lnwire.GossipVersion1),
×
1508
                                        ID:      lastID,
×
1509
                                        Limit:   limit,
×
1510
                                },
×
1511
                        )
×
1512
                }
×
1513

1514
                return sqldb.ExecutePaginatedQuery(
×
1515
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1516
                        extractCursor, handleChannel,
×
1517
                )
×
1518
        }, reset)
1519
}
1520

1521
// ForEachChannel iterates through all the channel edges stored within the
1522
// graph and invokes the passed callback for each edge. The callback takes two
1523
// edges as since this is a directed graph, both the in/out edges are visited.
1524
// If the callback returns an error, then the transaction is aborted and the
1525
// iteration stops early.
1526
//
1527
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1528
// for that particular channel edge routing policy will be passed into the
1529
// callback.
1530
//
1531
// NOTE: part of the V1Store interface.
1532
func (s *SQLStore) ForEachChannel(ctx context.Context,
1533
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1534
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1535

×
1536
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1537
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1538
        }, reset)
×
1539
}
1540

1541
// FilterChannelRange returns the channel ID's of all known channels which were
1542
// mined in a block height within the passed range. The channel IDs are grouped
1543
// by their common block height. This method can be used to quickly share with a
1544
// peer the set of channels we know of within a particular range to catch them
1545
// up after a period of time offline. If withTimestamps is true then the
1546
// timestamp info of the latest received channel update messages of the channel
1547
// will be included in the response.
1548
//
1549
// NOTE: This is part of the V1Store interface.
1550
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1551
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1552

×
1553
        var (
×
1554
                ctx       = context.TODO()
×
1555
                startSCID = &lnwire.ShortChannelID{
×
1556
                        BlockHeight: startHeight,
×
1557
                }
×
1558
                endSCID = lnwire.ShortChannelID{
×
1559
                        BlockHeight: endHeight,
×
1560
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1561
                        TxPosition:  math.MaxUint16,
×
1562
                }
×
1563
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1564
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1565
        )
×
1566

×
1567
        // 1) get all channels where channelID is between start and end chan ID.
×
1568
        // 2) skip if not public (ie, no channel_proof)
×
1569
        // 3) collect that channel.
×
1570
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1571
        //    and add those timestamps to the collected channel.
×
1572
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1573
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1574
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1575
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1576
                                StartScid: chanIDStart,
×
1577
                                EndScid:   chanIDEnd,
×
1578
                        },
×
1579
                )
×
1580
                if err != nil {
×
1581
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1582
                                err)
×
1583
                }
×
1584

1585
                for _, dbChan := range dbChans {
×
1586
                        cid := lnwire.NewShortChanIDFromInt(
×
1587
                                byteOrder.Uint64(dbChan.Scid),
×
1588
                        )
×
1589
                        chanInfo := NewChannelUpdateInfo(
×
1590
                                cid, time.Time{}, time.Time{},
×
1591
                        )
×
1592

×
1593
                        if !withTimestamps {
×
1594
                                channelsPerBlock[cid.BlockHeight] = append(
×
1595
                                        channelsPerBlock[cid.BlockHeight],
×
1596
                                        chanInfo,
×
1597
                                )
×
1598

×
1599
                                continue
×
1600
                        }
1601

1602
                        //nolint:ll
1603
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1604
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
1605
                                        Version:   int16(lnwire.GossipVersion1),
×
1606
                                        ChannelID: dbChan.ID,
×
1607
                                        NodeID:    dbChan.NodeID1,
×
1608
                                },
×
1609
                        )
×
1610
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1611
                                return fmt.Errorf("unable to fetch node1 "+
×
1612
                                        "policy: %w", err)
×
1613
                        } else if err == nil {
×
1614
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1615
                                        node1Policy.LastUpdate.Int64, 0,
×
1616
                                )
×
1617
                        }
×
1618

1619
                        //nolint:ll
1620
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1621
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
1622
                                        Version:   int16(lnwire.GossipVersion1),
×
1623
                                        ChannelID: dbChan.ID,
×
1624
                                        NodeID:    dbChan.NodeID2,
×
1625
                                },
×
1626
                        )
×
1627
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1628
                                return fmt.Errorf("unable to fetch node2 "+
×
1629
                                        "policy: %w", err)
×
1630
                        } else if err == nil {
×
1631
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1632
                                        node2Policy.LastUpdate.Int64, 0,
×
1633
                                )
×
1634
                        }
×
1635

1636
                        channelsPerBlock[cid.BlockHeight] = append(
×
1637
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1638
                        )
×
1639
                }
1640

1641
                return nil
×
1642
        }, func() {
×
1643
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1644
        })
×
1645
        if err != nil {
×
1646
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1647
        }
×
1648

1649
        if len(channelsPerBlock) == 0 {
×
1650
                return nil, nil
×
1651
        }
×
1652

1653
        // Return the channel ranges in ascending block height order.
1654
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1655
        slices.Sort(blocks)
×
1656

×
1657
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1658
                return BlockChannelRange{
×
1659
                        Height:   block,
×
1660
                        Channels: channelsPerBlock[block],
×
1661
                }
×
1662
        }), nil
×
1663
}
1664

1665
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1666
// zombie. This method is used on an ad-hoc basis, when channels need to be
1667
// marked as zombies outside the normal pruning cycle.
1668
//
1669
// NOTE: part of the V1Store interface.
1670
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1671
        pubKey1, pubKey2 [33]byte) error {
×
1672

×
1673
        ctx := context.TODO()
×
1674

×
1675
        s.cacheMu.Lock()
×
1676
        defer s.cacheMu.Unlock()
×
1677

×
1678
        chanIDB := channelIDToBytes(chanID)
×
1679

×
1680
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1681
                return db.UpsertZombieChannel(
×
1682
                        ctx, sqlc.UpsertZombieChannelParams{
×
NEW
1683
                                Version:  int16(lnwire.GossipVersion1),
×
1684
                                Scid:     chanIDB,
×
1685
                                NodeKey1: pubKey1[:],
×
1686
                                NodeKey2: pubKey2[:],
×
1687
                        },
×
1688
                )
×
1689
        }, sqldb.NoOpReset)
×
1690
        if err != nil {
×
1691
                return fmt.Errorf("unable to upsert zombie channel "+
×
1692
                        "(channel_id=%d): %w", chanID, err)
×
1693
        }
×
1694

1695
        s.rejectCache.remove(chanID)
×
1696
        s.chanCache.remove(chanID)
×
1697

×
1698
        return nil
×
1699
}
1700

1701
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1702
//
1703
// NOTE: part of the V1Store interface.
1704
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1705
        s.cacheMu.Lock()
×
1706
        defer s.cacheMu.Unlock()
×
1707

×
1708
        var (
×
1709
                ctx     = context.TODO()
×
1710
                chanIDB = channelIDToBytes(chanID)
×
1711
        )
×
1712

×
1713
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1714
                res, err := db.DeleteZombieChannel(
×
1715
                        ctx, sqlc.DeleteZombieChannelParams{
×
1716
                                Scid:    chanIDB,
×
NEW
1717
                                Version: int16(lnwire.GossipVersion1),
×
1718
                        },
×
1719
                )
×
1720
                if err != nil {
×
1721
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1722
                                err)
×
1723
                }
×
1724

1725
                rows, err := res.RowsAffected()
×
1726
                if err != nil {
×
1727
                        return err
×
1728
                }
×
1729

1730
                if rows == 0 {
×
1731
                        return ErrZombieEdgeNotFound
×
1732
                } else if rows > 1 {
×
1733
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1734
                                "expected 1", rows)
×
1735
                }
×
1736

1737
                return nil
×
1738
        }, sqldb.NoOpReset)
1739
        if err != nil {
×
1740
                return fmt.Errorf("unable to mark edge live "+
×
1741
                        "(channel_id=%d): %w", chanID, err)
×
1742
        }
×
1743

1744
        s.rejectCache.remove(chanID)
×
1745
        s.chanCache.remove(chanID)
×
1746

×
1747
        return err
×
1748
}
1749

1750
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1751
// zombie, then the two node public keys corresponding to this edge are also
1752
// returned.
1753
//
1754
// NOTE: part of the V1Store interface.
1755
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1756
        error) {
×
1757

×
1758
        var (
×
1759
                ctx              = context.TODO()
×
1760
                isZombie         bool
×
1761
                pubKey1, pubKey2 route.Vertex
×
1762
                chanIDB          = channelIDToBytes(chanID)
×
1763
        )
×
1764

×
1765
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1766
                zombie, err := db.GetZombieChannel(
×
1767
                        ctx, sqlc.GetZombieChannelParams{
×
1768
                                Scid:    chanIDB,
×
NEW
1769
                                Version: int16(lnwire.GossipVersion1),
×
1770
                        },
×
1771
                )
×
1772
                if errors.Is(err, sql.ErrNoRows) {
×
1773
                        return nil
×
1774
                }
×
1775
                if err != nil {
×
1776
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1777
                                err)
×
1778
                }
×
1779

1780
                copy(pubKey1[:], zombie.NodeKey1)
×
1781
                copy(pubKey2[:], zombie.NodeKey2)
×
1782
                isZombie = true
×
1783

×
1784
                return nil
×
1785
        }, sqldb.NoOpReset)
1786
        if err != nil {
×
1787
                return false, route.Vertex{}, route.Vertex{},
×
1788
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1789
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1790
        }
×
1791

1792
        return isZombie, pubKey1, pubKey2, nil
×
1793
}
1794

1795
// NumZombies returns the current number of zombie channels in the graph.
1796
//
1797
// NOTE: part of the V1Store interface.
1798
func (s *SQLStore) NumZombies() (uint64, error) {
×
1799
        var (
×
1800
                ctx        = context.TODO()
×
1801
                numZombies uint64
×
1802
        )
×
1803
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1804
                count, err := db.CountZombieChannels(
×
NEW
1805
                        ctx, int16(lnwire.GossipVersion1),
×
NEW
1806
                )
×
1807
                if err != nil {
×
1808
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1809
                                err)
×
1810
                }
×
1811

1812
                numZombies = uint64(count)
×
1813

×
1814
                return nil
×
1815
        }, sqldb.NoOpReset)
1816
        if err != nil {
×
1817
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1818
        }
×
1819

1820
        return numZombies, nil
×
1821
}
1822

1823
// DeleteChannelEdges removes edges with the given channel IDs from the
1824
// database and marks them as zombies. This ensures that we're unable to re-add
1825
// it to our database once again. If an edge does not exist within the
1826
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1827
// true, then when we mark these edges as zombies, we'll set up the keys such
1828
// that we require the node that failed to send the fresh update to be the one
1829
// that resurrects the channel from its zombie state. The markZombie bool
1830
// denotes whether to mark the channel as a zombie.
1831
//
1832
// NOTE: part of the V1Store interface.
1833
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1834
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1835

×
1836
        s.cacheMu.Lock()
×
1837
        defer s.cacheMu.Unlock()
×
1838

×
1839
        // Keep track of which channels we end up finding so that we can
×
1840
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1841
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1842
        for _, chanID := range chanIDs {
×
1843
                chanLookup[chanID] = struct{}{}
×
1844
        }
×
1845

1846
        var (
×
1847
                ctx   = context.TODO()
×
1848
                edges []*models.ChannelEdgeInfo
×
1849
        )
×
1850
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1851
                // First, collect all channel rows.
×
1852
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1853
                chanCallBack := func(ctx context.Context,
×
1854
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1855

×
1856
                        // Deleting the entry from the map indicates that we
×
1857
                        // have found the channel.
×
1858
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1859
                        delete(chanLookup, scid)
×
1860

×
1861
                        channelRows = append(channelRows, row)
×
1862

×
1863
                        return nil
×
1864
                }
×
1865

1866
                err := s.forEachChanWithPoliciesInSCIDList(
×
1867
                        ctx, db, chanCallBack, chanIDs,
×
1868
                )
×
1869
                if err != nil {
×
1870
                        return err
×
1871
                }
×
1872

1873
                if len(chanLookup) > 0 {
×
1874
                        return ErrEdgeNotFound
×
1875
                }
×
1876

1877
                if len(channelRows) == 0 {
×
1878
                        return nil
×
1879
                }
×
1880

1881
                // Batch build all channel edges.
1882
                var chanIDsToDelete []int64
×
1883
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
1884
                        ctx, s.cfg, db, channelRows,
×
1885
                )
×
1886
                if err != nil {
×
1887
                        return err
×
1888
                }
×
1889

1890
                if markZombie {
×
1891
                        for i, row := range channelRows {
×
1892
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1893

×
1894
                                err := handleZombieMarking(
×
1895
                                        ctx, db, row, edges[i],
×
1896
                                        strictZombiePruning, scid,
×
1897
                                )
×
1898
                                if err != nil {
×
1899
                                        return fmt.Errorf("unable to mark "+
×
1900
                                                "channel as zombie: %w", err)
×
1901
                                }
×
1902
                        }
1903
                }
1904

1905
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1906
        }, func() {
×
1907
                edges = nil
×
1908

×
1909
                // Re-fill the lookup map.
×
1910
                for _, chanID := range chanIDs {
×
1911
                        chanLookup[chanID] = struct{}{}
×
1912
                }
×
1913
        })
1914
        if err != nil {
×
1915
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1916
                        err)
×
1917
        }
×
1918

1919
        for _, chanID := range chanIDs {
×
1920
                s.rejectCache.remove(chanID)
×
1921
                s.chanCache.remove(chanID)
×
1922
        }
×
1923

1924
        return edges, nil
×
1925
}
1926

1927
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1928
// channel identified by the channel ID. If the channel can't be found, then
1929
// ErrEdgeNotFound is returned. A struct which houses the general information
1930
// for the channel itself is returned as well as two structs that contain the
1931
// routing policies for the channel in either direction.
1932
//
1933
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1934
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1935
// the ChannelEdgeInfo will only include the public keys of each node.
1936
//
1937
// NOTE: part of the V1Store interface.
1938
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1939
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1940
        *models.ChannelEdgePolicy, error) {
×
1941

×
1942
        var (
×
1943
                ctx              = context.TODO()
×
1944
                edge             *models.ChannelEdgeInfo
×
1945
                policy1, policy2 *models.ChannelEdgePolicy
×
1946
                chanIDB          = channelIDToBytes(chanID)
×
1947
        )
×
1948
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1949
                row, err := db.GetChannelBySCIDWithPolicies(
×
1950
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1951
                                Scid:    chanIDB,
×
NEW
1952
                                Version: int16(lnwire.GossipVersion1),
×
1953
                        },
×
1954
                )
×
1955
                if errors.Is(err, sql.ErrNoRows) {
×
1956
                        // First check if this edge is perhaps in the zombie
×
1957
                        // index.
×
1958
                        zombie, err := db.GetZombieChannel(
×
1959
                                ctx, sqlc.GetZombieChannelParams{
×
1960
                                        Scid:    chanIDB,
×
NEW
1961
                                        Version: int16(lnwire.GossipVersion1),
×
1962
                                },
×
1963
                        )
×
1964
                        if errors.Is(err, sql.ErrNoRows) {
×
1965
                                return ErrEdgeNotFound
×
1966
                        } else if err != nil {
×
1967
                                return fmt.Errorf("unable to check if "+
×
1968
                                        "channel is zombie: %w", err)
×
1969
                        }
×
1970

1971
                        // At this point, we know the channel is a zombie, so
1972
                        // we'll return an error indicating this, and we will
1973
                        // populate the edge info with the public keys of each
1974
                        // party as this is the only information we have about
1975
                        // it.
1976
                        edge = &models.ChannelEdgeInfo{}
×
1977
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1978
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1979

×
1980
                        return ErrZombieEdge
×
1981
                } else if err != nil {
×
1982
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1983
                }
×
1984

1985
                node1, node2, err := buildNodeVertices(
×
1986
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1987
                )
×
1988
                if err != nil {
×
1989
                        return err
×
1990
                }
×
1991

1992
                edge, err = getAndBuildEdgeInfo(
×
1993
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
1994
                )
×
1995
                if err != nil {
×
1996
                        return fmt.Errorf("unable to build channel info: %w",
×
1997
                                err)
×
1998
                }
×
1999

2000
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2001
                if err != nil {
×
2002
                        return fmt.Errorf("unable to extract channel "+
×
2003
                                "policies: %w", err)
×
2004
                }
×
2005

2006
                policy1, policy2, err = getAndBuildChanPolicies(
×
2007
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2008
                        node1, node2,
×
2009
                )
×
2010
                if err != nil {
×
2011
                        return fmt.Errorf("unable to build channel "+
×
2012
                                "policies: %w", err)
×
2013
                }
×
2014

2015
                return nil
×
2016
        }, sqldb.NoOpReset)
2017
        if err != nil {
×
2018
                // If we are returning the ErrZombieEdge, then we also need to
×
2019
                // return the edge info as the method comment indicates that
×
2020
                // this will be populated when the edge is a zombie.
×
2021
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2022
                        err)
×
2023
        }
×
2024

2025
        return edge, policy1, policy2, nil
×
2026
}
2027

2028
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
2029
// the channel identified by the funding outpoint. If the channel can't be
2030
// found, then ErrEdgeNotFound is returned. A struct which houses the general
2031
// information for the channel itself is returned as well as two structs that
2032
// contain the routing policies for the channel in either direction.
2033
//
2034
// NOTE: part of the V1Store interface.
2035
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
2036
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
2037
        *models.ChannelEdgePolicy, error) {
×
2038

×
2039
        var (
×
2040
                ctx              = context.TODO()
×
2041
                edge             *models.ChannelEdgeInfo
×
2042
                policy1, policy2 *models.ChannelEdgePolicy
×
2043
        )
×
2044
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2045
                row, err := db.GetChannelByOutpointWithPolicies(
×
2046
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
2047
                                Outpoint: op.String(),
×
NEW
2048
                                Version:  int16(lnwire.GossipVersion1),
×
2049
                        },
×
2050
                )
×
2051
                if errors.Is(err, sql.ErrNoRows) {
×
2052
                        return ErrEdgeNotFound
×
2053
                } else if err != nil {
×
2054
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2055
                }
×
2056

2057
                node1, node2, err := buildNodeVertices(
×
2058
                        row.Node1Pubkey, row.Node2Pubkey,
×
2059
                )
×
2060
                if err != nil {
×
2061
                        return err
×
2062
                }
×
2063

2064
                edge, err = getAndBuildEdgeInfo(
×
2065
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2066
                )
×
2067
                if err != nil {
×
2068
                        return fmt.Errorf("unable to build channel info: %w",
×
2069
                                err)
×
2070
                }
×
2071

2072
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2073
                if err != nil {
×
2074
                        return fmt.Errorf("unable to extract channel "+
×
2075
                                "policies: %w", err)
×
2076
                }
×
2077

2078
                policy1, policy2, err = getAndBuildChanPolicies(
×
2079
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2080
                        node1, node2,
×
2081
                )
×
2082
                if err != nil {
×
2083
                        return fmt.Errorf("unable to build channel "+
×
2084
                                "policies: %w", err)
×
2085
                }
×
2086

2087
                return nil
×
2088
        }, sqldb.NoOpReset)
2089
        if err != nil {
×
2090
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2091
                        err)
×
2092
        }
×
2093

2094
        return edge, policy1, policy2, nil
×
2095
}
2096

2097
// HasChannelEdge returns true if the database knows of a channel edge with the
2098
// passed channel ID, and false otherwise. If an edge with that ID is found
2099
// within the graph, then two time stamps representing the last time the edge
2100
// was updated for both directed edges are returned along with the boolean. If
2101
// it is not found, then the zombie index is checked and its result is returned
2102
// as the second boolean.
2103
//
2104
// NOTE: part of the V1Store interface.
2105
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
2106
        bool, error) {
×
2107

×
2108
        ctx := context.TODO()
×
2109

×
2110
        var (
×
2111
                exists          bool
×
2112
                isZombie        bool
×
2113
                node1LastUpdate time.Time
×
2114
                node2LastUpdate time.Time
×
2115
        )
×
2116

×
2117
        // We'll query the cache with the shared lock held to allow multiple
×
2118
        // readers to access values in the cache concurrently if they exist.
×
2119
        s.cacheMu.RLock()
×
2120
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2121
                s.cacheMu.RUnlock()
×
2122
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2123
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2124
                exists, isZombie = entry.flags.unpack()
×
2125

×
2126
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2127
        }
×
2128
        s.cacheMu.RUnlock()
×
2129

×
2130
        s.cacheMu.Lock()
×
2131
        defer s.cacheMu.Unlock()
×
2132

×
2133
        // The item was not found with the shared lock, so we'll acquire the
×
2134
        // exclusive lock and check the cache again in case another method added
×
2135
        // the entry to the cache while no lock was held.
×
2136
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2137
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2138
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2139
                exists, isZombie = entry.flags.unpack()
×
2140

×
2141
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2142
        }
×
2143

2144
        chanIDB := channelIDToBytes(chanID)
×
2145
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2146
                channel, err := db.GetChannelBySCID(
×
2147
                        ctx, sqlc.GetChannelBySCIDParams{
×
2148
                                Scid:    chanIDB,
×
NEW
2149
                                Version: int16(lnwire.GossipVersion1),
×
2150
                        },
×
2151
                )
×
2152
                if errors.Is(err, sql.ErrNoRows) {
×
2153
                        // Check if it is a zombie channel.
×
2154
                        isZombie, err = db.IsZombieChannel(
×
2155
                                ctx, sqlc.IsZombieChannelParams{
×
2156
                                        Scid:    chanIDB,
×
NEW
2157
                                        Version: int16(lnwire.GossipVersion1),
×
2158
                                },
×
2159
                        )
×
2160
                        if err != nil {
×
2161
                                return fmt.Errorf("could not check if channel "+
×
2162
                                        "is zombie: %w", err)
×
2163
                        }
×
2164

2165
                        return nil
×
2166
                } else if err != nil {
×
2167
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2168
                }
×
2169

2170
                exists = true
×
2171

×
2172
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2173
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
2174
                                Version:   int16(lnwire.GossipVersion1),
×
2175
                                ChannelID: channel.ID,
×
2176
                                NodeID:    channel.NodeID1,
×
2177
                        },
×
2178
                )
×
2179
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2180
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2181
                                err)
×
2182
                } else if err == nil {
×
2183
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2184
                }
×
2185

2186
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2187
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
2188
                                Version:   int16(lnwire.GossipVersion1),
×
2189
                                ChannelID: channel.ID,
×
2190
                                NodeID:    channel.NodeID2,
×
2191
                        },
×
2192
                )
×
2193
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2194
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2195
                                err)
×
2196
                } else if err == nil {
×
2197
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2198
                }
×
2199

2200
                return nil
×
2201
        }, sqldb.NoOpReset)
2202
        if err != nil {
×
2203
                return time.Time{}, time.Time{}, false, false,
×
2204
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2205
        }
×
2206

2207
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2208
                upd1Time: node1LastUpdate.Unix(),
×
2209
                upd2Time: node2LastUpdate.Unix(),
×
2210
                flags:    packRejectFlags(exists, isZombie),
×
2211
        })
×
2212

×
2213
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2214
}
2215

2216
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2217
// passed channel point (outpoint). If the passed channel doesn't exist within
2218
// the database, then ErrEdgeNotFound is returned.
2219
//
2220
// NOTE: part of the V1Store interface.
2221
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2222
        var (
×
2223
                ctx       = context.TODO()
×
2224
                channelID uint64
×
2225
        )
×
2226
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2227
                chanID, err := db.GetSCIDByOutpoint(
×
2228
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2229
                                Outpoint: chanPoint.String(),
×
NEW
2230
                                Version:  int16(lnwire.GossipVersion1),
×
2231
                        },
×
2232
                )
×
2233
                if errors.Is(err, sql.ErrNoRows) {
×
2234
                        return ErrEdgeNotFound
×
2235
                } else if err != nil {
×
2236
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2237
                                err)
×
2238
                }
×
2239

2240
                channelID = byteOrder.Uint64(chanID)
×
2241

×
2242
                return nil
×
2243
        }, sqldb.NoOpReset)
2244
        if err != nil {
×
2245
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2246
        }
×
2247

2248
        return channelID, nil
×
2249
}
2250

2251
// IsPublicNode is a helper method that determines whether the node with the
2252
// given public key is seen as a public node in the graph from the graph's
2253
// source node's point of view.
2254
//
2255
// NOTE: part of the V1Store interface.
2256
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2257
        ctx := context.TODO()
×
2258

×
2259
        var isPublic bool
×
2260
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2261
                var err error
×
2262
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2263

×
2264
                return err
×
2265
        }, sqldb.NoOpReset)
×
2266
        if err != nil {
×
2267
                return false, fmt.Errorf("unable to check if node is "+
×
2268
                        "public: %w", err)
×
2269
        }
×
2270

2271
        return isPublic, nil
×
2272
}
2273

2274
// FetchChanInfos returns the set of channel edges that correspond to the passed
2275
// channel ID's. If an edge is the query is unknown to the database, it will
2276
// skipped and the result will contain only those edges that exist at the time
2277
// of the query. This can be used to respond to peer queries that are seeking to
2278
// fill in gaps in their view of the channel graph.
2279
//
2280
// NOTE: part of the V1Store interface.
2281
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2282
        var (
×
2283
                ctx   = context.TODO()
×
2284
                edges = make(map[uint64]ChannelEdge)
×
2285
        )
×
2286
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2287
                // First, collect all channel rows.
×
2288
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2289
                chanCallBack := func(ctx context.Context,
×
2290
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2291

×
2292
                        channelRows = append(channelRows, row)
×
2293
                        return nil
×
2294
                }
×
2295

2296
                err := s.forEachChanWithPoliciesInSCIDList(
×
2297
                        ctx, db, chanCallBack, chanIDs,
×
2298
                )
×
2299
                if err != nil {
×
2300
                        return err
×
2301
                }
×
2302

2303
                if len(channelRows) == 0 {
×
2304
                        return nil
×
2305
                }
×
2306

2307
                // Batch build all channel edges.
2308
                chans, err := batchBuildChannelEdges(
×
2309
                        ctx, s.cfg, db, channelRows,
×
2310
                )
×
2311
                if err != nil {
×
2312
                        return fmt.Errorf("unable to build channel edges: %w",
×
2313
                                err)
×
2314
                }
×
2315

2316
                for _, c := range chans {
×
2317
                        edges[c.Info.ChannelID] = c
×
2318
                }
×
2319

2320
                return err
×
2321
        }, func() {
×
2322
                clear(edges)
×
2323
        })
×
2324
        if err != nil {
×
2325
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2326
        }
×
2327

2328
        res := make([]ChannelEdge, 0, len(edges))
×
2329
        for _, chanID := range chanIDs {
×
2330
                edge, ok := edges[chanID]
×
2331
                if !ok {
×
2332
                        continue
×
2333
                }
2334

2335
                res = append(res, edge)
×
2336
        }
2337

2338
        return res, nil
×
2339
}
2340

2341
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2342
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2343
// channels in a paginated manner.
2344
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2345
        db SQLQueries, cb func(ctx context.Context,
2346
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2347
        chanIDs []uint64) error {
×
2348

×
2349
        queryWrapper := func(ctx context.Context,
×
2350
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2351
                error) {
×
2352

×
2353
                return db.GetChannelsBySCIDWithPolicies(
×
2354
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
NEW
2355
                                Version: int16(lnwire.GossipVersion1),
×
2356
                                Scids:   scids,
×
2357
                        },
×
2358
                )
×
2359
        }
×
2360

2361
        return sqldb.ExecuteBatchQuery(
×
2362
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2363
                cb,
×
2364
        )
×
2365
}
2366

2367
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2368
// ID's that we don't know and are not known zombies of the passed set. In other
2369
// words, we perform a set difference of our set of chan ID's and the ones
2370
// passed in. This method can be used by callers to determine the set of
2371
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2372
// known zombies is also returned.
2373
//
2374
// NOTE: part of the V1Store interface.
2375
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2376
        []ChannelUpdateInfo, error) {
×
2377

×
2378
        var (
×
2379
                ctx          = context.TODO()
×
2380
                newChanIDs   []uint64
×
2381
                knownZombies []ChannelUpdateInfo
×
2382
                infoLookup   = make(
×
2383
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2384
                )
×
2385
        )
×
2386

×
2387
        // We first build a lookup map of the channel ID's to the
×
2388
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2389
        // already know about.
×
2390
        for _, chanInfo := range chansInfo {
×
2391
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2392
        }
×
2393

2394
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2395
                // The call-back function deletes known channels from
×
2396
                // infoLookup, so that we can later check which channels are
×
2397
                // zombies by only looking at the remaining channels in the set.
×
2398
                cb := func(ctx context.Context,
×
2399
                        channel sqlc.GraphChannel) error {
×
2400

×
2401
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2402

×
2403
                        return nil
×
2404
                }
×
2405

2406
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2407
                if err != nil {
×
2408
                        return fmt.Errorf("unable to iterate through "+
×
2409
                                "channels: %w", err)
×
2410
                }
×
2411

2412
                // We want to ensure that we deal with the channels in the
2413
                // same order that they were passed in, so we iterate over the
2414
                // original chansInfo slice and then check if that channel is
2415
                // still in the infoLookup map.
2416
                for _, chanInfo := range chansInfo {
×
2417
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2418
                        if _, ok := infoLookup[channelID]; !ok {
×
2419
                                continue
×
2420
                        }
2421

2422
                        isZombie, err := db.IsZombieChannel(
×
2423
                                ctx, sqlc.IsZombieChannelParams{
×
2424
                                        Scid:    channelIDToBytes(channelID),
×
NEW
2425
                                        Version: int16(lnwire.GossipVersion1),
×
2426
                                },
×
2427
                        )
×
2428
                        if err != nil {
×
2429
                                return fmt.Errorf("unable to fetch zombie "+
×
2430
                                        "channel: %w", err)
×
2431
                        }
×
2432

2433
                        if isZombie {
×
2434
                                knownZombies = append(knownZombies, chanInfo)
×
2435

×
2436
                                continue
×
2437
                        }
2438

2439
                        newChanIDs = append(newChanIDs, channelID)
×
2440
                }
2441

2442
                return nil
×
2443
        }, func() {
×
2444
                newChanIDs = nil
×
2445
                knownZombies = nil
×
2446
                // Rebuild the infoLookup map in case of a rollback.
×
2447
                for _, chanInfo := range chansInfo {
×
2448
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2449
                        infoLookup[scid] = chanInfo
×
2450
                }
×
2451
        })
2452
        if err != nil {
×
2453
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2454
        }
×
2455

2456
        return newChanIDs, knownZombies, nil
×
2457
}
2458

2459
// forEachChanInSCIDList is a helper method that executes a paged query
2460
// against the database to fetch all channels that match the passed
2461
// ChannelUpdateInfo slice. The callback function is called for each channel
2462
// that is found.
2463
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2464
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2465
        chansInfo []ChannelUpdateInfo) error {
×
2466

×
2467
        queryWrapper := func(ctx context.Context,
×
2468
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2469

×
2470
                return db.GetChannelsBySCIDs(
×
2471
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
NEW
2472
                                Version: int16(lnwire.GossipVersion1),
×
2473
                                Scids:   scids,
×
2474
                        },
×
2475
                )
×
2476
        }
×
2477

2478
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2479
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2480

×
2481
                return channelIDToBytes(channelID)
×
2482
        }
×
2483

2484
        return sqldb.ExecuteBatchQuery(
×
2485
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2486
                cb,
×
2487
        )
×
2488
}
2489

2490
// PruneGraphNodes is a garbage collection method which attempts to prune out
2491
// any nodes from the channel graph that are currently unconnected. This ensure
2492
// that we only maintain a graph of reachable nodes. In the event that a pruned
2493
// node gains more channels, it will be re-added back to the graph.
2494
//
2495
// NOTE: this prunes nodes across protocol versions. It will never prune the
2496
// source nodes.
2497
//
2498
// NOTE: part of the V1Store interface.
2499
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2500
        var ctx = context.TODO()
×
2501

×
2502
        var prunedNodes []route.Vertex
×
2503
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2504
                var err error
×
2505
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2506

×
2507
                return err
×
2508
        }, func() {
×
2509
                prunedNodes = nil
×
2510
        })
×
2511
        if err != nil {
×
2512
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2513
        }
×
2514

2515
        return prunedNodes, nil
×
2516
}
2517

2518
// PruneGraph prunes newly closed channels from the channel graph in response
2519
// to a new block being solved on the network. Any transactions which spend the
2520
// funding output of any known channels within he graph will be deleted.
2521
// Additionally, the "prune tip", or the last block which has been used to
2522
// prune the graph is stored so callers can ensure the graph is fully in sync
2523
// with the current UTXO state. A slice of channels that have been closed by
2524
// the target block along with any pruned nodes are returned if the function
2525
// succeeds without error.
2526
//
2527
// NOTE: part of the V1Store interface.
2528
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2529
        blockHash *chainhash.Hash, blockHeight uint32) (
2530
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2531

×
2532
        ctx := context.TODO()
×
2533

×
2534
        s.cacheMu.Lock()
×
2535
        defer s.cacheMu.Unlock()
×
2536

×
2537
        var (
×
2538
                closedChans []*models.ChannelEdgeInfo
×
2539
                prunedNodes []route.Vertex
×
2540
        )
×
2541
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2542
                // First, collect all channel rows that need to be pruned.
×
2543
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2544
                channelCallback := func(ctx context.Context,
×
2545
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2546

×
2547
                        channelRows = append(channelRows, row)
×
2548

×
2549
                        return nil
×
2550
                }
×
2551

2552
                err := s.forEachChanInOutpoints(
×
2553
                        ctx, db, spentOutputs, channelCallback,
×
2554
                )
×
2555
                if err != nil {
×
2556
                        return fmt.Errorf("unable to fetch channels by "+
×
2557
                                "outpoints: %w", err)
×
2558
                }
×
2559

2560
                if len(channelRows) == 0 {
×
2561
                        // There are no channels to prune. So we can exit early
×
2562
                        // after updating the prune log.
×
2563
                        err = db.UpsertPruneLogEntry(
×
2564
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
2565
                                        BlockHash:   blockHash[:],
×
2566
                                        BlockHeight: int64(blockHeight),
×
2567
                                },
×
2568
                        )
×
2569
                        if err != nil {
×
2570
                                return fmt.Errorf("unable to insert prune log "+
×
2571
                                        "entry: %w", err)
×
2572
                        }
×
2573

2574
                        return nil
×
2575
                }
2576

2577
                // Batch build all channel edges for pruning.
2578
                var chansToDelete []int64
×
2579
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
2580
                        ctx, s.cfg, db, channelRows,
×
2581
                )
×
2582
                if err != nil {
×
2583
                        return err
×
2584
                }
×
2585

2586
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2587
                if err != nil {
×
2588
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2589
                }
×
2590

2591
                err = db.UpsertPruneLogEntry(
×
2592
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2593
                                BlockHash:   blockHash[:],
×
2594
                                BlockHeight: int64(blockHeight),
×
2595
                        },
×
2596
                )
×
2597
                if err != nil {
×
2598
                        return fmt.Errorf("unable to insert prune log "+
×
2599
                                "entry: %w", err)
×
2600
                }
×
2601

2602
                // Now that we've pruned some channels, we'll also prune any
2603
                // nodes that no longer have any channels.
2604
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2605
                if err != nil {
×
2606
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2607
                                err)
×
2608
                }
×
2609

2610
                return nil
×
2611
        }, func() {
×
2612
                prunedNodes = nil
×
2613
                closedChans = nil
×
2614
        })
×
2615
        if err != nil {
×
2616
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2617
        }
×
2618

2619
        for _, channel := range closedChans {
×
2620
                s.rejectCache.remove(channel.ChannelID)
×
2621
                s.chanCache.remove(channel.ChannelID)
×
2622
        }
×
2623

2624
        return closedChans, prunedNodes, nil
×
2625
}
2626

2627
// forEachChanInOutpoints is a helper function that executes a paginated
2628
// query to fetch channels by their outpoints and applies the given call-back
2629
// to each.
2630
//
2631
// NOTE: this fetches channels for all protocol versions.
2632
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2633
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2634
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2635

×
2636
        // Create a wrapper that uses the transaction's db instance to execute
×
2637
        // the query.
×
2638
        queryWrapper := func(ctx context.Context,
×
2639
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2640
                error) {
×
2641

×
2642
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2643
        }
×
2644

2645
        // Define the conversion function from Outpoint to string.
2646
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2647
                return outpoint.String()
×
2648
        }
×
2649

2650
        return sqldb.ExecuteBatchQuery(
×
2651
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2652
                queryWrapper, cb,
×
2653
        )
×
2654
}
2655

2656
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2657
        dbIDs []int64) error {
×
2658

×
2659
        // Create a wrapper that uses the transaction's db instance to execute
×
2660
        // the query.
×
2661
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2662
                return nil, db.DeleteChannels(ctx, ids)
×
2663
        }
×
2664

2665
        idConverter := func(id int64) int64 {
×
2666
                return id
×
2667
        }
×
2668

2669
        return sqldb.ExecuteBatchQuery(
×
2670
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2671
                queryWrapper, func(ctx context.Context, _ any) error {
×
2672
                        return nil
×
2673
                },
×
2674
        )
2675
}
2676

2677
// ChannelView returns the verifiable edge information for each active channel
2678
// within the known channel graph. The set of UTXOs (along with their scripts)
2679
// returned are the ones that need to be watched on chain to detect channel
2680
// closes on the resident blockchain.
2681
//
2682
// NOTE: part of the V1Store interface.
2683
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2684
        var (
×
2685
                ctx        = context.TODO()
×
2686
                edgePoints []EdgePoint
×
2687
        )
×
2688

×
2689
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2690
                handleChannel := func(_ context.Context,
×
2691
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2692

×
2693
                        pkScript, err := genMultiSigP2WSH(
×
2694
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2695
                        )
×
2696
                        if err != nil {
×
2697
                                return err
×
2698
                        }
×
2699

2700
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2701
                        if err != nil {
×
2702
                                return err
×
2703
                        }
×
2704

2705
                        edgePoints = append(edgePoints, EdgePoint{
×
2706
                                FundingPkScript: pkScript,
×
2707
                                OutPoint:        *op,
×
2708
                        })
×
2709

×
2710
                        return nil
×
2711
                }
2712

2713
                queryFunc := func(ctx context.Context, lastID int64,
×
2714
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2715

×
2716
                        return db.ListChannelsPaginated(
×
2717
                                ctx, sqlc.ListChannelsPaginatedParams{
×
NEW
2718
                                        Version: int16(lnwire.GossipVersion1),
×
2719
                                        ID:      lastID,
×
2720
                                        Limit:   limit,
×
2721
                                },
×
2722
                        )
×
2723
                }
×
2724

2725
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2726
                        return row.ID
×
2727
                }
×
2728

2729
                return sqldb.ExecutePaginatedQuery(
×
2730
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
2731
                        extractCursor, handleChannel,
×
2732
                )
×
2733
        }, func() {
×
2734
                edgePoints = nil
×
2735
        })
×
2736
        if err != nil {
×
2737
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2738
        }
×
2739

2740
        return edgePoints, nil
×
2741
}
2742

2743
// PruneTip returns the block height and hash of the latest block that has been
2744
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2745
// to tell if the graph is currently in sync with the current best known UTXO
2746
// state.
2747
//
2748
// NOTE: part of the V1Store interface.
2749
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2750
        var (
×
2751
                ctx       = context.TODO()
×
2752
                tipHash   chainhash.Hash
×
2753
                tipHeight uint32
×
2754
        )
×
2755
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2756
                pruneTip, err := db.GetPruneTip(ctx)
×
2757
                if errors.Is(err, sql.ErrNoRows) {
×
2758
                        return ErrGraphNeverPruned
×
2759
                } else if err != nil {
×
2760
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2761
                }
×
2762

2763
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2764
                tipHeight = uint32(pruneTip.BlockHeight)
×
2765

×
2766
                return nil
×
2767
        }, sqldb.NoOpReset)
2768
        if err != nil {
×
2769
                return nil, 0, err
×
2770
        }
×
2771

2772
        return &tipHash, tipHeight, nil
×
2773
}
2774

2775
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2776
//
2777
// NOTE: this prunes nodes across protocol versions. It will never prune the
2778
// source nodes.
2779
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2780
        db SQLQueries) ([]route.Vertex, error) {
×
2781

×
2782
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2783
        if err != nil {
×
2784
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2785
                        "nodes: %w", err)
×
2786
        }
×
2787

2788
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2789
        for i, nodeKey := range nodeKeys {
×
2790
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2791
                if err != nil {
×
2792
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2793
                                "from bytes: %w", err)
×
2794
                }
×
2795

2796
                prunedNodes[i] = pub
×
2797
        }
2798

2799
        return prunedNodes, nil
×
2800
}
2801

2802
// DisconnectBlockAtHeight is used to indicate that the block specified
2803
// by the passed height has been disconnected from the main chain. This
2804
// will "rewind" the graph back to the height below, deleting channels
2805
// that are no longer confirmed from the graph. The prune log will be
2806
// set to the last prune height valid for the remaining chain.
2807
// Channels that were removed from the graph resulting from the
2808
// disconnected block are returned.
2809
//
2810
// NOTE: part of the V1Store interface.
2811
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2812
        []*models.ChannelEdgeInfo, error) {
×
2813

×
2814
        ctx := context.TODO()
×
2815

×
2816
        var (
×
2817
                // Every channel having a ShortChannelID starting at 'height'
×
2818
                // will no longer be confirmed.
×
2819
                startShortChanID = lnwire.ShortChannelID{
×
2820
                        BlockHeight: height,
×
2821
                }
×
2822

×
2823
                // Delete everything after this height from the db up until the
×
2824
                // SCID alias range.
×
2825
                endShortChanID = aliasmgr.StartingAlias
×
2826

×
2827
                removedChans []*models.ChannelEdgeInfo
×
2828

×
2829
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2830
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2831
        )
×
2832

×
2833
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2834
                rows, err := db.GetChannelsBySCIDRange(
×
2835
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2836
                                StartScid: chanIDStart,
×
2837
                                EndScid:   chanIDEnd,
×
2838
                        },
×
2839
                )
×
2840
                if err != nil {
×
2841
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2842
                }
×
2843

2844
                if len(rows) == 0 {
×
2845
                        // No channels to disconnect, but still clean up prune
×
2846
                        // log.
×
2847
                        return db.DeletePruneLogEntriesInRange(
×
2848
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2849
                                        StartHeight: int64(height),
×
2850
                                        EndHeight: int64(
×
2851
                                                endShortChanID.BlockHeight,
×
2852
                                        ),
×
2853
                                },
×
2854
                        )
×
2855
                }
×
2856

2857
                // Batch build all channel edges for disconnection.
2858
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
2859
                        ctx, s.cfg, db, rows,
×
2860
                )
×
2861
                if err != nil {
×
2862
                        return err
×
2863
                }
×
2864

2865
                removedChans = channelEdges
×
2866

×
2867
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2868
                if err != nil {
×
2869
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2870
                }
×
2871

2872
                return db.DeletePruneLogEntriesInRange(
×
2873
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2874
                                StartHeight: int64(height),
×
2875
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2876
                        },
×
2877
                )
×
2878
        }, func() {
×
2879
                removedChans = nil
×
2880
        })
×
2881
        if err != nil {
×
2882
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2883
                        "height: %w", err)
×
2884
        }
×
2885

2886
        for _, channel := range removedChans {
×
2887
                s.rejectCache.remove(channel.ChannelID)
×
2888
                s.chanCache.remove(channel.ChannelID)
×
2889
        }
×
2890

2891
        return removedChans, nil
×
2892
}
2893

2894
// AddEdgeProof sets the proof of an existing edge in the graph database.
2895
//
2896
// NOTE: part of the V1Store interface.
2897
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2898
        proof *models.ChannelAuthProof) error {
×
2899

×
2900
        var (
×
2901
                ctx       = context.TODO()
×
2902
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2903
        )
×
2904

×
2905
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2906
                res, err := db.AddV1ChannelProof(
×
2907
                        ctx, sqlc.AddV1ChannelProofParams{
×
2908
                                Scid:              scidBytes,
×
2909
                                Node1Signature:    proof.NodeSig1Bytes,
×
2910
                                Node2Signature:    proof.NodeSig2Bytes,
×
2911
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2912
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2913
                        },
×
2914
                )
×
2915
                if err != nil {
×
2916
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2917
                }
×
2918

2919
                n, err := res.RowsAffected()
×
2920
                if err != nil {
×
2921
                        return err
×
2922
                }
×
2923

2924
                if n == 0 {
×
2925
                        return fmt.Errorf("no rows affected when adding edge "+
×
2926
                                "proof for SCID %v", scid)
×
2927
                } else if n > 1 {
×
2928
                        return fmt.Errorf("multiple rows affected when adding "+
×
2929
                                "edge proof for SCID %v: %d rows affected",
×
2930
                                scid, n)
×
2931
                }
×
2932

2933
                return nil
×
2934
        }, sqldb.NoOpReset)
2935
        if err != nil {
×
2936
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2937
        }
×
2938

2939
        return nil
×
2940
}
2941

2942
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2943
// that we can ignore channel announcements that we know to be closed without
2944
// having to validate them and fetch a block.
2945
//
2946
// NOTE: part of the V1Store interface.
2947
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2948
        var (
×
2949
                ctx     = context.TODO()
×
2950
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2951
        )
×
2952

×
2953
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2954
                return db.InsertClosedChannel(ctx, chanIDB)
×
2955
        }, sqldb.NoOpReset)
×
2956
}
2957

2958
// IsClosedScid checks whether a channel identified by the passed in scid is
2959
// closed. This helps avoid having to perform expensive validation checks.
2960
//
2961
// NOTE: part of the V1Store interface.
2962
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2963
        var (
×
2964
                ctx      = context.TODO()
×
2965
                isClosed bool
×
2966
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2967
        )
×
2968
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2969
                var err error
×
2970
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2971
                if err != nil {
×
2972
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2973
                                err)
×
2974
                }
×
2975

2976
                return nil
×
2977
        }, sqldb.NoOpReset)
2978
        if err != nil {
×
2979
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2980
                        err)
×
2981
        }
×
2982

2983
        return isClosed, nil
×
2984
}
2985

2986
// GraphSession will provide the call-back with access to a NodeTraverser
2987
// instance which can be used to perform queries against the channel graph.
2988
//
2989
// NOTE: part of the V1Store interface.
2990
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
2991
        reset func()) error {
×
2992

×
2993
        var ctx = context.TODO()
×
2994

×
2995
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2996
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2997
        }, reset)
×
2998
}
2999

3000
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
3001
// read only transaction for a consistent view of the graph.
3002
type sqlNodeTraverser struct {
3003
        db    SQLQueries
3004
        chain chainhash.Hash
3005
}
3006

3007
// A compile-time assertion to ensure that sqlNodeTraverser implements the
3008
// NodeTraverser interface.
3009
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
3010

3011
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
3012
func newSQLNodeTraverser(db SQLQueries,
3013
        chain chainhash.Hash) *sqlNodeTraverser {
×
3014

×
3015
        return &sqlNodeTraverser{
×
3016
                db:    db,
×
3017
                chain: chain,
×
3018
        }
×
3019
}
×
3020

3021
// ForEachNodeDirectedChannel calls the callback for every channel of the given
3022
// node.
3023
//
3024
// NOTE: Part of the NodeTraverser interface.
3025
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
3026
        cb func(channel *DirectedChannel) error, _ func()) error {
×
3027

×
3028
        ctx := context.TODO()
×
3029

×
3030
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
3031
}
×
3032

3033
// FetchNodeFeatures returns the features of the given node. If the node is
3034
// unknown, assume no additional features are supported.
3035
//
3036
// NOTE: Part of the NodeTraverser interface.
3037
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
3038
        *lnwire.FeatureVector, error) {
×
3039

×
3040
        ctx := context.TODO()
×
3041

×
3042
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
3043
}
×
3044

3045
// forEachNodeDirectedChannel iterates through all channels of a given
3046
// node, executing the passed callback on the directed edge representing the
3047
// channel and its incoming policy. If the node is not found, no error is
3048
// returned.
3049
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
3050
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
3051

×
3052
        toNodeCallback := func() route.Vertex {
×
3053
                return nodePub
×
3054
        }
×
3055

3056
        dbID, err := db.GetNodeIDByPubKey(
×
3057
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
NEW
3058
                        Version: int16(lnwire.GossipVersion1),
×
3059
                        PubKey:  nodePub[:],
×
3060
                },
×
3061
        )
×
3062
        if errors.Is(err, sql.ErrNoRows) {
×
3063
                return nil
×
3064
        } else if err != nil {
×
3065
                return fmt.Errorf("unable to fetch node: %w", err)
×
3066
        }
×
3067

3068
        rows, err := db.ListChannelsByNodeID(
×
3069
                ctx, sqlc.ListChannelsByNodeIDParams{
×
NEW
3070
                        Version: int16(lnwire.GossipVersion1),
×
3071
                        NodeID1: dbID,
×
3072
                },
×
3073
        )
×
3074
        if err != nil {
×
3075
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3076
        }
×
3077

3078
        // Exit early if there are no channels for this node so we don't
3079
        // do the unnecessary feature fetching.
3080
        if len(rows) == 0 {
×
3081
                return nil
×
3082
        }
×
3083

3084
        features, err := getNodeFeatures(ctx, db, dbID)
×
3085
        if err != nil {
×
3086
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3087
        }
×
3088

3089
        for _, row := range rows {
×
3090
                node1, node2, err := buildNodeVertices(
×
3091
                        row.Node1Pubkey, row.Node2Pubkey,
×
3092
                )
×
3093
                if err != nil {
×
3094
                        return fmt.Errorf("unable to build node vertices: %w",
×
3095
                                err)
×
3096
                }
×
3097

3098
                edge := buildCacheableChannelInfo(
×
3099
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
3100
                        node1, node2,
×
3101
                )
×
3102

×
3103
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3104
                if err != nil {
×
3105
                        return err
×
3106
                }
×
3107

3108
                p1, p2, err := buildCachedChanPolicies(
×
3109
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3110
                )
×
3111
                if err != nil {
×
3112
                        return err
×
3113
                }
×
3114

3115
                // Determine the outgoing and incoming policy for this
3116
                // channel and node combo.
3117
                outPolicy, inPolicy := p1, p2
×
3118
                if p1 != nil && node2 == nodePub {
×
3119
                        outPolicy, inPolicy = p2, p1
×
3120
                } else if p2 != nil && node1 != nodePub {
×
3121
                        outPolicy, inPolicy = p2, p1
×
3122
                }
×
3123

3124
                var cachedInPolicy *models.CachedEdgePolicy
×
3125
                if inPolicy != nil {
×
3126
                        cachedInPolicy = inPolicy
×
3127
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3128
                        cachedInPolicy.ToNodeFeatures = features
×
3129
                }
×
3130

3131
                directedChannel := &DirectedChannel{
×
3132
                        ChannelID:    edge.ChannelID,
×
3133
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3134
                        OtherNode:    edge.NodeKey2Bytes,
×
3135
                        Capacity:     edge.Capacity,
×
3136
                        OutPolicySet: outPolicy != nil,
×
3137
                        InPolicy:     cachedInPolicy,
×
3138
                }
×
3139
                if outPolicy != nil {
×
3140
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3141
                                directedChannel.InboundFee = fee
×
3142
                        })
×
3143
                }
3144

3145
                if nodePub == edge.NodeKey2Bytes {
×
3146
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3147
                }
×
3148

3149
                if err := cb(directedChannel); err != nil {
×
3150
                        return err
×
3151
                }
×
3152
        }
3153

3154
        return nil
×
3155
}
3156

3157
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3158
// and executes the provided callback for each node. It does so via pagination
3159
// along with batch loading of the node feature bits.
3160
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
3161
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
3162
                features *lnwire.FeatureVector) error) error {
×
3163

×
3164
        handleNode := func(_ context.Context,
×
3165
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
3166
                featureBits map[int64][]int) error {
×
3167

×
3168
                fv := lnwire.EmptyFeatureVector()
×
3169
                if features, exists := featureBits[dbNode.ID]; exists {
×
3170
                        for _, bit := range features {
×
3171
                                fv.Set(lnwire.FeatureBit(bit))
×
3172
                        }
×
3173
                }
3174

3175
                var pub route.Vertex
×
3176
                copy(pub[:], dbNode.PubKey)
×
3177

×
3178
                return processNode(dbNode.ID, pub, fv)
×
3179
        }
3180

3181
        queryFunc := func(ctx context.Context, lastID int64,
×
3182
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
3183

×
3184
                return db.ListNodeIDsAndPubKeys(
×
3185
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
NEW
3186
                                Version: int16(lnwire.GossipVersion1),
×
3187
                                ID:      lastID,
×
3188
                                Limit:   limit,
×
3189
                        },
×
3190
                )
×
3191
        }
×
3192

3193
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
3194
                return row.ID
×
3195
        }
×
3196

3197
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
3198
                return node.ID, nil
×
3199
        }
×
3200

3201
        batchQueryFunc := func(ctx context.Context,
×
3202
                nodeIDs []int64) (map[int64][]int, error) {
×
3203

×
3204
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
3205
        }
×
3206

3207
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
3208
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
3209
                batchQueryFunc, handleNode,
×
3210
        )
×
3211
}
3212

3213
// forEachNodeChannel iterates through all channels of a node, executing
3214
// the passed callback on each. The call-back is provided with the channel's
3215
// edge information, the outgoing policy and the incoming policy for the
3216
// channel and node combo.
3217
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3218
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
3219
                *models.ChannelEdgePolicy,
3220
                *models.ChannelEdgePolicy) error) error {
×
3221

×
3222
        // Get all the V1 channels for this node.
×
3223
        rows, err := db.ListChannelsByNodeID(
×
3224
                ctx, sqlc.ListChannelsByNodeIDParams{
×
NEW
3225
                        Version: int16(lnwire.GossipVersion1),
×
3226
                        NodeID1: id,
×
3227
                },
×
3228
        )
×
3229
        if err != nil {
×
3230
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3231
        }
×
3232

3233
        // Collect all the channel and policy IDs.
3234
        var (
×
3235
                chanIDs   = make([]int64, 0, len(rows))
×
3236
                policyIDs = make([]int64, 0, 2*len(rows))
×
3237
        )
×
3238
        for _, row := range rows {
×
3239
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3240

×
3241
                if row.Policy1ID.Valid {
×
3242
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3243
                }
×
3244
                if row.Policy2ID.Valid {
×
3245
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3246
                }
×
3247
        }
3248

3249
        batchData, err := batchLoadChannelData(
×
3250
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3251
        )
×
3252
        if err != nil {
×
3253
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3254
        }
×
3255

3256
        // Call the call-back for each channel and its known policies.
3257
        for _, row := range rows {
×
3258
                node1, node2, err := buildNodeVertices(
×
3259
                        row.Node1Pubkey, row.Node2Pubkey,
×
3260
                )
×
3261
                if err != nil {
×
3262
                        return fmt.Errorf("unable to build node vertices: %w",
×
3263
                                err)
×
3264
                }
×
3265

3266
                edge, err := buildEdgeInfoWithBatchData(
×
3267
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3268
                        batchData,
×
3269
                )
×
3270
                if err != nil {
×
3271
                        return fmt.Errorf("unable to build channel info: %w",
×
3272
                                err)
×
3273
                }
×
3274

3275
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3276
                if err != nil {
×
3277
                        return fmt.Errorf("unable to extract channel "+
×
3278
                                "policies: %w", err)
×
3279
                }
×
3280

3281
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3282
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3283
                )
×
3284
                if err != nil {
×
3285
                        return fmt.Errorf("unable to build channel "+
×
3286
                                "policies: %w", err)
×
3287
                }
×
3288

3289
                // Determine the outgoing and incoming policy for this
3290
                // channel and node combo.
3291
                p1ToNode := row.GraphChannel.NodeID2
×
3292
                p2ToNode := row.GraphChannel.NodeID1
×
3293
                outPolicy, inPolicy := p1, p2
×
3294
                if (p1 != nil && p1ToNode == id) ||
×
3295
                        (p2 != nil && p2ToNode != id) {
×
3296

×
3297
                        outPolicy, inPolicy = p2, p1
×
3298
                }
×
3299

3300
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3301
                        return err
×
3302
                }
×
3303
        }
3304

3305
        return nil
×
3306
}
3307

3308
// updateChanEdgePolicy upserts the channel policy info we have stored for
3309
// a channel we already know of.
3310
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3311
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3312
        error) {
×
3313

×
3314
        var (
×
3315
                node1Pub, node2Pub route.Vertex
×
3316
                isNode1            bool
×
3317
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3318
        )
×
3319

×
3320
        // Check that this edge policy refers to a channel that we already
×
3321
        // know of. We do this explicitly so that we can return the appropriate
×
3322
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3323
        // abort the transaction which would abort the entire batch.
×
3324
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3325
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3326
                        Scid:    chanIDB,
×
NEW
3327
                        Version: int16(lnwire.GossipVersion1),
×
3328
                },
×
3329
        )
×
3330
        if errors.Is(err, sql.ErrNoRows) {
×
3331
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3332
        } else if err != nil {
×
3333
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3334
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3335
        }
×
3336

3337
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3338
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3339

×
3340
        // Figure out which node this edge is from.
×
3341
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3342
        nodeID := dbChan.NodeID1
×
3343
        if !isNode1 {
×
3344
                nodeID = dbChan.NodeID2
×
3345
        }
×
3346

3347
        var (
×
3348
                inboundBase sql.NullInt64
×
3349
                inboundRate sql.NullInt64
×
3350
        )
×
3351
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3352
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3353
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3354
        })
×
3355

3356
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
NEW
3357
                Version:     int16(lnwire.GossipVersion1),
×
3358
                ChannelID:   dbChan.ID,
×
3359
                NodeID:      nodeID,
×
3360
                Timelock:    int32(edge.TimeLockDelta),
×
3361
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3362
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3363
                MinHtlcMsat: int64(edge.MinHTLC),
×
3364
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3365
                Disabled: sql.NullBool{
×
3366
                        Valid: true,
×
3367
                        Bool:  edge.IsDisabled(),
×
3368
                },
×
3369
                MaxHtlcMsat: sql.NullInt64{
×
3370
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3371
                        Int64: int64(edge.MaxHTLC),
×
3372
                },
×
3373
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3374
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3375
                InboundBaseFeeMsat:      inboundBase,
×
3376
                InboundFeeRateMilliMsat: inboundRate,
×
3377
                Signature:               edge.SigBytes,
×
3378
        })
×
3379
        if err != nil {
×
3380
                return node1Pub, node2Pub, isNode1,
×
3381
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3382
        }
×
3383

3384
        // Convert the flat extra opaque data into a map of TLV types to
3385
        // values.
3386
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3387
        if err != nil {
×
3388
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3389
                        "marshal extra opaque data: %w", err)
×
3390
        }
×
3391

3392
        // Update the channel policy's extra signed fields.
3393
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3394
        if err != nil {
×
3395
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3396
                        "policy extra TLVs: %w", err)
×
3397
        }
×
3398

3399
        return node1Pub, node2Pub, isNode1, nil
×
3400
}
3401

3402
// getNodeByPubKey attempts to look up a target node by its public key.
3403
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3404
        pubKey route.Vertex) (int64, *models.Node, error) {
×
3405

×
3406
        dbNode, err := db.GetNodeByPubKey(
×
3407
                ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
3408
                        Version: int16(lnwire.GossipVersion1),
×
3409
                        PubKey:  pubKey[:],
×
3410
                },
×
3411
        )
×
3412
        if errors.Is(err, sql.ErrNoRows) {
×
3413
                return 0, nil, ErrGraphNodeNotFound
×
3414
        } else if err != nil {
×
3415
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3416
        }
×
3417

3418
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3419
        if err != nil {
×
3420
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3421
        }
×
3422

3423
        return dbNode.ID, node, nil
×
3424
}
3425

3426
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3427
// provided parameters.
3428
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3429
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3430

×
3431
        return &models.CachedEdgeInfo{
×
3432
                ChannelID:     byteOrder.Uint64(scid),
×
3433
                NodeKey1Bytes: node1Pub,
×
3434
                NodeKey2Bytes: node2Pub,
×
3435
                Capacity:      btcutil.Amount(capacity),
×
3436
        }
×
3437
}
×
3438

3439
// buildNode constructs a Node instance from the given database node
3440
// record. The node's features, addresses and extra signed fields are also
3441
// fetched from the database and set on the node.
3442
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3443
        dbNode sqlc.GraphNode) (*models.Node, error) {
×
3444

×
3445
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3446
        if err != nil {
×
3447
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3448
                        err)
×
3449
        }
×
3450

3451
        return buildNodeWithBatchData(dbNode, data)
×
3452
}
3453

3454
// buildNodeWithBatchData builds a models.Node instance
3455
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3456
// features/addresses/extra fields, then the corresponding fields are expected
3457
// to be present in the batchNodeData.
3458
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3459
        batchData *batchNodeData) (*models.Node, error) {
×
3460

×
NEW
3461
        if dbNode.Version != int16(lnwire.GossipVersion1) {
×
3462
                return nil, fmt.Errorf("unsupported node version: %d",
×
3463
                        dbNode.Version)
×
3464
        }
×
3465

3466
        var pub [33]byte
×
3467
        copy(pub[:], dbNode.PubKey)
×
3468

×
3469
        node := &models.Node{
×
3470
                PubKeyBytes: pub,
×
3471
                Features:    lnwire.EmptyFeatureVector(),
×
3472
                LastUpdate:  time.Unix(0, 0),
×
3473
        }
×
3474

×
3475
        if len(dbNode.Signature) == 0 {
×
3476
                return node, nil
×
3477
        }
×
3478

3479
        node.HaveNodeAnnouncement = true
×
3480
        node.AuthSigBytes = dbNode.Signature
×
3481
        node.Alias = dbNode.Alias.String
×
3482
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3483

×
3484
        var err error
×
3485
        if dbNode.Color.Valid {
×
3486
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3487
                if err != nil {
×
3488
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3489
                                err)
×
3490
                }
×
3491
        }
3492

3493
        // Use preloaded features.
3494
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3495
                fv := lnwire.EmptyFeatureVector()
×
3496
                for _, bit := range features {
×
3497
                        fv.Set(lnwire.FeatureBit(bit))
×
3498
                }
×
3499
                node.Features = fv
×
3500
        }
3501

3502
        // Use preloaded addresses.
3503
        addresses, exists := batchData.addresses[dbNode.ID]
×
3504
        if exists && len(addresses) > 0 {
×
3505
                node.Addresses, err = buildNodeAddresses(addresses)
×
3506
                if err != nil {
×
3507
                        return nil, fmt.Errorf("unable to build addresses "+
×
3508
                                "for node(%d): %w", dbNode.ID, err)
×
3509
                }
×
3510
        }
3511

3512
        // Use preloaded extra fields.
3513
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3514
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3515
                if err != nil {
×
3516
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3517
                                "signed fields: %w", err)
×
3518
                }
×
3519
                if len(recs) != 0 {
×
3520
                        node.ExtraOpaqueData = recs
×
3521
                }
×
3522
        }
3523

3524
        return node, nil
×
3525
}
3526

3527
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3528
// with the preloaded data, and executes the provided callback for each node.
3529
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3530
        db SQLQueries, nodes []sqlc.GraphNode,
3531
        cb func(dbID int64, node *models.Node) error) error {
×
3532

×
3533
        // Extract node IDs for batch loading.
×
3534
        nodeIDs := make([]int64, len(nodes))
×
3535
        for i, node := range nodes {
×
3536
                nodeIDs[i] = node.ID
×
3537
        }
×
3538

3539
        // Batch load all related data for this page.
3540
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3541
        if err != nil {
×
3542
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3543
        }
×
3544

3545
        for _, dbNode := range nodes {
×
3546
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
3547
                if err != nil {
×
3548
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3549
                                dbNode.ID, err)
×
3550
                }
×
3551

3552
                if err := cb(dbNode.ID, node); err != nil {
×
3553
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3554
                                dbNode.ID, err)
×
3555
                }
×
3556
        }
3557

3558
        return nil
×
3559
}
3560

3561
// getNodeFeatures fetches the feature bits and constructs the feature vector
3562
// for a node with the given DB ID.
3563
func getNodeFeatures(ctx context.Context, db SQLQueries,
3564
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3565

×
3566
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3567
        if err != nil {
×
3568
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3569
                        nodeID, err)
×
3570
        }
×
3571

3572
        features := lnwire.EmptyFeatureVector()
×
3573
        for _, feature := range rows {
×
3574
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3575
        }
×
3576

3577
        return features, nil
×
3578
}
3579

3580
// upsertNode upserts the node record into the database. If the node already
3581
// exists, then the node's information is updated. If the node doesn't exist,
3582
// then a new node is created. The node's features, addresses and extra TLV
3583
// types are also updated. The node's DB ID is returned.
3584
func upsertNode(ctx context.Context, db SQLQueries,
3585
        node *models.Node) (int64, error) {
×
3586

×
3587
        params := sqlc.UpsertNodeParams{
×
NEW
3588
                Version: int16(lnwire.GossipVersion1),
×
3589
                PubKey:  node.PubKeyBytes[:],
×
3590
        }
×
3591

×
3592
        if node.HaveNodeAnnouncement {
×
3593
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3594
                params.Color = sqldb.SQLStrValid(EncodeHexColor(node.Color))
×
3595
                params.Alias = sqldb.SQLStrValid(node.Alias)
×
3596
                params.Signature = node.AuthSigBytes
×
3597
        }
×
3598

3599
        nodeID, err := db.UpsertNode(ctx, params)
×
3600
        if err != nil {
×
3601
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3602
                        err)
×
3603
        }
×
3604

3605
        // We can exit here if we don't have the announcement yet.
3606
        if !node.HaveNodeAnnouncement {
×
3607
                return nodeID, nil
×
3608
        }
×
3609

3610
        // Update the node's features.
3611
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3612
        if err != nil {
×
3613
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3614
        }
×
3615

3616
        // Update the node's addresses.
3617
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3618
        if err != nil {
×
3619
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3620
        }
×
3621

3622
        // Convert the flat extra opaque data into a map of TLV types to
3623
        // values.
3624
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3625
        if err != nil {
×
3626
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3627
                        err)
×
3628
        }
×
3629

3630
        // Update the node's extra signed fields.
3631
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3632
        if err != nil {
×
3633
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3634
        }
×
3635

3636
        return nodeID, nil
×
3637
}
3638

3639
// upsertNodeFeatures updates the node's features node_features table. This
3640
// includes deleting any feature bits no longer present and inserting any new
3641
// feature bits. If the feature bit does not yet exist in the features table,
3642
// then an entry is created in that table first.
3643
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3644
        features *lnwire.FeatureVector) error {
×
3645

×
3646
        // Get any existing features for the node.
×
3647
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3648
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3649
                return err
×
3650
        }
×
3651

3652
        // Copy the nodes latest set of feature bits.
3653
        newFeatures := make(map[int32]struct{})
×
3654
        if features != nil {
×
3655
                for feature := range features.Features() {
×
3656
                        newFeatures[int32(feature)] = struct{}{}
×
3657
                }
×
3658
        }
3659

3660
        // For any current feature that already exists in the DB, remove it from
3661
        // the in-memory map. For any existing feature that does not exist in
3662
        // the in-memory map, delete it from the database.
3663
        for _, feature := range existingFeatures {
×
3664
                // The feature is still present, so there are no updates to be
×
3665
                // made.
×
3666
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3667
                        delete(newFeatures, feature.FeatureBit)
×
3668
                        continue
×
3669
                }
3670

3671
                // The feature is no longer present, so we remove it from the
3672
                // database.
3673
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3674
                        NodeID:     nodeID,
×
3675
                        FeatureBit: feature.FeatureBit,
×
3676
                })
×
3677
                if err != nil {
×
3678
                        return fmt.Errorf("unable to delete node(%d) "+
×
3679
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3680
                                err)
×
3681
                }
×
3682
        }
3683

3684
        // Any remaining entries in newFeatures are new features that need to be
3685
        // added to the database for the first time.
3686
        for feature := range newFeatures {
×
3687
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3688
                        NodeID:     nodeID,
×
3689
                        FeatureBit: feature,
×
3690
                })
×
3691
                if err != nil {
×
3692
                        return fmt.Errorf("unable to insert node(%d) "+
×
3693
                                "feature(%v): %w", nodeID, feature, err)
×
3694
                }
×
3695
        }
3696

3697
        return nil
×
3698
}
3699

3700
// fetchNodeFeatures fetches the features for a node with the given public key.
3701
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3702
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3703

×
3704
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3705
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3706
                        PubKey:  nodePub[:],
×
NEW
3707
                        Version: int16(lnwire.GossipVersion1),
×
3708
                },
×
3709
        )
×
3710
        if err != nil {
×
3711
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3712
                        nodePub, err)
×
3713
        }
×
3714

3715
        features := lnwire.EmptyFeatureVector()
×
3716
        for _, bit := range rows {
×
3717
                features.Set(lnwire.FeatureBit(bit))
×
3718
        }
×
3719

3720
        return features, nil
×
3721
}
3722

3723
// dbAddressType is an enum type that represents the different address types
3724
// that we store in the node_addresses table. The address type determines how
3725
// the address is to be serialised/deserialize.
3726
type dbAddressType uint8
3727

3728
const (
3729
        addressTypeIPv4   dbAddressType = 1
3730
        addressTypeIPv6   dbAddressType = 2
3731
        addressTypeTorV2  dbAddressType = 3
3732
        addressTypeTorV3  dbAddressType = 4
3733
        addressTypeDNS    dbAddressType = 5
3734
        addressTypeOpaque dbAddressType = math.MaxInt8
3735
)
3736

3737
// collectAddressRecords collects the addresses from the provided
3738
// net.Addr slice and returns a map of dbAddressType to a slice of address
3739
// strings.
3740
func collectAddressRecords(addresses []net.Addr) (map[dbAddressType][]string,
3741
        error) {
×
3742

×
3743
        // Copy the nodes latest set of addresses.
×
3744
        newAddresses := map[dbAddressType][]string{
×
3745
                addressTypeIPv4:   {},
×
3746
                addressTypeIPv6:   {},
×
3747
                addressTypeTorV2:  {},
×
3748
                addressTypeTorV3:  {},
×
3749
                addressTypeDNS:    {},
×
3750
                addressTypeOpaque: {},
×
3751
        }
×
3752
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3753
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3754
        }
×
3755

3756
        for _, address := range addresses {
×
3757
                switch addr := address.(type) {
×
3758
                case *net.TCPAddr:
×
3759
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3760
                                addAddr(addressTypeIPv4, addr)
×
3761
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3762
                                addAddr(addressTypeIPv6, addr)
×
3763
                        } else {
×
3764
                                return nil, fmt.Errorf("unhandled IP "+
×
3765
                                        "address: %v", addr)
×
3766
                        }
×
3767

3768
                case *tor.OnionAddr:
×
3769
                        switch len(addr.OnionService) {
×
3770
                        case tor.V2Len:
×
3771
                                addAddr(addressTypeTorV2, addr)
×
3772
                        case tor.V3Len:
×
3773
                                addAddr(addressTypeTorV3, addr)
×
3774
                        default:
×
3775
                                return nil, fmt.Errorf("invalid length for " +
×
3776
                                        "a tor address")
×
3777
                        }
3778

3779
                case *lnwire.DNSAddress:
×
3780
                        addAddr(addressTypeDNS, addr)
×
3781

3782
                case *lnwire.OpaqueAddrs:
×
3783
                        addAddr(addressTypeOpaque, addr)
×
3784

3785
                default:
×
3786
                        return nil, fmt.Errorf("unhandled address type: %T",
×
3787
                                addr)
×
3788
                }
3789
        }
3790

3791
        return newAddresses, nil
×
3792
}
3793

3794
// upsertNodeAddresses updates the node's addresses in the database. This
3795
// includes deleting any existing addresses and inserting the new set of
3796
// addresses. The deletion is necessary since the ordering of the addresses may
3797
// change, and we need to ensure that the database reflects the latest set of
3798
// addresses so that at the time of reconstructing the node announcement, the
3799
// order is preserved and the signature over the message remains valid.
3800
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3801
        addresses []net.Addr) error {
×
3802

×
3803
        // Delete any existing addresses for the node. This is required since
×
3804
        // even if the new set of addresses is the same, the ordering may have
×
3805
        // changed for a given address type.
×
3806
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3807
        if err != nil {
×
3808
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3809
                        nodeID, err)
×
3810
        }
×
3811

3812
        newAddresses, err := collectAddressRecords(addresses)
×
3813
        if err != nil {
×
3814
                return err
×
3815
        }
×
3816

3817
        // Any remaining entries in newAddresses are new addresses that need to
3818
        // be added to the database for the first time.
3819
        for addrType, addrList := range newAddresses {
×
3820
                for position, addr := range addrList {
×
3821
                        err := db.UpsertNodeAddress(
×
3822
                                ctx, sqlc.UpsertNodeAddressParams{
×
3823
                                        NodeID:   nodeID,
×
3824
                                        Type:     int16(addrType),
×
3825
                                        Address:  addr,
×
3826
                                        Position: int32(position),
×
3827
                                },
×
3828
                        )
×
3829
                        if err != nil {
×
3830
                                return fmt.Errorf("unable to insert "+
×
3831
                                        "node(%d) address(%v): %w", nodeID,
×
3832
                                        addr, err)
×
3833
                        }
×
3834
                }
3835
        }
3836

3837
        return nil
×
3838
}
3839

3840
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3841
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
3842
        error) {
×
3843

×
3844
        // GetNodeAddresses ensures that the addresses for a given type are
×
3845
        // returned in the same order as they were inserted.
×
3846
        rows, err := db.GetNodeAddresses(ctx, id)
×
3847
        if err != nil {
×
3848
                return nil, err
×
3849
        }
×
3850

3851
        addresses := make([]net.Addr, 0, len(rows))
×
3852
        for _, row := range rows {
×
3853
                address := row.Address
×
3854

×
3855
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
3856
                if err != nil {
×
3857
                        return nil, fmt.Errorf("unable to parse address "+
×
3858
                                "for node(%d): %v: %w", id, address, err)
×
3859
                }
×
3860

3861
                addresses = append(addresses, addr)
×
3862
        }
3863

3864
        // If we have no addresses, then we'll return nil instead of an
3865
        // empty slice.
3866
        if len(addresses) == 0 {
×
3867
                addresses = nil
×
3868
        }
×
3869

3870
        return addresses, nil
×
3871
}
3872

3873
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3874
// database. This includes updating any existing types, inserting any new types,
3875
// and deleting any types that are no longer present.
3876
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3877
        nodeID int64, extraFields map[uint64][]byte) error {
×
3878

×
3879
        // Get any existing extra signed fields for the node.
×
3880
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3881
        if err != nil {
×
3882
                return err
×
3883
        }
×
3884

3885
        // Make a lookup map of the existing field types so that we can use it
3886
        // to keep track of any fields we should delete.
3887
        m := make(map[uint64]bool)
×
3888
        for _, field := range existingFields {
×
3889
                m[uint64(field.Type)] = true
×
3890
        }
×
3891

3892
        // For all the new fields, we'll upsert them and remove them from the
3893
        // map of existing fields.
3894
        for tlvType, value := range extraFields {
×
3895
                err = db.UpsertNodeExtraType(
×
3896
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3897
                                NodeID: nodeID,
×
3898
                                Type:   int64(tlvType),
×
3899
                                Value:  value,
×
3900
                        },
×
3901
                )
×
3902
                if err != nil {
×
3903
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3904
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3905
                }
×
3906

3907
                // Remove the field from the map of existing fields if it was
3908
                // present.
3909
                delete(m, tlvType)
×
3910
        }
3911

3912
        // For all the fields that are left in the map of existing fields, we'll
3913
        // delete them as they are no longer present in the new set of fields.
3914
        for tlvType := range m {
×
3915
                err = db.DeleteExtraNodeType(
×
3916
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3917
                                NodeID: nodeID,
×
3918
                                Type:   int64(tlvType),
×
3919
                        },
×
3920
                )
×
3921
                if err != nil {
×
3922
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3923
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3924
                }
×
3925
        }
3926

3927
        return nil
×
3928
}
3929

3930
// srcNodeInfo holds the information about the source node of the graph.
3931
type srcNodeInfo struct {
3932
        // id is the DB level ID of the source node entry in the "nodes" table.
3933
        id int64
3934

3935
        // pub is the public key of the source node.
3936
        pub route.Vertex
3937
}
3938

3939
// sourceNode returns the DB node ID and pub key of the source node for the
3940
// specified protocol version.
3941
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
NEW
3942
        version lnwire.GossipVersion) (int64, route.Vertex, error) {
×
3943

×
3944
        s.srcNodeMu.Lock()
×
3945
        defer s.srcNodeMu.Unlock()
×
3946

×
3947
        // If we already have the source node ID and pub key cached, then
×
3948
        // return them.
×
3949
        if info, ok := s.srcNodes[version]; ok {
×
3950
                return info.id, info.pub, nil
×
3951
        }
×
3952

3953
        var pubKey route.Vertex
×
3954

×
3955
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3956
        if err != nil {
×
3957
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3958
                        err)
×
3959
        }
×
3960

3961
        if len(nodes) == 0 {
×
3962
                return 0, pubKey, ErrSourceNodeNotSet
×
3963
        } else if len(nodes) > 1 {
×
3964
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3965
                        "protocol %s found", version)
×
3966
        }
×
3967

3968
        copy(pubKey[:], nodes[0].PubKey)
×
3969

×
3970
        s.srcNodes[version] = &srcNodeInfo{
×
3971
                id:  nodes[0].NodeID,
×
3972
                pub: pubKey,
×
3973
        }
×
3974

×
3975
        return nodes[0].NodeID, pubKey, nil
×
3976
}
3977

3978
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3979
// This then produces a map from TLV type to value. If the input is not a
3980
// valid TLV stream, then an error is returned.
3981
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3982
        r := bytes.NewReader(data)
×
3983

×
3984
        tlvStream, err := tlv.NewStream()
×
3985
        if err != nil {
×
3986
                return nil, err
×
3987
        }
×
3988

3989
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3990
        // pass it into the P2P decoding variant.
3991
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3992
        if err != nil {
×
3993
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3994
        }
×
3995
        if len(parsedTypes) == 0 {
×
3996
                return nil, nil
×
3997
        }
×
3998

3999
        records := make(map[uint64][]byte)
×
4000
        for k, v := range parsedTypes {
×
4001
                records[uint64(k)] = v
×
4002
        }
×
4003

4004
        return records, nil
×
4005
}
4006

4007
// insertChannel inserts a new channel record into the database.
4008
func insertChannel(ctx context.Context, db SQLQueries,
4009
        edge *models.ChannelEdgeInfo) error {
×
4010

×
4011
        // Make sure that at least a "shell" entry for each node is present in
×
4012
        // the nodes table.
×
4013
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
4014
        if err != nil {
×
4015
                return fmt.Errorf("unable to create shell node: %w", err)
×
4016
        }
×
4017

4018
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
4019
        if err != nil {
×
4020
                return fmt.Errorf("unable to create shell node: %w", err)
×
4021
        }
×
4022

4023
        var capacity sql.NullInt64
×
4024
        if edge.Capacity != 0 {
×
4025
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
4026
        }
×
4027

4028
        createParams := sqlc.CreateChannelParams{
×
NEW
4029
                Version:     int16(lnwire.GossipVersion1),
×
4030
                Scid:        channelIDToBytes(edge.ChannelID),
×
4031
                NodeID1:     node1DBID,
×
4032
                NodeID2:     node2DBID,
×
4033
                Outpoint:    edge.ChannelPoint.String(),
×
4034
                Capacity:    capacity,
×
4035
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
4036
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
4037
        }
×
4038

×
4039
        if edge.AuthProof != nil {
×
4040
                proof := edge.AuthProof
×
4041

×
4042
                createParams.Node1Signature = proof.NodeSig1Bytes
×
4043
                createParams.Node2Signature = proof.NodeSig2Bytes
×
4044
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
4045
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
4046
        }
×
4047

4048
        // Insert the new channel record.
4049
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
4050
        if err != nil {
×
4051
                return err
×
4052
        }
×
4053

4054
        // Insert any channel features.
4055
        for feature := range edge.Features.Features() {
×
4056
                err = db.InsertChannelFeature(
×
4057
                        ctx, sqlc.InsertChannelFeatureParams{
×
4058
                                ChannelID:  dbChanID,
×
4059
                                FeatureBit: int32(feature),
×
4060
                        },
×
4061
                )
×
4062
                if err != nil {
×
4063
                        return fmt.Errorf("unable to insert channel(%d) "+
×
4064
                                "feature(%v): %w", dbChanID, feature, err)
×
4065
                }
×
4066
        }
4067

4068
        // Finally, insert any extra TLV fields in the channel announcement.
4069
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
4070
        if err != nil {
×
4071
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
4072
                        err)
×
4073
        }
×
4074

4075
        for tlvType, value := range extra {
×
4076
                err := db.UpsertChannelExtraType(
×
4077
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
4078
                                ChannelID: dbChanID,
×
4079
                                Type:      int64(tlvType),
×
4080
                                Value:     value,
×
4081
                        },
×
4082
                )
×
4083
                if err != nil {
×
4084
                        return fmt.Errorf("unable to upsert channel(%d) "+
×
4085
                                "extra signed field(%v): %w", edge.ChannelID,
×
4086
                                tlvType, err)
×
4087
                }
×
4088
        }
4089

4090
        return nil
×
4091
}
4092

4093
// maybeCreateShellNode checks if a shell node entry exists for the
4094
// given public key. If it does not exist, then a new shell node entry is
4095
// created. The ID of the node is returned. A shell node only has a protocol
4096
// version and public key persisted.
4097
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
4098
        pubKey route.Vertex) (int64, error) {
×
4099

×
4100
        dbNode, err := db.GetNodeByPubKey(
×
4101
                ctx, sqlc.GetNodeByPubKeyParams{
×
4102
                        PubKey:  pubKey[:],
×
NEW
4103
                        Version: int16(lnwire.GossipVersion1),
×
4104
                },
×
4105
        )
×
4106
        // The node exists. Return the ID.
×
4107
        if err == nil {
×
4108
                return dbNode.ID, nil
×
4109
        } else if !errors.Is(err, sql.ErrNoRows) {
×
4110
                return 0, err
×
4111
        }
×
4112

4113
        // Otherwise, the node does not exist, so we create a shell entry for
4114
        // it.
4115
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
NEW
4116
                Version: int16(lnwire.GossipVersion1),
×
4117
                PubKey:  pubKey[:],
×
4118
        })
×
4119
        if err != nil {
×
4120
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4121
        }
×
4122

4123
        return id, nil
×
4124
}
4125

4126
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4127
// the database. This includes deleting any existing types and then inserting
4128
// the new types.
4129
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4130
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4131

×
4132
        // Delete all existing extra signed fields for the channel policy.
×
4133
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4134
        if err != nil {
×
4135
                return fmt.Errorf("unable to delete "+
×
4136
                        "existing policy extra signed fields for policy %d: %w",
×
4137
                        chanPolicyID, err)
×
4138
        }
×
4139

4140
        // Insert all new extra signed fields for the channel policy.
4141
        for tlvType, value := range extraFields {
×
4142
                err = db.UpsertChanPolicyExtraType(
×
4143
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
4144
                                ChannelPolicyID: chanPolicyID,
×
4145
                                Type:            int64(tlvType),
×
4146
                                Value:           value,
×
4147
                        },
×
4148
                )
×
4149
                if err != nil {
×
4150
                        return fmt.Errorf("unable to insert "+
×
4151
                                "channel_policy(%d) extra signed field(%v): %w",
×
4152
                                chanPolicyID, tlvType, err)
×
4153
                }
×
4154
        }
4155

4156
        return nil
×
4157
}
4158

4159
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4160
// provided dbChanRow and also fetches any other required information
4161
// to construct the edge info.
4162
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
4163
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
4164
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4165

×
4166
        data, err := batchLoadChannelData(
×
4167
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
4168
        )
×
4169
        if err != nil {
×
4170
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
4171
                        err)
×
4172
        }
×
4173

4174
        return buildEdgeInfoWithBatchData(
×
4175
                cfg.ChainHash, dbChan, node1, node2, data,
×
4176
        )
×
4177
}
4178

4179
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
4180
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
4181
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
4182
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
4183

×
NEW
4184
        if dbChan.Version != int16(lnwire.GossipVersion1) {
×
4185
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4186
                        dbChan.Version)
×
4187
        }
×
4188

4189
        // Use pre-loaded features and extras types.
4190
        fv := lnwire.EmptyFeatureVector()
×
4191
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
4192
                for _, bit := range features {
×
4193
                        fv.Set(lnwire.FeatureBit(bit))
×
4194
                }
×
4195
        }
4196

4197
        var extras map[uint64][]byte
×
4198
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
4199
        if exists {
×
4200
                extras = channelExtras
×
4201
        } else {
×
4202
                extras = make(map[uint64][]byte)
×
4203
        }
×
4204

4205
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4206
        if err != nil {
×
4207
                return nil, err
×
4208
        }
×
4209

4210
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4211
        if err != nil {
×
4212
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4213
                        "fields: %w", err)
×
4214
        }
×
4215
        if recs == nil {
×
4216
                recs = make([]byte, 0)
×
4217
        }
×
4218

4219
        var btcKey1, btcKey2 route.Vertex
×
4220
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4221
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4222

×
4223
        channel := &models.ChannelEdgeInfo{
×
4224
                ChainHash:        chain,
×
4225
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4226
                NodeKey1Bytes:    node1,
×
4227
                NodeKey2Bytes:    node2,
×
4228
                BitcoinKey1Bytes: btcKey1,
×
4229
                BitcoinKey2Bytes: btcKey2,
×
4230
                ChannelPoint:     *op,
×
4231
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4232
                Features:         fv,
×
4233
                ExtraOpaqueData:  recs,
×
4234
        }
×
4235

×
4236
        // We always set all the signatures at the same time, so we can
×
4237
        // safely check if one signature is present to determine if we have the
×
4238
        // rest of the signatures for the auth proof.
×
4239
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4240
                channel.AuthProof = &models.ChannelAuthProof{
×
4241
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4242
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4243
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4244
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4245
                }
×
4246
        }
×
4247

4248
        return channel, nil
×
4249
}
4250

4251
// buildNodeVertices is a helper that converts raw node public keys
4252
// into route.Vertex instances.
4253
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4254
        route.Vertex, error) {
×
4255

×
4256
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4257
        if err != nil {
×
4258
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4259
                        "create vertex from node1 pubkey: %w", err)
×
4260
        }
×
4261

4262
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4263
        if err != nil {
×
4264
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4265
                        "create vertex from node2 pubkey: %w", err)
×
4266
        }
×
4267

4268
        return node1Vertex, node2Vertex, nil
×
4269
}
4270

4271
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4272
// retrieves all the extra info required to build the complete
4273
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4274
// the provided sqlc.GraphChannelPolicy records are nil.
4275
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
4276
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4277
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
4278
        *models.ChannelEdgePolicy, error) {
×
4279

×
4280
        if dbPol1 == nil && dbPol2 == nil {
×
4281
                return nil, nil, nil
×
4282
        }
×
4283

4284
        var policyIDs = make([]int64, 0, 2)
×
4285
        if dbPol1 != nil {
×
4286
                policyIDs = append(policyIDs, dbPol1.ID)
×
4287
        }
×
4288
        if dbPol2 != nil {
×
4289
                policyIDs = append(policyIDs, dbPol2.ID)
×
4290
        }
×
4291

4292
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4293
        if err != nil {
×
4294
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4295
                        "data: %w", err)
×
4296
        }
×
4297

4298
        pol1, err := buildChanPolicyWithBatchData(
×
4299
                dbPol1, channelID, node2, batchData,
×
4300
        )
×
4301
        if err != nil {
×
4302
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4303
        }
×
4304

4305
        pol2, err := buildChanPolicyWithBatchData(
×
4306
                dbPol2, channelID, node1, batchData,
×
4307
        )
×
4308
        if err != nil {
×
4309
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4310
        }
×
4311

4312
        return pol1, pol2, nil
×
4313
}
4314

4315
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4316
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4317
// then nil is returned for it.
4318
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4319
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4320
        *models.CachedEdgePolicy, error) {
×
4321

×
4322
        var p1, p2 *models.CachedEdgePolicy
×
4323
        if dbPol1 != nil {
×
4324
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
4325
                if err != nil {
×
4326
                        return nil, nil, err
×
4327
                }
×
4328

4329
                p1 = models.NewCachedPolicy(policy1)
×
4330
        }
4331
        if dbPol2 != nil {
×
4332
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4333
                if err != nil {
×
4334
                        return nil, nil, err
×
4335
                }
×
4336

4337
                p2 = models.NewCachedPolicy(policy2)
×
4338
        }
4339

4340
        return p1, p2, nil
×
4341
}
4342

4343
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4344
// provided sqlc.GraphChannelPolicy and other required information.
4345
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4346
        extras map[uint64][]byte,
4347
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4348

×
4349
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4350
        if err != nil {
×
4351
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4352
                        "fields: %w", err)
×
4353
        }
×
4354

4355
        var inboundFee fn.Option[lnwire.Fee]
×
4356
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4357
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4358

×
4359
                inboundFee = fn.Some(lnwire.Fee{
×
4360
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4361
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4362
                })
×
4363
        }
×
4364

4365
        return &models.ChannelEdgePolicy{
×
4366
                SigBytes:  dbPolicy.Signature,
×
4367
                ChannelID: channelID,
×
4368
                LastUpdate: time.Unix(
×
4369
                        dbPolicy.LastUpdate.Int64, 0,
×
4370
                ),
×
4371
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4372
                        dbPolicy.MessageFlags,
×
4373
                ),
×
4374
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4375
                        dbPolicy.ChannelFlags,
×
4376
                ),
×
4377
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4378
                MinHTLC: lnwire.MilliSatoshi(
×
4379
                        dbPolicy.MinHtlcMsat,
×
4380
                ),
×
4381
                MaxHTLC: lnwire.MilliSatoshi(
×
4382
                        dbPolicy.MaxHtlcMsat.Int64,
×
4383
                ),
×
4384
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4385
                        dbPolicy.BaseFeeMsat,
×
4386
                ),
×
4387
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4388
                ToNode:                    toNode,
×
4389
                InboundFee:                inboundFee,
×
4390
                ExtraOpaqueData:           recs,
×
4391
        }, nil
×
4392
}
4393

4394
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4395
// row which is expected to be a sqlc type that contains channel policy
4396
// information. It returns two policies, which may be nil if the policy
4397
// information is not present in the row.
4398
//
4399
//nolint:ll,dupl,funlen
4400
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4401
        *sqlc.GraphChannelPolicy, error) {
×
4402

×
4403
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4404
        switch r := row.(type) {
×
4405
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4406
                if r.Policy1Timelock.Valid {
×
4407
                        policy1 = &sqlc.GraphChannelPolicy{
×
4408
                                Timelock:                r.Policy1Timelock.Int32,
×
4409
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4410
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4411
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4412
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4413
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4414
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4415
                                Disabled:                r.Policy1Disabled,
×
4416
                                MessageFlags:            r.Policy1MessageFlags,
×
4417
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4418
                        }
×
4419
                }
×
4420
                if r.Policy2Timelock.Valid {
×
4421
                        policy2 = &sqlc.GraphChannelPolicy{
×
4422
                                Timelock:                r.Policy2Timelock.Int32,
×
4423
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4424
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4425
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4426
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4427
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4428
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4429
                                Disabled:                r.Policy2Disabled,
×
4430
                                MessageFlags:            r.Policy2MessageFlags,
×
4431
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4432
                        }
×
4433
                }
×
4434

4435
                return policy1, policy2, nil
×
4436

4437
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4438
                if r.Policy1ID.Valid {
×
4439
                        policy1 = &sqlc.GraphChannelPolicy{
×
4440
                                ID:                      r.Policy1ID.Int64,
×
4441
                                Version:                 r.Policy1Version.Int16,
×
4442
                                ChannelID:               r.GraphChannel.ID,
×
4443
                                NodeID:                  r.Policy1NodeID.Int64,
×
4444
                                Timelock:                r.Policy1Timelock.Int32,
×
4445
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4446
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4447
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4448
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4449
                                LastUpdate:              r.Policy1LastUpdate,
×
4450
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4451
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4452
                                Disabled:                r.Policy1Disabled,
×
4453
                                MessageFlags:            r.Policy1MessageFlags,
×
4454
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4455
                                Signature:               r.Policy1Signature,
×
4456
                        }
×
4457
                }
×
4458
                if r.Policy2ID.Valid {
×
4459
                        policy2 = &sqlc.GraphChannelPolicy{
×
4460
                                ID:                      r.Policy2ID.Int64,
×
4461
                                Version:                 r.Policy2Version.Int16,
×
4462
                                ChannelID:               r.GraphChannel.ID,
×
4463
                                NodeID:                  r.Policy2NodeID.Int64,
×
4464
                                Timelock:                r.Policy2Timelock.Int32,
×
4465
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4466
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4467
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4468
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4469
                                LastUpdate:              r.Policy2LastUpdate,
×
4470
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4471
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4472
                                Disabled:                r.Policy2Disabled,
×
4473
                                MessageFlags:            r.Policy2MessageFlags,
×
4474
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4475
                                Signature:               r.Policy2Signature,
×
4476
                        }
×
4477
                }
×
4478

4479
                return policy1, policy2, nil
×
4480

4481
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4482
                if r.Policy1ID.Valid {
×
4483
                        policy1 = &sqlc.GraphChannelPolicy{
×
4484
                                ID:                      r.Policy1ID.Int64,
×
4485
                                Version:                 r.Policy1Version.Int16,
×
4486
                                ChannelID:               r.GraphChannel.ID,
×
4487
                                NodeID:                  r.Policy1NodeID.Int64,
×
4488
                                Timelock:                r.Policy1Timelock.Int32,
×
4489
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4490
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4491
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4492
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4493
                                LastUpdate:              r.Policy1LastUpdate,
×
4494
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4495
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4496
                                Disabled:                r.Policy1Disabled,
×
4497
                                MessageFlags:            r.Policy1MessageFlags,
×
4498
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4499
                                Signature:               r.Policy1Signature,
×
4500
                        }
×
4501
                }
×
4502
                if r.Policy2ID.Valid {
×
4503
                        policy2 = &sqlc.GraphChannelPolicy{
×
4504
                                ID:                      r.Policy2ID.Int64,
×
4505
                                Version:                 r.Policy2Version.Int16,
×
4506
                                ChannelID:               r.GraphChannel.ID,
×
4507
                                NodeID:                  r.Policy2NodeID.Int64,
×
4508
                                Timelock:                r.Policy2Timelock.Int32,
×
4509
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4510
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4511
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4512
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4513
                                LastUpdate:              r.Policy2LastUpdate,
×
4514
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4515
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4516
                                Disabled:                r.Policy2Disabled,
×
4517
                                MessageFlags:            r.Policy2MessageFlags,
×
4518
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4519
                                Signature:               r.Policy2Signature,
×
4520
                        }
×
4521
                }
×
4522

4523
                return policy1, policy2, nil
×
4524

4525
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4526
                if r.Policy1ID.Valid {
×
4527
                        policy1 = &sqlc.GraphChannelPolicy{
×
4528
                                ID:                      r.Policy1ID.Int64,
×
4529
                                Version:                 r.Policy1Version.Int16,
×
4530
                                ChannelID:               r.GraphChannel.ID,
×
4531
                                NodeID:                  r.Policy1NodeID.Int64,
×
4532
                                Timelock:                r.Policy1Timelock.Int32,
×
4533
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4534
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4535
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4536
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4537
                                LastUpdate:              r.Policy1LastUpdate,
×
4538
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4539
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4540
                                Disabled:                r.Policy1Disabled,
×
4541
                                MessageFlags:            r.Policy1MessageFlags,
×
4542
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4543
                                Signature:               r.Policy1Signature,
×
4544
                        }
×
4545
                }
×
4546
                if r.Policy2ID.Valid {
×
4547
                        policy2 = &sqlc.GraphChannelPolicy{
×
4548
                                ID:                      r.Policy2ID.Int64,
×
4549
                                Version:                 r.Policy2Version.Int16,
×
4550
                                ChannelID:               r.GraphChannel.ID,
×
4551
                                NodeID:                  r.Policy2NodeID.Int64,
×
4552
                                Timelock:                r.Policy2Timelock.Int32,
×
4553
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4554
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4555
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4556
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4557
                                LastUpdate:              r.Policy2LastUpdate,
×
4558
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4559
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4560
                                Disabled:                r.Policy2Disabled,
×
4561
                                MessageFlags:            r.Policy2MessageFlags,
×
4562
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4563
                                Signature:               r.Policy2Signature,
×
4564
                        }
×
4565
                }
×
4566

4567
                return policy1, policy2, nil
×
4568

4569
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4570
                if r.Policy1ID.Valid {
×
4571
                        policy1 = &sqlc.GraphChannelPolicy{
×
4572
                                ID:                      r.Policy1ID.Int64,
×
4573
                                Version:                 r.Policy1Version.Int16,
×
4574
                                ChannelID:               r.GraphChannel.ID,
×
4575
                                NodeID:                  r.Policy1NodeID.Int64,
×
4576
                                Timelock:                r.Policy1Timelock.Int32,
×
4577
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4578
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4579
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4580
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4581
                                LastUpdate:              r.Policy1LastUpdate,
×
4582
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4583
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4584
                                Disabled:                r.Policy1Disabled,
×
4585
                                MessageFlags:            r.Policy1MessageFlags,
×
4586
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4587
                                Signature:               r.Policy1Signature,
×
4588
                        }
×
4589
                }
×
4590
                if r.Policy2ID.Valid {
×
4591
                        policy2 = &sqlc.GraphChannelPolicy{
×
4592
                                ID:                      r.Policy2ID.Int64,
×
4593
                                Version:                 r.Policy2Version.Int16,
×
4594
                                ChannelID:               r.GraphChannel.ID,
×
4595
                                NodeID:                  r.Policy2NodeID.Int64,
×
4596
                                Timelock:                r.Policy2Timelock.Int32,
×
4597
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4598
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4599
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4600
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4601
                                LastUpdate:              r.Policy2LastUpdate,
×
4602
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4603
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4604
                                Disabled:                r.Policy2Disabled,
×
4605
                                MessageFlags:            r.Policy2MessageFlags,
×
4606
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4607
                                Signature:               r.Policy2Signature,
×
4608
                        }
×
4609
                }
×
4610

4611
                return policy1, policy2, nil
×
4612

4613
        case sqlc.ListChannelsForNodeIDsRow:
×
4614
                if r.Policy1ID.Valid {
×
4615
                        policy1 = &sqlc.GraphChannelPolicy{
×
4616
                                ID:                      r.Policy1ID.Int64,
×
4617
                                Version:                 r.Policy1Version.Int16,
×
4618
                                ChannelID:               r.GraphChannel.ID,
×
4619
                                NodeID:                  r.Policy1NodeID.Int64,
×
4620
                                Timelock:                r.Policy1Timelock.Int32,
×
4621
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4622
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4623
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4624
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4625
                                LastUpdate:              r.Policy1LastUpdate,
×
4626
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4627
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4628
                                Disabled:                r.Policy1Disabled,
×
4629
                                MessageFlags:            r.Policy1MessageFlags,
×
4630
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4631
                                Signature:               r.Policy1Signature,
×
4632
                        }
×
4633
                }
×
4634
                if r.Policy2ID.Valid {
×
4635
                        policy2 = &sqlc.GraphChannelPolicy{
×
4636
                                ID:                      r.Policy2ID.Int64,
×
4637
                                Version:                 r.Policy2Version.Int16,
×
4638
                                ChannelID:               r.GraphChannel.ID,
×
4639
                                NodeID:                  r.Policy2NodeID.Int64,
×
4640
                                Timelock:                r.Policy2Timelock.Int32,
×
4641
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4642
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4643
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4644
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4645
                                LastUpdate:              r.Policy2LastUpdate,
×
4646
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4647
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4648
                                Disabled:                r.Policy2Disabled,
×
4649
                                MessageFlags:            r.Policy2MessageFlags,
×
4650
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4651
                                Signature:               r.Policy2Signature,
×
4652
                        }
×
4653
                }
×
4654

4655
                return policy1, policy2, nil
×
4656

4657
        case sqlc.ListChannelsByNodeIDRow:
×
4658
                if r.Policy1ID.Valid {
×
4659
                        policy1 = &sqlc.GraphChannelPolicy{
×
4660
                                ID:                      r.Policy1ID.Int64,
×
4661
                                Version:                 r.Policy1Version.Int16,
×
4662
                                ChannelID:               r.GraphChannel.ID,
×
4663
                                NodeID:                  r.Policy1NodeID.Int64,
×
4664
                                Timelock:                r.Policy1Timelock.Int32,
×
4665
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4666
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4667
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4668
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4669
                                LastUpdate:              r.Policy1LastUpdate,
×
4670
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4671
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4672
                                Disabled:                r.Policy1Disabled,
×
4673
                                MessageFlags:            r.Policy1MessageFlags,
×
4674
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4675
                                Signature:               r.Policy1Signature,
×
4676
                        }
×
4677
                }
×
4678
                if r.Policy2ID.Valid {
×
4679
                        policy2 = &sqlc.GraphChannelPolicy{
×
4680
                                ID:                      r.Policy2ID.Int64,
×
4681
                                Version:                 r.Policy2Version.Int16,
×
4682
                                ChannelID:               r.GraphChannel.ID,
×
4683
                                NodeID:                  r.Policy2NodeID.Int64,
×
4684
                                Timelock:                r.Policy2Timelock.Int32,
×
4685
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4686
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4687
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4688
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4689
                                LastUpdate:              r.Policy2LastUpdate,
×
4690
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4691
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4692
                                Disabled:                r.Policy2Disabled,
×
4693
                                MessageFlags:            r.Policy2MessageFlags,
×
4694
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4695
                                Signature:               r.Policy2Signature,
×
4696
                        }
×
4697
                }
×
4698

4699
                return policy1, policy2, nil
×
4700

4701
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4702
                if r.Policy1ID.Valid {
×
4703
                        policy1 = &sqlc.GraphChannelPolicy{
×
4704
                                ID:                      r.Policy1ID.Int64,
×
4705
                                Version:                 r.Policy1Version.Int16,
×
4706
                                ChannelID:               r.GraphChannel.ID,
×
4707
                                NodeID:                  r.Policy1NodeID.Int64,
×
4708
                                Timelock:                r.Policy1Timelock.Int32,
×
4709
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4710
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4711
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4712
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4713
                                LastUpdate:              r.Policy1LastUpdate,
×
4714
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4715
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4716
                                Disabled:                r.Policy1Disabled,
×
4717
                                MessageFlags:            r.Policy1MessageFlags,
×
4718
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4719
                                Signature:               r.Policy1Signature,
×
4720
                        }
×
4721
                }
×
4722
                if r.Policy2ID.Valid {
×
4723
                        policy2 = &sqlc.GraphChannelPolicy{
×
4724
                                ID:                      r.Policy2ID.Int64,
×
4725
                                Version:                 r.Policy2Version.Int16,
×
4726
                                ChannelID:               r.GraphChannel.ID,
×
4727
                                NodeID:                  r.Policy2NodeID.Int64,
×
4728
                                Timelock:                r.Policy2Timelock.Int32,
×
4729
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4730
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4731
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4732
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4733
                                LastUpdate:              r.Policy2LastUpdate,
×
4734
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4735
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4736
                                Disabled:                r.Policy2Disabled,
×
4737
                                MessageFlags:            r.Policy2MessageFlags,
×
4738
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4739
                                Signature:               r.Policy2Signature,
×
4740
                        }
×
4741
                }
×
4742

4743
                return policy1, policy2, nil
×
4744

4745
        case sqlc.GetChannelsByIDsRow:
×
4746
                if r.Policy1ID.Valid {
×
4747
                        policy1 = &sqlc.GraphChannelPolicy{
×
4748
                                ID:                      r.Policy1ID.Int64,
×
4749
                                Version:                 r.Policy1Version.Int16,
×
4750
                                ChannelID:               r.GraphChannel.ID,
×
4751
                                NodeID:                  r.Policy1NodeID.Int64,
×
4752
                                Timelock:                r.Policy1Timelock.Int32,
×
4753
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4754
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4755
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4756
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4757
                                LastUpdate:              r.Policy1LastUpdate,
×
4758
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4759
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4760
                                Disabled:                r.Policy1Disabled,
×
4761
                                MessageFlags:            r.Policy1MessageFlags,
×
4762
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4763
                                Signature:               r.Policy1Signature,
×
4764
                        }
×
4765
                }
×
4766
                if r.Policy2ID.Valid {
×
4767
                        policy2 = &sqlc.GraphChannelPolicy{
×
4768
                                ID:                      r.Policy2ID.Int64,
×
4769
                                Version:                 r.Policy2Version.Int16,
×
4770
                                ChannelID:               r.GraphChannel.ID,
×
4771
                                NodeID:                  r.Policy2NodeID.Int64,
×
4772
                                Timelock:                r.Policy2Timelock.Int32,
×
4773
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4774
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4775
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4776
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4777
                                LastUpdate:              r.Policy2LastUpdate,
×
4778
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4779
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4780
                                Disabled:                r.Policy2Disabled,
×
4781
                                MessageFlags:            r.Policy2MessageFlags,
×
4782
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4783
                                Signature:               r.Policy2Signature,
×
4784
                        }
×
4785
                }
×
4786

4787
                return policy1, policy2, nil
×
4788

4789
        default:
×
4790
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4791
                        "extractChannelPolicies: %T", r)
×
4792
        }
4793
}
4794

4795
// channelIDToBytes converts a channel ID (SCID) to a byte array
4796
// representation.
4797
func channelIDToBytes(channelID uint64) []byte {
×
4798
        var chanIDB [8]byte
×
4799
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4800

×
4801
        return chanIDB[:]
×
4802
}
×
4803

4804
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4805
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4806
        if len(addresses) == 0 {
×
4807
                return nil, nil
×
4808
        }
×
4809

4810
        result := make([]net.Addr, 0, len(addresses))
×
4811
        for _, addr := range addresses {
×
4812
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
4813
                if err != nil {
×
4814
                        return nil, fmt.Errorf("unable to parse address %s "+
×
4815
                                "of type %d: %w", addr.address, addr.addrType,
×
4816
                                err)
×
4817
                }
×
4818
                if netAddr != nil {
×
4819
                        result = append(result, netAddr)
×
4820
                }
×
4821
        }
4822

4823
        // If we have no valid addresses, return nil instead of empty slice.
4824
        if len(result) == 0 {
×
4825
                return nil, nil
×
4826
        }
×
4827

4828
        return result, nil
×
4829
}
4830

4831
// parseAddress parses the given address string based on the address type
4832
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4833
// and opaque addresses.
4834
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4835
        switch addrType {
×
4836
        case addressTypeIPv4:
×
4837
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4838
                if err != nil {
×
4839
                        return nil, err
×
4840
                }
×
4841

4842
                tcp.IP = tcp.IP.To4()
×
4843

×
4844
                return tcp, nil
×
4845

4846
        case addressTypeIPv6:
×
4847
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
4848
                if err != nil {
×
4849
                        return nil, err
×
4850
                }
×
4851

4852
                return tcp, nil
×
4853

4854
        case addressTypeTorV3, addressTypeTorV2:
×
4855
                service, portStr, err := net.SplitHostPort(address)
×
4856
                if err != nil {
×
4857
                        return nil, fmt.Errorf("unable to split tor "+
×
4858
                                "address: %v", address)
×
4859
                }
×
4860

4861
                port, err := strconv.Atoi(portStr)
×
4862
                if err != nil {
×
4863
                        return nil, err
×
4864
                }
×
4865

4866
                return &tor.OnionAddr{
×
4867
                        OnionService: service,
×
4868
                        Port:         port,
×
4869
                }, nil
×
4870

4871
        case addressTypeDNS:
×
4872
                hostname, portStr, err := net.SplitHostPort(address)
×
4873
                if err != nil {
×
4874
                        return nil, fmt.Errorf("unable to split DNS "+
×
4875
                                "address: %v", address)
×
4876
                }
×
4877

4878
                port, err := strconv.Atoi(portStr)
×
4879
                if err != nil {
×
4880
                        return nil, err
×
4881
                }
×
4882

4883
                return &lnwire.DNSAddress{
×
4884
                        Hostname: hostname,
×
4885
                        Port:     uint16(port),
×
4886
                }, nil
×
4887

4888
        case addressTypeOpaque:
×
4889
                opaque, err := hex.DecodeString(address)
×
4890
                if err != nil {
×
4891
                        return nil, fmt.Errorf("unable to decode opaque "+
×
4892
                                "address: %v", address)
×
4893
                }
×
4894

4895
                return &lnwire.OpaqueAddrs{
×
4896
                        Payload: opaque,
×
4897
                }, nil
×
4898

4899
        default:
×
4900
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
4901
        }
4902
}
4903

4904
// batchNodeData holds all the related data for a batch of nodes.
4905
type batchNodeData struct {
4906
        // features is a map from a DB node ID to the feature bits for that
4907
        // node.
4908
        features map[int64][]int
4909

4910
        // addresses is a map from a DB node ID to the node's addresses.
4911
        addresses map[int64][]nodeAddress
4912

4913
        // extraFields is a map from a DB node ID to the extra signed fields
4914
        // for that node.
4915
        extraFields map[int64]map[uint64][]byte
4916
}
4917

4918
// nodeAddress holds the address type, position and address string for a
4919
// node. This is used to batch the fetching of node addresses.
4920
type nodeAddress struct {
4921
        addrType dbAddressType
4922
        position int32
4923
        address  string
4924
}
4925

4926
// batchLoadNodeData loads all related data for a batch of node IDs using the
4927
// provided SQLQueries interface. It returns a batchNodeData instance containing
4928
// the node features, addresses and extra signed fields.
4929
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
4930
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
4931

×
4932
        // Batch load the node features.
×
4933
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
4934
        if err != nil {
×
4935
                return nil, fmt.Errorf("unable to batch load node "+
×
4936
                        "features: %w", err)
×
4937
        }
×
4938

4939
        // Batch load the node addresses.
4940
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
4941
        if err != nil {
×
4942
                return nil, fmt.Errorf("unable to batch load node "+
×
4943
                        "addresses: %w", err)
×
4944
        }
×
4945

4946
        // Batch load the node extra signed fields.
4947
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
4948
        if err != nil {
×
4949
                return nil, fmt.Errorf("unable to batch load node extra "+
×
4950
                        "signed fields: %w", err)
×
4951
        }
×
4952

4953
        return &batchNodeData{
×
4954
                features:    features,
×
4955
                addresses:   addrs,
×
4956
                extraFields: extraTypes,
×
4957
        }, nil
×
4958
}
4959

4960
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
4961
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
4962
func batchLoadNodeFeaturesHelper(ctx context.Context,
4963
        cfg *sqldb.QueryConfig, db SQLQueries,
4964
        nodeIDs []int64) (map[int64][]int, error) {
×
4965

×
4966
        features := make(map[int64][]int)
×
4967

×
4968
        return features, sqldb.ExecuteBatchQuery(
×
4969
                ctx, cfg, nodeIDs,
×
4970
                func(id int64) int64 {
×
4971
                        return id
×
4972
                },
×
4973
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
4974
                        error) {
×
4975

×
4976
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
4977
                },
×
4978
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
4979
                        features[feature.NodeID] = append(
×
4980
                                features[feature.NodeID],
×
4981
                                int(feature.FeatureBit),
×
4982
                        )
×
4983

×
4984
                        return nil
×
4985
                },
×
4986
        )
4987
}
4988

4989
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
4990
// wrapper around the GetNodeAddressesBatch query. It returns a map from
4991
// node ID to a slice of nodeAddress structs.
4992
func batchLoadNodeAddressesHelper(ctx context.Context,
4993
        cfg *sqldb.QueryConfig, db SQLQueries,
4994
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
4995

×
4996
        addrs := make(map[int64][]nodeAddress)
×
4997

×
4998
        return addrs, sqldb.ExecuteBatchQuery(
×
4999
                ctx, cfg, nodeIDs,
×
5000
                func(id int64) int64 {
×
5001
                        return id
×
5002
                },
×
5003
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
5004
                        error) {
×
5005

×
5006
                        return db.GetNodeAddressesBatch(ctx, ids)
×
5007
                },
×
5008
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
5009
                        addrs[addr.NodeID] = append(
×
5010
                                addrs[addr.NodeID], nodeAddress{
×
5011
                                        addrType: dbAddressType(addr.Type),
×
5012
                                        position: addr.Position,
×
5013
                                        address:  addr.Address,
×
5014
                                },
×
5015
                        )
×
5016

×
5017
                        return nil
×
5018
                },
×
5019
        )
5020
}
5021

5022
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
5023
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
5024
// query.
5025
func batchLoadNodeExtraTypesHelper(ctx context.Context,
5026
        cfg *sqldb.QueryConfig, db SQLQueries,
5027
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5028

×
5029
        extraFields := make(map[int64]map[uint64][]byte)
×
5030

×
5031
        callback := func(ctx context.Context,
×
5032
                field sqlc.GraphNodeExtraType) error {
×
5033

×
5034
                if extraFields[field.NodeID] == nil {
×
5035
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
5036
                }
×
5037
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
5038

×
5039
                return nil
×
5040
        }
5041

5042
        return extraFields, sqldb.ExecuteBatchQuery(
×
5043
                ctx, cfg, nodeIDs,
×
5044
                func(id int64) int64 {
×
5045
                        return id
×
5046
                },
×
5047
                func(ctx context.Context, ids []int64) (
5048
                        []sqlc.GraphNodeExtraType, error) {
×
5049

×
5050
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
5051
                },
×
5052
                callback,
5053
        )
5054
}
5055

5056
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
5057
// from the provided sqlc.GraphChannelPolicy records and the
5058
// provided batchChannelData.
5059
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
5060
        channelID uint64, node1, node2 route.Vertex,
5061
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
5062
        *models.ChannelEdgePolicy, error) {
×
5063

×
5064
        pol1, err := buildChanPolicyWithBatchData(
×
5065
                dbPol1, channelID, node2, batchData,
×
5066
        )
×
5067
        if err != nil {
×
5068
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
5069
        }
×
5070

5071
        pol2, err := buildChanPolicyWithBatchData(
×
5072
                dbPol2, channelID, node1, batchData,
×
5073
        )
×
5074
        if err != nil {
×
5075
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
5076
        }
×
5077

5078
        return pol1, pol2, nil
×
5079
}
5080

5081
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
5082
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
5083
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
5084
        channelID uint64, toNode route.Vertex,
5085
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
5086

×
5087
        if dbPol == nil {
×
5088
                return nil, nil
×
5089
        }
×
5090

5091
        var dbPol1Extras map[uint64][]byte
×
5092
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
5093
                dbPol1Extras = extras
×
5094
        } else {
×
5095
                dbPol1Extras = make(map[uint64][]byte)
×
5096
        }
×
5097

5098
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
5099
}
5100

5101
// batchChannelData holds all the related data for a batch of channels.
5102
type batchChannelData struct {
5103
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
5104
        chanfeatures map[int64][]int
5105

5106
        // chanExtras is a map from DB channel ID to a map of TLV type to
5107
        // extra signed field bytes.
5108
        chanExtraTypes map[int64]map[uint64][]byte
5109

5110
        // policyExtras is a map from DB channel policy ID to a map of TLV type
5111
        // to extra signed field bytes.
5112
        policyExtras map[int64]map[uint64][]byte
5113
}
5114

5115
// batchLoadChannelData loads all related data for batches of channels and
5116
// policies.
5117
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
5118
        db SQLQueries, channelIDs []int64,
5119
        policyIDs []int64) (*batchChannelData, error) {
×
5120

×
5121
        batchData := &batchChannelData{
×
5122
                chanfeatures:   make(map[int64][]int),
×
5123
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
5124
                policyExtras:   make(map[int64]map[uint64][]byte),
×
5125
        }
×
5126

×
5127
        // Batch load channel features and extras
×
5128
        var err error
×
5129
        if len(channelIDs) > 0 {
×
5130
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
5131
                        ctx, cfg, db, channelIDs,
×
5132
                )
×
5133
                if err != nil {
×
5134
                        return nil, fmt.Errorf("unable to batch load "+
×
5135
                                "channel features: %w", err)
×
5136
                }
×
5137

5138
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
5139
                        ctx, cfg, db, channelIDs,
×
5140
                )
×
5141
                if err != nil {
×
5142
                        return nil, fmt.Errorf("unable to batch load "+
×
5143
                                "channel extras: %w", err)
×
5144
                }
×
5145
        }
5146

5147
        if len(policyIDs) > 0 {
×
5148
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
5149
                        ctx, cfg, db, policyIDs,
×
5150
                )
×
5151
                if err != nil {
×
5152
                        return nil, fmt.Errorf("unable to batch load "+
×
5153
                                "policy extras: %w", err)
×
5154
                }
×
5155
                batchData.policyExtras = policyExtras
×
5156
        }
5157

5158
        return batchData, nil
×
5159
}
5160

5161
// batchLoadChannelFeaturesHelper loads channel features for a batch of
5162
// channel IDs using ExecuteBatchQuery wrapper around the
5163
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
5164
// slice of feature bits.
5165
func batchLoadChannelFeaturesHelper(ctx context.Context,
5166
        cfg *sqldb.QueryConfig, db SQLQueries,
5167
        channelIDs []int64) (map[int64][]int, error) {
×
5168

×
5169
        features := make(map[int64][]int)
×
5170

×
5171
        return features, sqldb.ExecuteBatchQuery(
×
5172
                ctx, cfg, channelIDs,
×
5173
                func(id int64) int64 {
×
5174
                        return id
×
5175
                },
×
5176
                func(ctx context.Context,
5177
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
5178

×
5179
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
5180
                },
×
5181
                func(ctx context.Context,
5182
                        feature sqlc.GraphChannelFeature) error {
×
5183

×
5184
                        features[feature.ChannelID] = append(
×
5185
                                features[feature.ChannelID],
×
5186
                                int(feature.FeatureBit),
×
5187
                        )
×
5188

×
5189
                        return nil
×
5190
                },
×
5191
        )
5192
}
5193

5194
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
5195
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
5196
// query. It returns a map from DB channel ID to a map of TLV type to extra
5197
// signed field bytes.
5198
func batchLoadChannelExtrasHelper(ctx context.Context,
5199
        cfg *sqldb.QueryConfig, db SQLQueries,
5200
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5201

×
5202
        extras := make(map[int64]map[uint64][]byte)
×
5203

×
5204
        cb := func(ctx context.Context,
×
5205
                extra sqlc.GraphChannelExtraType) error {
×
5206

×
5207
                if extras[extra.ChannelID] == nil {
×
5208
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5209
                }
×
5210
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
5211

×
5212
                return nil
×
5213
        }
5214

5215
        return extras, sqldb.ExecuteBatchQuery(
×
5216
                ctx, cfg, channelIDs,
×
5217
                func(id int64) int64 {
×
5218
                        return id
×
5219
                },
×
5220
                func(ctx context.Context,
5221
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5222

×
5223
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5224
                }, cb,
×
5225
        )
5226
}
5227

5228
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5229
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5230
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5231
// a map of TLV type to extra signed field bytes.
5232
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5233
        cfg *sqldb.QueryConfig, db SQLQueries,
5234
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5235

×
5236
        extras := make(map[int64]map[uint64][]byte)
×
5237

×
5238
        return extras, sqldb.ExecuteBatchQuery(
×
5239
                ctx, cfg, policyIDs,
×
5240
                func(id int64) int64 {
×
5241
                        return id
×
5242
                },
×
5243
                func(ctx context.Context, ids []int64) (
5244
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5245

×
5246
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5247
                },
×
5248
                func(ctx context.Context,
5249
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5250

×
5251
                        if extras[row.PolicyID] == nil {
×
5252
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5253
                        }
×
5254
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5255

×
5256
                        return nil
×
5257
                },
5258
        )
5259
}
5260

5261
// forEachNodePaginated executes a paginated query to process each node in the
5262
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5263
// and applies the provided processNode function to each node.
5264
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5265
        db SQLQueries, protocol lnwire.GossipVersion,
5266
        processNode func(context.Context, int64,
5267
                *models.Node) error) error {
×
5268

×
5269
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5270
                limit int32) ([]sqlc.GraphNode, error) {
×
5271

×
5272
                return db.ListNodesPaginated(
×
5273
                        ctx, sqlc.ListNodesPaginatedParams{
×
5274
                                Version: int16(protocol),
×
5275
                                ID:      lastID,
×
5276
                                Limit:   limit,
×
5277
                        },
×
5278
                )
×
5279
        }
×
5280

5281
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5282
                return node.ID
×
5283
        }
×
5284

5285
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5286
                return node.ID, nil
×
5287
        }
×
5288

5289
        batchQueryFunc := func(ctx context.Context,
×
5290
                nodeIDs []int64) (*batchNodeData, error) {
×
5291

×
5292
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5293
        }
×
5294

5295
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5296
                batchData *batchNodeData) error {
×
5297

×
5298
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
5299
                if err != nil {
×
5300
                        return fmt.Errorf("unable to build "+
×
5301
                                "node(id=%d): %w", dbNode.ID, err)
×
5302
                }
×
5303

5304
                return processNode(ctx, dbNode.ID, node)
×
5305
        }
5306

5307
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5308
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5309
                collectFunc, batchQueryFunc, processItem,
×
5310
        )
×
5311
}
5312

5313
// forEachChannelWithPolicies executes a paginated query to process each channel
5314
// with policies in the graph.
5315
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5316
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5317
                *models.ChannelEdgePolicy,
5318
                *models.ChannelEdgePolicy) error) error {
×
5319

×
5320
        type channelBatchIDs struct {
×
5321
                channelID int64
×
5322
                policyIDs []int64
×
5323
        }
×
5324

×
5325
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5326
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5327
                error) {
×
5328

×
5329
                return db.ListChannelsWithPoliciesPaginated(
×
5330
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
NEW
5331
                                Version: int16(lnwire.GossipVersion1),
×
5332
                                ID:      lastID,
×
5333
                                Limit:   limit,
×
5334
                        },
×
5335
                )
×
5336
        }
×
5337

5338
        extractPageCursor := func(
×
5339
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5340

×
5341
                return row.GraphChannel.ID
×
5342
        }
×
5343

5344
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
5345
                channelBatchIDs, error) {
×
5346

×
5347
                ids := channelBatchIDs{
×
5348
                        channelID: row.GraphChannel.ID,
×
5349
                }
×
5350

×
5351
                // Extract policy IDs from the row.
×
5352
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5353
                if err != nil {
×
5354
                        return ids, err
×
5355
                }
×
5356

5357
                if dbPol1 != nil {
×
5358
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5359
                }
×
5360
                if dbPol2 != nil {
×
5361
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5362
                }
×
5363

5364
                return ids, nil
×
5365
        }
5366

5367
        batchDataFunc := func(ctx context.Context,
×
5368
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5369

×
5370
                // Separate channel IDs from policy IDs.
×
5371
                var (
×
5372
                        channelIDs = make([]int64, len(allIDs))
×
5373
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5374
                )
×
5375

×
5376
                for i, ids := range allIDs {
×
5377
                        channelIDs[i] = ids.channelID
×
5378
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5379
                }
×
5380

5381
                return batchLoadChannelData(
×
5382
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5383
                )
×
5384
        }
5385

5386
        processItem := func(ctx context.Context,
×
5387
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5388
                batchData *batchChannelData) error {
×
5389

×
5390
                node1, node2, err := buildNodeVertices(
×
5391
                        row.Node1Pubkey, row.Node2Pubkey,
×
5392
                )
×
5393
                if err != nil {
×
5394
                        return err
×
5395
                }
×
5396

5397
                edge, err := buildEdgeInfoWithBatchData(
×
5398
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5399
                        batchData,
×
5400
                )
×
5401
                if err != nil {
×
5402
                        return fmt.Errorf("unable to build channel info: %w",
×
5403
                                err)
×
5404
                }
×
5405

5406
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5407
                if err != nil {
×
5408
                        return err
×
5409
                }
×
5410

5411
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5412
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
5413
                )
×
5414
                if err != nil {
×
5415
                        return err
×
5416
                }
×
5417

5418
                return processChannel(edge, p1, p2)
×
5419
        }
5420

5421
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5422
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5423
                collectFunc, batchDataFunc, processItem,
×
5424
        )
×
5425
}
5426

5427
// buildDirectedChannel builds a DirectedChannel instance from the provided
5428
// data.
5429
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5430
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5431
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5432
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
5433

×
5434
        node1, node2, err := buildNodeVertices(
×
5435
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5436
        )
×
5437
        if err != nil {
×
5438
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5439
        }
×
5440

5441
        edge, err := buildEdgeInfoWithBatchData(
×
5442
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5443
        )
×
5444
        if err != nil {
×
5445
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5446
        }
×
5447

5448
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
5449
        if err != nil {
×
5450
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5451
                        err)
×
5452
        }
×
5453

5454
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5455
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5456
                channelBatchData,
×
5457
        )
×
5458
        if err != nil {
×
5459
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5460
                        err)
×
5461
        }
×
5462

5463
        // Determine outgoing and incoming policy for this specific node.
5464
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5465
        p2ToNode := channelRow.GraphChannel.NodeID1
×
5466
        outPolicy, inPolicy := p1, p2
×
5467
        if (p1 != nil && p1ToNode == nodeID) ||
×
5468
                (p2 != nil && p2ToNode != nodeID) {
×
5469

×
5470
                outPolicy, inPolicy = p2, p1
×
5471
        }
×
5472

5473
        // Build cached policy.
5474
        var cachedInPolicy *models.CachedEdgePolicy
×
5475
        if inPolicy != nil {
×
5476
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5477
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5478
                cachedInPolicy.ToNodeFeatures = features
×
5479
        }
×
5480

5481
        // Extract inbound fee.
5482
        var inboundFee lnwire.Fee
×
5483
        if outPolicy != nil {
×
5484
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5485
                        inboundFee = fee
×
5486
                })
×
5487
        }
5488

5489
        // Build directed channel.
5490
        directedChannel := &DirectedChannel{
×
5491
                ChannelID:    edge.ChannelID,
×
5492
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
5493
                OtherNode:    edge.NodeKey2Bytes,
×
5494
                Capacity:     edge.Capacity,
×
5495
                OutPolicySet: outPolicy != nil,
×
5496
                InPolicy:     cachedInPolicy,
×
5497
                InboundFee:   inboundFee,
×
5498
        }
×
5499

×
5500
        if nodePub == edge.NodeKey2Bytes {
×
5501
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5502
        }
×
5503

5504
        return directedChannel, nil
×
5505
}
5506

5507
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
5508
// provided rows. It uses batch loading for channels, policies, and nodes.
5509
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
5510
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
5511

×
5512
        var (
×
5513
                channelIDs = make([]int64, len(rows))
×
5514
                policyIDs  = make([]int64, 0, len(rows)*2)
×
5515
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
5516

×
5517
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
5518
                nodeIDSet = make(map[int64]bool)
×
5519

×
5520
                // edges will hold the final channel edges built from the rows.
×
5521
                edges = make([]ChannelEdge, 0, len(rows))
×
5522
        )
×
5523

×
5524
        // Collect all IDs needed for batch loading.
×
5525
        for i, row := range rows {
×
5526
                channelIDs[i] = row.Channel().ID
×
5527

×
5528
                // Collect policy IDs
×
5529
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5530
                if err != nil {
×
5531
                        return nil, fmt.Errorf("unable to extract channel "+
×
5532
                                "policies: %w", err)
×
5533
                }
×
5534
                if dbPol1 != nil {
×
5535
                        policyIDs = append(policyIDs, dbPol1.ID)
×
5536
                }
×
5537
                if dbPol2 != nil {
×
5538
                        policyIDs = append(policyIDs, dbPol2.ID)
×
5539
                }
×
5540

5541
                var (
×
5542
                        node1ID = row.Node1().ID
×
5543
                        node2ID = row.Node2().ID
×
5544
                )
×
5545

×
5546
                // Collect unique node IDs.
×
5547
                if !nodeIDSet[node1ID] {
×
5548
                        nodeIDs = append(nodeIDs, node1ID)
×
5549
                        nodeIDSet[node1ID] = true
×
5550
                }
×
5551

5552
                if !nodeIDSet[node2ID] {
×
5553
                        nodeIDs = append(nodeIDs, node2ID)
×
5554
                        nodeIDSet[node2ID] = true
×
5555
                }
×
5556
        }
5557

5558
        // Batch the data for all the channels and policies.
5559
        channelBatchData, err := batchLoadChannelData(
×
5560
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5561
        )
×
5562
        if err != nil {
×
5563
                return nil, fmt.Errorf("unable to batch load channel and "+
×
5564
                        "policy data: %w", err)
×
5565
        }
×
5566

5567
        // Batch the data for all the nodes.
5568
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
5569
        if err != nil {
×
5570
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
5571
                        err)
×
5572
        }
×
5573

5574
        // Build all channel edges using batch data.
5575
        for _, row := range rows {
×
5576
                // Build nodes using batch data.
×
5577
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
5578
                if err != nil {
×
5579
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
5580
                }
×
5581

5582
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
5583
                if err != nil {
×
5584
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
5585
                }
×
5586

5587
                // Build channel info using batch data.
5588
                channel, err := buildEdgeInfoWithBatchData(
×
5589
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
5590
                        node2.PubKeyBytes, channelBatchData,
×
5591
                )
×
5592
                if err != nil {
×
5593
                        return nil, fmt.Errorf("unable to build channel "+
×
5594
                                "info: %w", err)
×
5595
                }
×
5596

5597
                // Extract and build policies using batch data.
5598
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5599
                if err != nil {
×
5600
                        return nil, fmt.Errorf("unable to extract channel "+
×
5601
                                "policies: %w", err)
×
5602
                }
×
5603

5604
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5605
                        dbPol1, dbPol2, channel.ChannelID,
×
5606
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
5607
                )
×
5608
                if err != nil {
×
5609
                        return nil, fmt.Errorf("unable to build channel "+
×
5610
                                "policies: %w", err)
×
5611
                }
×
5612

5613
                edges = append(edges, ChannelEdge{
×
5614
                        Info:    channel,
×
5615
                        Policy1: p1,
×
5616
                        Policy2: p2,
×
5617
                        Node1:   node1,
×
5618
                        Node2:   node2,
×
5619
                })
×
5620
        }
5621

5622
        return edges, nil
×
5623
}
5624

5625
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
5626
// instances from the provided rows using batch loading for channel data.
5627
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
5628
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
5629
        []*models.ChannelEdgeInfo, []int64, error) {
×
5630

×
5631
        if len(rows) == 0 {
×
5632
                return nil, nil, nil
×
5633
        }
×
5634

5635
        // Collect all the channel IDs needed for batch loading.
5636
        channelIDs := make([]int64, len(rows))
×
5637
        for i, row := range rows {
×
5638
                channelIDs[i] = row.Channel().ID
×
5639
        }
×
5640

5641
        // Batch load the channel data.
5642
        channelBatchData, err := batchLoadChannelData(
×
5643
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
5644
        )
×
5645
        if err != nil {
×
5646
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
5647
                        "data: %w", err)
×
5648
        }
×
5649

5650
        // Build all channel edges using batch data.
5651
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
5652
        for _, row := range rows {
×
5653
                node1, node2, err := buildNodeVertices(
×
5654
                        row.Node1Pub(), row.Node2Pub(),
×
5655
                )
×
5656
                if err != nil {
×
5657
                        return nil, nil, err
×
5658
                }
×
5659

5660
                // Build channel info using batch data
5661
                info, err := buildEdgeInfoWithBatchData(
×
5662
                        cfg.ChainHash, row.Channel(), node1, node2,
×
5663
                        channelBatchData,
×
5664
                )
×
5665
                if err != nil {
×
5666
                        return nil, nil, err
×
5667
                }
×
5668

5669
                edges = append(edges, info)
×
5670
        }
5671

5672
        return edges, channelIDs, nil
×
5673
}
5674

5675
// handleZombieMarking is a helper function that handles the logic of
5676
// marking a channel as a zombie in the database. It takes into account whether
5677
// we are in strict zombie pruning mode, and adjusts the node public keys
5678
// accordingly based on the last update timestamps of the channel policies.
5679
func handleZombieMarking(ctx context.Context, db SQLQueries,
5680
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
5681
        strictZombiePruning bool, scid uint64) error {
×
5682

×
5683
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
5684

×
5685
        if strictZombiePruning {
×
5686
                var e1UpdateTime, e2UpdateTime *time.Time
×
5687
                if row.Policy1LastUpdate.Valid {
×
5688
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
5689
                        e1UpdateTime = &e1Time
×
5690
                }
×
5691
                if row.Policy2LastUpdate.Valid {
×
5692
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
5693
                        e2UpdateTime = &e2Time
×
5694
                }
×
5695

5696
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
5697
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
5698
                        e2UpdateTime,
×
5699
                )
×
5700
        }
5701

5702
        return db.UpsertZombieChannel(
×
5703
                ctx, sqlc.UpsertZombieChannelParams{
×
NEW
5704
                        Version:  int16(lnwire.GossipVersion1),
×
5705
                        Scid:     channelIDToBytes(scid),
×
5706
                        NodeKey1: nodeKey1[:],
×
5707
                        NodeKey2: nodeKey2[:],
×
5708
                },
×
5709
        )
×
5710
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc