• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16990665124

15 Aug 2025 01:10PM UTC coverage: 66.74% (-0.03%) from 66.765%
16990665124

Pull #9455

github

web-flow
Merge 035fac41d into fb1adfc21
Pull Request #9455: [1/2] discovery+lnwire: add support for DNS host name in NodeAnnouncement msg

116 of 188 new or added lines in 8 files covered. (61.7%)

110 existing lines in 23 files now uncovered.

136011 of 203791 relevant lines covered (66.74%)

21482.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// ProtocolVersion is an enum that defines the gossip protocol version of a
35
// message.
36
type ProtocolVersion uint8
37

38
const (
39
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
40
        ProtocolV1 ProtocolVersion = 1
41
)
42

43
// String returns a string representation of the protocol version.
44
func (v ProtocolVersion) String() string {
×
45
        return fmt.Sprintf("V%d", v)
×
46
}
×
47

48
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
49
// execute queries against the SQL graph tables.
50
//
51
//nolint:ll,interfacebloat
52
type SQLQueries interface {
53
        /*
54
                Node queries.
55
        */
56
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
57
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
58
        GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
59
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
60
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
61
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
62
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
63
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
64
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
65
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
66
        DeleteNode(ctx context.Context, id int64) error
67

68
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
69
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
70
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
71
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
72

73
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
74
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
75
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
76
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
77

78
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
79
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
80
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
96
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
97
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
98
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
99
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
100
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
101
        GetChannelsByIDs(ctx context.Context, ids []int64) ([]sqlc.GetChannelsByIDsRow, error)
102
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
103
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
104
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
105
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
106
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
107
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
108
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
109
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
110
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
111
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
112
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
113
        DeleteChannels(ctx context.Context, ids []int64) error
114

115
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
116
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
117
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
118
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
119

120
        /*
121
                Channel Policy table queries.
122
        */
123
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
124
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
125
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
126

127
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
128
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
129
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
130

131
        /*
132
                Zombie index queries.
133
        */
134
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
135
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
136
        GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error)
137
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
138
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
139
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
140

141
        /*
142
                Prune log table queries.
143
        */
144
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
145
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
146
        GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]sqlc.GraphPruneLog, error)
147
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
148
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
149

150
        /*
151
                Closed SCID table queries.
152
        */
153
        InsertClosedChannel(ctx context.Context, scid []byte) error
154
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
155
        GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
156
}
157

158
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
159
// database operations.
160
type BatchedSQLQueries interface {
161
        SQLQueries
162
        sqldb.BatchedTx[SQLQueries]
163
}
164

165
// SQLStore is an implementation of the V1Store interface that uses a SQL
166
// database as the backend.
167
type SQLStore struct {
168
        cfg *SQLStoreConfig
169
        db  BatchedSQLQueries
170

171
        // cacheMu guards all caches (rejectCache and chanCache). If
172
        // this mutex will be acquired at the same time as the DB mutex then
173
        // the cacheMu MUST be acquired first to prevent deadlock.
174
        cacheMu     sync.RWMutex
175
        rejectCache *rejectCache
176
        chanCache   *channelCache
177

178
        chanScheduler batch.Scheduler[SQLQueries]
179
        nodeScheduler batch.Scheduler[SQLQueries]
180

181
        srcNodes  map[ProtocolVersion]*srcNodeInfo
182
        srcNodeMu sync.Mutex
183
}
184

185
// A compile-time assertion to ensure that SQLStore implements the V1Store
186
// interface.
187
var _ V1Store = (*SQLStore)(nil)
188

189
// SQLStoreConfig holds the configuration for the SQLStore.
190
type SQLStoreConfig struct {
191
        // ChainHash is the genesis hash for the chain that all the gossip
192
        // messages in this store are aimed at.
193
        ChainHash chainhash.Hash
194

195
        // QueryConfig holds configuration values for SQL queries.
196
        QueryCfg *sqldb.QueryConfig
197
}
198

199
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
200
// storage backend.
201
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
202
        options ...StoreOptionModifier) (*SQLStore, error) {
×
203

×
204
        opts := DefaultOptions()
×
205
        for _, o := range options {
×
206
                o(opts)
×
207
        }
×
208

209
        if opts.NoMigration {
×
210
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
211
                        "supported for SQL stores")
×
212
        }
×
213

214
        s := &SQLStore{
×
215
                cfg:         cfg,
×
216
                db:          db,
×
217
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
218
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
219
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
220
        }
×
221

×
222
        s.chanScheduler = batch.NewTimeScheduler(
×
223
                db, &s.cacheMu, opts.BatchCommitInterval,
×
224
        )
×
225
        s.nodeScheduler = batch.NewTimeScheduler(
×
226
                db, nil, opts.BatchCommitInterval,
×
227
        )
×
228

×
229
        return s, nil
×
230
}
231

232
// AddLightningNode adds a vertex/node to the graph database. If the node is not
233
// in the database from before, this will add a new, unconnected one to the
234
// graph. If it is present from before, this will update that node's
235
// information.
236
//
237
// NOTE: part of the V1Store interface.
238
func (s *SQLStore) AddLightningNode(ctx context.Context,
239
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
240

×
241
        r := &batch.Request[SQLQueries]{
×
242
                Opts: batch.NewSchedulerOptions(opts...),
×
243
                Do: func(queries SQLQueries) error {
×
244
                        _, err := upsertNode(ctx, queries, node)
×
245
                        return err
×
246
                },
×
247
        }
248

249
        return s.nodeScheduler.Execute(ctx, r)
×
250
}
251

252
// FetchLightningNode attempts to look up a target node by its identity public
253
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
254
// returned.
255
//
256
// NOTE: part of the V1Store interface.
257
func (s *SQLStore) FetchLightningNode(ctx context.Context,
258
        pubKey route.Vertex) (*models.LightningNode, error) {
×
259

×
260
        var node *models.LightningNode
×
261
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
262
                var err error
×
263
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, pubKey)
×
264

×
265
                return err
×
266
        }, sqldb.NoOpReset)
×
267
        if err != nil {
×
268
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
269
        }
×
270

271
        return node, nil
×
272
}
273

274
// HasLightningNode determines if the graph has a vertex identified by the
275
// target node identity public key. If the node exists in the database, a
276
// timestamp of when the data for the node was lasted updated is returned along
277
// with a true boolean. Otherwise, an empty time.Time is returned with a false
278
// boolean.
279
//
280
// NOTE: part of the V1Store interface.
281
func (s *SQLStore) HasLightningNode(ctx context.Context,
282
        pubKey [33]byte) (time.Time, bool, error) {
×
283

×
284
        var (
×
285
                exists     bool
×
286
                lastUpdate time.Time
×
287
        )
×
288
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
289
                dbNode, err := db.GetNodeByPubKey(
×
290
                        ctx, sqlc.GetNodeByPubKeyParams{
×
291
                                Version: int16(ProtocolV1),
×
292
                                PubKey:  pubKey[:],
×
293
                        },
×
294
                )
×
295
                if errors.Is(err, sql.ErrNoRows) {
×
296
                        return nil
×
297
                } else if err != nil {
×
298
                        return fmt.Errorf("unable to fetch node: %w", err)
×
299
                }
×
300

301
                exists = true
×
302

×
303
                if dbNode.LastUpdate.Valid {
×
304
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
305
                }
×
306

307
                return nil
×
308
        }, sqldb.NoOpReset)
309
        if err != nil {
×
310
                return time.Time{}, false,
×
311
                        fmt.Errorf("unable to fetch node: %w", err)
×
312
        }
×
313

314
        return lastUpdate, exists, nil
×
315
}
316

317
// AddrsForNode returns all known addresses for the target node public key
318
// that the graph DB is aware of. The returned boolean indicates if the
319
// given node is unknown to the graph DB or not.
320
//
321
// NOTE: part of the V1Store interface.
322
func (s *SQLStore) AddrsForNode(ctx context.Context,
323
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
324

×
325
        var (
×
326
                addresses []net.Addr
×
327
                known     bool
×
328
        )
×
329
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
330
                // First, check if the node exists and get its DB ID if it
×
331
                // does.
×
332
                dbID, err := db.GetNodeIDByPubKey(
×
333
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
334
                                Version: int16(ProtocolV1),
×
335
                                PubKey:  nodePub.SerializeCompressed(),
×
336
                        },
×
337
                )
×
338
                if errors.Is(err, sql.ErrNoRows) {
×
339
                        return nil
×
340
                }
×
341

342
                known = true
×
343

×
344
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
345
                if err != nil {
×
346
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
347
                                err)
×
348
                }
×
349

350
                return nil
×
351
        }, sqldb.NoOpReset)
352
        if err != nil {
×
353
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
354
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
355
        }
×
356

357
        return known, addresses, nil
×
358
}
359

360
// DeleteLightningNode starts a new database transaction to remove a vertex/node
361
// from the database according to the node's public key.
362
//
363
// NOTE: part of the V1Store interface.
364
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
365
        pubKey route.Vertex) error {
×
366

×
367
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
368
                res, err := db.DeleteNodeByPubKey(
×
369
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
370
                                Version: int16(ProtocolV1),
×
371
                                PubKey:  pubKey[:],
×
372
                        },
×
373
                )
×
374
                if err != nil {
×
375
                        return err
×
376
                }
×
377

378
                rows, err := res.RowsAffected()
×
379
                if err != nil {
×
380
                        return err
×
381
                }
×
382

383
                if rows == 0 {
×
384
                        return ErrGraphNodeNotFound
×
385
                } else if rows > 1 {
×
386
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
387
                }
×
388

389
                return err
×
390
        }, sqldb.NoOpReset)
391
        if err != nil {
×
392
                return fmt.Errorf("unable to delete node: %w", err)
×
393
        }
×
394

395
        return nil
×
396
}
397

398
// FetchNodeFeatures returns the features of the given node. If no features are
399
// known for the node, an empty feature vector is returned.
400
//
401
// NOTE: this is part of the graphdb.NodeTraverser interface.
402
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
403
        *lnwire.FeatureVector, error) {
×
404

×
405
        ctx := context.TODO()
×
406

×
407
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
408
}
×
409

410
// DisabledChannelIDs returns the channel ids of disabled channels.
411
// A channel is disabled when two of the associated ChanelEdgePolicies
412
// have their disabled bit on.
413
//
414
// NOTE: part of the V1Store interface.
415
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
416
        var (
×
417
                ctx     = context.TODO()
×
418
                chanIDs []uint64
×
419
        )
×
420
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
421
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
422
                if err != nil {
×
423
                        return fmt.Errorf("unable to fetch disabled "+
×
424
                                "channels: %w", err)
×
425
                }
×
426

427
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
428

×
429
                return nil
×
430
        }, sqldb.NoOpReset)
431
        if err != nil {
×
432
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
433
                        err)
×
434
        }
×
435

436
        return chanIDs, nil
×
437
}
438

439
// LookupAlias attempts to return the alias as advertised by the target node.
440
//
441
// NOTE: part of the V1Store interface.
442
func (s *SQLStore) LookupAlias(ctx context.Context,
443
        pub *btcec.PublicKey) (string, error) {
×
444

×
445
        var alias string
×
446
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
447
                dbNode, err := db.GetNodeByPubKey(
×
448
                        ctx, sqlc.GetNodeByPubKeyParams{
×
449
                                Version: int16(ProtocolV1),
×
450
                                PubKey:  pub.SerializeCompressed(),
×
451
                        },
×
452
                )
×
453
                if errors.Is(err, sql.ErrNoRows) {
×
454
                        return ErrNodeAliasNotFound
×
455
                } else if err != nil {
×
456
                        return fmt.Errorf("unable to fetch node: %w", err)
×
457
                }
×
458

459
                if !dbNode.Alias.Valid {
×
460
                        return ErrNodeAliasNotFound
×
461
                }
×
462

463
                alias = dbNode.Alias.String
×
464

×
465
                return nil
×
466
        }, sqldb.NoOpReset)
467
        if err != nil {
×
468
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
469
        }
×
470

471
        return alias, nil
×
472
}
473

474
// SourceNode returns the source node of the graph. The source node is treated
475
// as the center node within a star-graph. This method may be used to kick off
476
// a path finding algorithm in order to explore the reachability of another
477
// node based off the source node.
478
//
479
// NOTE: part of the V1Store interface.
480
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
481
        error) {
×
482

×
483
        var node *models.LightningNode
×
484
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
485
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
486
                if err != nil {
×
487
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
488
                                err)
×
489
                }
×
490

491
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, nodePub)
×
492

×
493
                return err
×
494
        }, sqldb.NoOpReset)
495
        if err != nil {
×
496
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
497
        }
×
498

499
        return node, nil
×
500
}
501

502
// SetSourceNode sets the source node within the graph database. The source
503
// node is to be used as the center of a star-graph within path finding
504
// algorithms.
505
//
506
// NOTE: part of the V1Store interface.
507
func (s *SQLStore) SetSourceNode(ctx context.Context,
508
        node *models.LightningNode) error {
×
509

×
510
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
511
                id, err := upsertNode(ctx, db, node)
×
512
                if err != nil {
×
513
                        return fmt.Errorf("unable to upsert source node: %w",
×
514
                                err)
×
515
                }
×
516

517
                // Make sure that if a source node for this version is already
518
                // set, then the ID is the same as the one we are about to set.
519
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
520
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
521
                        return fmt.Errorf("unable to fetch source node: %w",
×
522
                                err)
×
523
                } else if err == nil {
×
524
                        if dbSourceNodeID != id {
×
525
                                return fmt.Errorf("v1 source node already "+
×
526
                                        "set to a different node: %d vs %d",
×
527
                                        dbSourceNodeID, id)
×
528
                        }
×
529

530
                        return nil
×
531
                }
532

533
                return db.AddSourceNode(ctx, id)
×
534
        }, sqldb.NoOpReset)
535
}
536

537
// NodeUpdatesInHorizon returns all the known lightning node which have an
538
// update timestamp within the passed range. This method can be used by two
539
// nodes to quickly determine if they have the same set of up to date node
540
// announcements.
541
//
542
// NOTE: This is part of the V1Store interface.
543
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
544
        endTime time.Time) ([]models.LightningNode, error) {
×
545

×
546
        ctx := context.TODO()
×
547

×
548
        var nodes []models.LightningNode
×
549
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
550
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
551
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
552
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
553
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
554
                        },
×
555
                )
×
556
                if err != nil {
×
557
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
558
                }
×
559

560
                err = forEachNodeInBatch(
×
561
                        ctx, s.cfg.QueryCfg, db, dbNodes,
×
562
                        func(_ int64, node *models.LightningNode) error {
×
563
                                nodes = append(nodes, *node)
×
564

×
565
                                return nil
×
566
                        },
×
567
                )
568
                if err != nil {
×
569
                        return fmt.Errorf("unable to build nodes: %w", err)
×
570
                }
×
571

572
                return nil
×
573
        }, sqldb.NoOpReset)
574
        if err != nil {
×
575
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
576
        }
×
577

578
        return nodes, nil
×
579
}
580

581
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
582
// undirected edge from the two target nodes are created. The information stored
583
// denotes the static attributes of the channel, such as the channelID, the keys
584
// involved in creation of the channel, and the set of features that the channel
585
// supports. The chanPoint and chanID are used to uniquely identify the edge
586
// globally within the database.
587
//
588
// NOTE: part of the V1Store interface.
589
func (s *SQLStore) AddChannelEdge(ctx context.Context,
590
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
591

×
592
        var alreadyExists bool
×
593
        r := &batch.Request[SQLQueries]{
×
594
                Opts: batch.NewSchedulerOptions(opts...),
×
595
                Reset: func() {
×
596
                        alreadyExists = false
×
597
                },
×
598
                Do: func(tx SQLQueries) error {
×
599
                        chanIDB := channelIDToBytes(edge.ChannelID)
×
600

×
601
                        // Make sure that the channel doesn't already exist. We
×
602
                        // do this explicitly instead of relying on catching a
×
603
                        // unique constraint error because relying on SQL to
×
604
                        // throw that error would abort the entire batch of
×
605
                        // transactions.
×
606
                        _, err := tx.GetChannelBySCID(
×
607
                                ctx, sqlc.GetChannelBySCIDParams{
×
608
                                        Scid:    chanIDB,
×
609
                                        Version: int16(ProtocolV1),
×
610
                                },
×
611
                        )
×
612
                        if err == nil {
×
613
                                alreadyExists = true
×
614
                                return nil
×
615
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
616
                                return fmt.Errorf("unable to fetch channel: %w",
×
617
                                        err)
×
618
                        }
×
619

620
                        _, err = insertChannel(ctx, tx, edge)
×
621

×
622
                        return err
×
623
                },
624
                OnCommit: func(err error) error {
×
625
                        switch {
×
626
                        case err != nil:
×
627
                                return err
×
628
                        case alreadyExists:
×
629
                                return ErrEdgeAlreadyExist
×
630
                        default:
×
631
                                s.rejectCache.remove(edge.ChannelID)
×
632
                                s.chanCache.remove(edge.ChannelID)
×
633
                                return nil
×
634
                        }
635
                },
636
        }
637

638
        return s.chanScheduler.Execute(ctx, r)
×
639
}
640

641
// HighestChanID returns the "highest" known channel ID in the channel graph.
642
// This represents the "newest" channel from the PoV of the chain. This method
643
// can be used by peers to quickly determine if their graphs are in sync.
644
//
645
// NOTE: This is part of the V1Store interface.
646
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
647
        var highestChanID uint64
×
648
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
649
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
650
                if errors.Is(err, sql.ErrNoRows) {
×
651
                        return nil
×
652
                } else if err != nil {
×
653
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
654
                                err)
×
655
                }
×
656

657
                highestChanID = byteOrder.Uint64(chanID)
×
658

×
659
                return nil
×
660
        }, sqldb.NoOpReset)
661
        if err != nil {
×
662
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
663
        }
×
664

665
        return highestChanID, nil
×
666
}
667

668
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
669
// within the database for the referenced channel. The `flags` attribute within
670
// the ChannelEdgePolicy determines which of the directed edges are being
671
// updated. If the flag is 1, then the first node's information is being
672
// updated, otherwise it's the second node's information. The node ordering is
673
// determined by the lexicographical ordering of the identity public keys of the
674
// nodes on either side of the channel.
675
//
676
// NOTE: part of the V1Store interface.
677
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
678
        edge *models.ChannelEdgePolicy,
679
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
680

×
681
        var (
×
682
                isUpdate1    bool
×
683
                edgeNotFound bool
×
684
                from, to     route.Vertex
×
685
        )
×
686

×
687
        r := &batch.Request[SQLQueries]{
×
688
                Opts: batch.NewSchedulerOptions(opts...),
×
689
                Reset: func() {
×
690
                        isUpdate1 = false
×
691
                        edgeNotFound = false
×
692
                },
×
693
                Do: func(tx SQLQueries) error {
×
694
                        var err error
×
695
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
696
                                ctx, tx, edge,
×
697
                        )
×
698
                        if err != nil {
×
699
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
700
                        }
×
701

702
                        // Silence ErrEdgeNotFound so that the batch can
703
                        // succeed, but propagate the error via local state.
704
                        if errors.Is(err, ErrEdgeNotFound) {
×
705
                                edgeNotFound = true
×
706
                                return nil
×
707
                        }
×
708

709
                        return err
×
710
                },
711
                OnCommit: func(err error) error {
×
712
                        switch {
×
713
                        case err != nil:
×
714
                                return err
×
715
                        case edgeNotFound:
×
716
                                return ErrEdgeNotFound
×
717
                        default:
×
718
                                s.updateEdgeCache(edge, isUpdate1)
×
719
                                return nil
×
720
                        }
721
                },
722
        }
723

724
        err := s.chanScheduler.Execute(ctx, r)
×
725

×
726
        return from, to, err
×
727
}
728

729
// updateEdgeCache updates our reject and channel caches with the new
730
// edge policy information.
731
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
732
        isUpdate1 bool) {
×
733

×
734
        // If an entry for this channel is found in reject cache, we'll modify
×
735
        // the entry with the updated timestamp for the direction that was just
×
736
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
737
        // during the next query for this edge.
×
738
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
739
                if isUpdate1 {
×
740
                        entry.upd1Time = e.LastUpdate.Unix()
×
741
                } else {
×
742
                        entry.upd2Time = e.LastUpdate.Unix()
×
743
                }
×
744
                s.rejectCache.insert(e.ChannelID, entry)
×
745
        }
746

747
        // If an entry for this channel is found in channel cache, we'll modify
748
        // the entry with the updated policy for the direction that was just
749
        // written. If the edge doesn't exist, we'll defer loading the info and
750
        // policies and lazily read from disk during the next query.
751
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
752
                if isUpdate1 {
×
753
                        channel.Policy1 = e
×
754
                } else {
×
755
                        channel.Policy2 = e
×
756
                }
×
757
                s.chanCache.insert(e.ChannelID, channel)
×
758
        }
759
}
760

761
// ForEachSourceNodeChannel iterates through all channels of the source node,
762
// executing the passed callback on each. The call-back is provided with the
763
// channel's outpoint, whether we have a policy for the channel and the channel
764
// peer's node information.
765
//
766
// NOTE: part of the V1Store interface.
767
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
768
        cb func(chanPoint wire.OutPoint, havePolicy bool,
769
                otherNode *models.LightningNode) error, reset func()) error {
×
770

×
771
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
772
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
773
                if err != nil {
×
774
                        return fmt.Errorf("unable to fetch source node: %w",
×
775
                                err)
×
776
                }
×
777

778
                return forEachNodeChannel(
×
779
                        ctx, db, s.cfg, nodeID,
×
780
                        func(info *models.ChannelEdgeInfo,
×
781
                                outPolicy *models.ChannelEdgePolicy,
×
782
                                _ *models.ChannelEdgePolicy) error {
×
783

×
784
                                // Fetch the other node.
×
785
                                var (
×
786
                                        otherNodePub [33]byte
×
787
                                        node1        = info.NodeKey1Bytes
×
788
                                        node2        = info.NodeKey2Bytes
×
789
                                )
×
790
                                switch {
×
791
                                case bytes.Equal(node1[:], nodePub[:]):
×
792
                                        otherNodePub = node2
×
793
                                case bytes.Equal(node2[:], nodePub[:]):
×
794
                                        otherNodePub = node1
×
795
                                default:
×
796
                                        return fmt.Errorf("node not " +
×
797
                                                "participating in this channel")
×
798
                                }
799

800
                                _, otherNode, err := getNodeByPubKey(
×
801
                                        ctx, s.cfg.QueryCfg, db, otherNodePub,
×
802
                                )
×
803
                                if err != nil {
×
804
                                        return fmt.Errorf("unable to fetch "+
×
805
                                                "other node(%x): %w",
×
806
                                                otherNodePub, err)
×
807
                                }
×
808

809
                                return cb(
×
810
                                        info.ChannelPoint, outPolicy != nil,
×
811
                                        otherNode,
×
812
                                )
×
813
                        },
814
                )
815
        }, reset)
816
}
817

818
// ForEachNode iterates through all the stored vertices/nodes in the graph,
819
// executing the passed callback with each node encountered. If the callback
820
// returns an error, then the transaction is aborted and the iteration stops
821
// early.
822
//
823
// NOTE: part of the V1Store interface.
824
func (s *SQLStore) ForEachNode(ctx context.Context,
825
        cb func(node *models.LightningNode) error, reset func()) error {
×
826

×
827
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
828
                return forEachNodePaginated(
×
829
                        ctx, s.cfg.QueryCfg, db,
×
830
                        ProtocolV1, func(_ context.Context, _ int64,
×
831
                                node *models.LightningNode) error {
×
832

×
833
                                return cb(node)
×
834
                        },
×
835
                )
836
        }, reset)
837
}
838

839
// ForEachNodeDirectedChannel iterates through all channels of a given node,
840
// executing the passed callback on the directed edge representing the channel
841
// and its incoming policy. If the callback returns an error, then the iteration
842
// is halted with the error propagated back up to the caller.
843
//
844
// Unknown policies are passed into the callback as nil values.
845
//
846
// NOTE: this is part of the graphdb.NodeTraverser interface.
847
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
848
        cb func(channel *DirectedChannel) error, reset func()) error {
×
849

×
850
        var ctx = context.TODO()
×
851

×
852
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
853
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
854
        }, reset)
×
855
}
856

857
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
858
// graph, executing the passed callback with each node encountered. If the
859
// callback returns an error, then the transaction is aborted and the iteration
860
// stops early.
861
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
862
        cb func(route.Vertex, *lnwire.FeatureVector) error,
863
        reset func()) error {
×
864

×
865
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
866
                return forEachNodeCacheable(
×
867
                        ctx, s.cfg.QueryCfg, db,
×
868
                        func(_ int64, nodePub route.Vertex,
×
869
                                features *lnwire.FeatureVector) error {
×
870

×
871
                                return cb(nodePub, features)
×
872
                        },
×
873
                )
874
        }, reset)
875
        if err != nil {
×
876
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
877
        }
×
878

879
        return nil
×
880
}
881

882
// ForEachNodeChannel iterates through all channels of the given node,
883
// executing the passed callback with an edge info structure and the policies
884
// of each end of the channel. The first edge policy is the outgoing edge *to*
885
// the connecting node, while the second is the incoming edge *from* the
886
// connecting node. If the callback returns an error, then the iteration is
887
// halted with the error propagated back up to the caller.
888
//
889
// Unknown policies are passed into the callback as nil values.
890
//
891
// NOTE: part of the V1Store interface.
892
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
893
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
894
                *models.ChannelEdgePolicy) error, reset func()) error {
×
895

×
896
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
897
                dbNode, err := db.GetNodeByPubKey(
×
898
                        ctx, sqlc.GetNodeByPubKeyParams{
×
899
                                Version: int16(ProtocolV1),
×
900
                                PubKey:  nodePub[:],
×
901
                        },
×
902
                )
×
903
                if errors.Is(err, sql.ErrNoRows) {
×
904
                        return nil
×
905
                } else if err != nil {
×
906
                        return fmt.Errorf("unable to fetch node: %w", err)
×
907
                }
×
908

909
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
910
        }, reset)
911
}
912

913
// ChanUpdatesInHorizon returns all the known channel edges which have at least
914
// one edge that has an update timestamp within the specified horizon.
915
//
916
// NOTE: This is part of the V1Store interface.
917
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
918
        endTime time.Time) ([]ChannelEdge, error) {
×
919

×
920
        s.cacheMu.Lock()
×
921
        defer s.cacheMu.Unlock()
×
922

×
923
        var (
×
924
                ctx = context.TODO()
×
925
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
926
                // an additional map to keep track of the edges already seen to
×
927
                // prevent re-adding it.
×
928
                edgesSeen    = make(map[uint64]struct{})
×
929
                edgesToCache = make(map[uint64]ChannelEdge)
×
930
                edges        []ChannelEdge
×
931
                hits         int
×
932
        )
×
933

×
934
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
935
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
936
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
937
                                Version:   int16(ProtocolV1),
×
938
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
939
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
940
                        },
×
941
                )
×
942
                if err != nil {
×
943
                        return err
×
944
                }
×
945

946
                if len(rows) == 0 {
×
947
                        return nil
×
948
                }
×
949

950
                // We'll pre-allocate the slices and maps here with a best
951
                // effort size in order to avoid unnecessary allocations later
952
                // on.
953
                uncachedRows := make(
×
954
                        []sqlc.GetChannelsByPolicyLastUpdateRangeRow, 0,
×
955
                        len(rows),
×
956
                )
×
957
                edgesToCache = make(map[uint64]ChannelEdge, len(rows))
×
958
                edgesSeen = make(map[uint64]struct{}, len(rows))
×
959
                edges = make([]ChannelEdge, 0, len(rows))
×
960

×
961
                // Separate cached from non-cached channels since we will only
×
962
                // batch load the data for the ones we haven't cached yet.
×
963
                for _, row := range rows {
×
964
                        chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid)
×
965

×
966
                        // Skip duplicates.
×
967
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
968
                                continue
×
969
                        }
970
                        edgesSeen[chanIDInt] = struct{}{}
×
971

×
972
                        // Check cache first.
×
973
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
974
                                hits++
×
975
                                edges = append(edges, channel)
×
976
                                continue
×
977
                        }
978

979
                        // Mark this row as one we need to batch load data for.
980
                        uncachedRows = append(uncachedRows, row)
×
981
                }
982

983
                // If there are no uncached rows, then we can return early.
984
                if len(uncachedRows) == 0 {
×
985
                        return nil
×
986
                }
×
987

988
                // Batch load data for all uncached channels.
989
                newEdges, err := batchBuildChannelEdges(
×
990
                        ctx, s.cfg, db, uncachedRows,
×
991
                )
×
992
                if err != nil {
×
993
                        return fmt.Errorf("unable to batch build channel "+
×
994
                                "edges: %w", err)
×
995
                }
×
996

997
                edges = append(edges, newEdges...)
×
998

×
999
                return nil
×
1000
        }, sqldb.NoOpReset)
1001
        if err != nil {
×
1002
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1003
        }
×
1004

1005
        // Insert any edges loaded from disk into the cache.
1006
        for chanid, channel := range edgesToCache {
×
1007
                s.chanCache.insert(chanid, channel)
×
1008
        }
×
1009

1010
        if len(edges) > 0 {
×
1011
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1012
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1013
        } else {
×
1014
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1015
                        "horizon (%s, %s)", startTime, endTime)
×
1016
        }
×
1017

1018
        return edges, nil
×
1019
}
1020

1021
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1022
// data to the call-back. If withAddrs is true, then the call-back will also be
1023
// provided with the addresses associated with the node. The address retrieval
1024
// result in an additional round-trip to the database, so it should only be used
1025
// if the addresses are actually needed.
1026
//
1027
// NOTE: part of the V1Store interface.
1028
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1029
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1030
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1031

×
1032
        type nodeCachedBatchData struct {
×
1033
                features      map[int64][]int
×
1034
                addrs         map[int64][]nodeAddress
×
1035
                chanBatchData *batchChannelData
×
1036
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1037
        }
×
1038

×
1039
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1040
                // pageQueryFunc is used to query the next page of nodes.
×
1041
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1042
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1043

×
1044
                        return db.ListNodeIDsAndPubKeys(
×
1045
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1046
                                        Version: int16(ProtocolV1),
×
1047
                                        ID:      lastID,
×
1048
                                        Limit:   limit,
×
1049
                                },
×
1050
                        )
×
1051
                }
×
1052

1053
                // batchDataFunc is then used to batch load the data required
1054
                // for each page of nodes.
1055
                batchDataFunc := func(ctx context.Context,
×
1056
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1057

×
1058
                        // Batch load node features.
×
1059
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1060
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1061
                        )
×
1062
                        if err != nil {
×
1063
                                return nil, fmt.Errorf("unable to batch load "+
×
1064
                                        "node features: %w", err)
×
1065
                        }
×
1066

1067
                        // Maybe fetch the node's addresses if requested.
1068
                        var nodeAddrs map[int64][]nodeAddress
×
1069
                        if withAddrs {
×
1070
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1071
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1072
                                )
×
1073
                                if err != nil {
×
1074
                                        return nil, fmt.Errorf("unable to "+
×
1075
                                                "batch load node "+
×
1076
                                                "addresses: %w", err)
×
1077
                                }
×
1078
                        }
1079

1080
                        // Batch load ALL unique channels for ALL nodes in this
1081
                        // page.
1082
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1083
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1084
                                        Version:  int16(ProtocolV1),
×
1085
                                        Node1Ids: nodeIDs,
×
1086
                                        Node2Ids: nodeIDs,
×
1087
                                },
×
1088
                        )
×
1089
                        if err != nil {
×
1090
                                return nil, fmt.Errorf("unable to batch "+
×
1091
                                        "fetch channels for nodes: %w", err)
×
1092
                        }
×
1093

1094
                        // Deduplicate channels and collect IDs.
1095
                        var (
×
1096
                                allChannelIDs []int64
×
1097
                                allPolicyIDs  []int64
×
1098
                        )
×
1099
                        uniqueChannels := make(
×
1100
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1101
                        )
×
1102

×
1103
                        for _, channel := range allChannels {
×
1104
                                channelID := channel.GraphChannel.ID
×
1105

×
1106
                                // Only process each unique channel once.
×
1107
                                _, exists := uniqueChannels[channelID]
×
1108
                                if exists {
×
1109
                                        continue
×
1110
                                }
1111

1112
                                uniqueChannels[channelID] = channel
×
1113
                                allChannelIDs = append(allChannelIDs, channelID)
×
1114

×
1115
                                if channel.Policy1ID.Valid {
×
1116
                                        allPolicyIDs = append(
×
1117
                                                allPolicyIDs,
×
1118
                                                channel.Policy1ID.Int64,
×
1119
                                        )
×
1120
                                }
×
1121
                                if channel.Policy2ID.Valid {
×
1122
                                        allPolicyIDs = append(
×
1123
                                                allPolicyIDs,
×
1124
                                                channel.Policy2ID.Int64,
×
1125
                                        )
×
1126
                                }
×
1127
                        }
1128

1129
                        // Batch load channel data for all unique channels.
1130
                        channelBatchData, err := batchLoadChannelData(
×
1131
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1132
                                allPolicyIDs,
×
1133
                        )
×
1134
                        if err != nil {
×
1135
                                return nil, fmt.Errorf("unable to batch "+
×
1136
                                        "load channel data: %w", err)
×
1137
                        }
×
1138

1139
                        // Create map of node ID to channels that involve this
1140
                        // node.
1141
                        nodeIDSet := make(map[int64]bool)
×
1142
                        for _, nodeID := range nodeIDs {
×
1143
                                nodeIDSet[nodeID] = true
×
1144
                        }
×
1145

1146
                        nodeChannelMap := make(
×
1147
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1148
                        )
×
1149
                        for _, channel := range uniqueChannels {
×
1150
                                // Add channel to both nodes if they're in our
×
1151
                                // current page.
×
1152
                                node1 := channel.GraphChannel.NodeID1
×
1153
                                if nodeIDSet[node1] {
×
1154
                                        nodeChannelMap[node1] = append(
×
1155
                                                nodeChannelMap[node1], channel,
×
1156
                                        )
×
1157
                                }
×
1158
                                node2 := channel.GraphChannel.NodeID2
×
1159
                                if nodeIDSet[node2] {
×
1160
                                        nodeChannelMap[node2] = append(
×
1161
                                                nodeChannelMap[node2], channel,
×
1162
                                        )
×
1163
                                }
×
1164
                        }
1165

1166
                        return &nodeCachedBatchData{
×
1167
                                features:      nodeFeatures,
×
1168
                                addrs:         nodeAddrs,
×
1169
                                chanBatchData: channelBatchData,
×
1170
                                chanMap:       nodeChannelMap,
×
1171
                        }, nil
×
1172
                }
1173

1174
                // processItem is used to process each node in the current page.
1175
                processItem := func(ctx context.Context,
×
1176
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1177
                        batchData *nodeCachedBatchData) error {
×
1178

×
1179
                        // Build feature vector for this node.
×
1180
                        fv := lnwire.EmptyFeatureVector()
×
1181
                        features, exists := batchData.features[nodeData.ID]
×
1182
                        if exists {
×
1183
                                for _, bit := range features {
×
1184
                                        fv.Set(lnwire.FeatureBit(bit))
×
1185
                                }
×
1186
                        }
1187

1188
                        var nodePub route.Vertex
×
1189
                        copy(nodePub[:], nodeData.PubKey)
×
1190

×
1191
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1192

×
1193
                        toNodeCallback := func() route.Vertex {
×
1194
                                return nodePub
×
1195
                        }
×
1196

1197
                        // Build cached channels map for this node.
1198
                        channels := make(map[uint64]*DirectedChannel)
×
1199
                        for _, channelRow := range nodeChannels {
×
1200
                                directedChan, err := buildDirectedChannel(
×
1201
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1202
                                        channelRow, batchData.chanBatchData, fv,
×
1203
                                        toNodeCallback,
×
1204
                                )
×
1205
                                if err != nil {
×
1206
                                        return err
×
1207
                                }
×
1208

1209
                                channels[directedChan.ChannelID] = directedChan
×
1210
                        }
1211

1212
                        addrs, err := buildNodeAddresses(
×
1213
                                batchData.addrs[nodeData.ID],
×
1214
                        )
×
1215
                        if err != nil {
×
1216
                                return fmt.Errorf("unable to build node "+
×
1217
                                        "addresses: %w", err)
×
1218
                        }
×
1219

1220
                        return cb(ctx, nodePub, addrs, channels)
×
1221
                }
1222

1223
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1224
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1225
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1226
                                return node.ID
×
1227
                        },
×
1228
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1229
                                error) {
×
1230

×
1231
                                return node.ID, nil
×
1232
                        },
×
1233
                        batchDataFunc, processItem,
1234
                )
1235
        }, reset)
1236
}
1237

1238
// ForEachChannelCacheable iterates through all the channel edges stored
1239
// within the graph and invokes the passed callback for each edge. The
1240
// callback takes two edges as since this is a directed graph, both the
1241
// in/out edges are visited. If the callback returns an error, then the
1242
// transaction is aborted and the iteration stops early.
1243
//
1244
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1245
// pointer for that particular channel edge routing policy will be
1246
// passed into the callback.
1247
//
1248
// NOTE: this method is like ForEachChannel but fetches only the data
1249
// required for the graph cache.
1250
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1251
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1252
        reset func()) error {
×
1253

×
1254
        ctx := context.TODO()
×
1255

×
1256
        handleChannel := func(_ context.Context,
×
1257
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1258

×
1259
                node1, node2, err := buildNodeVertices(
×
1260
                        row.Node1Pubkey, row.Node2Pubkey,
×
1261
                )
×
1262
                if err != nil {
×
1263
                        return err
×
1264
                }
×
1265

1266
                edge := buildCacheableChannelInfo(
×
1267
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1268
                )
×
1269

×
1270
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1271
                if err != nil {
×
1272
                        return err
×
1273
                }
×
1274

1275
                pol1, pol2, err := buildCachedChanPolicies(
×
1276
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1277
                )
×
1278
                if err != nil {
×
1279
                        return err
×
1280
                }
×
1281

1282
                return cb(edge, pol1, pol2)
×
1283
        }
1284

1285
        extractCursor := func(
×
1286
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1287

×
1288
                return row.ID
×
1289
        }
×
1290

1291
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1292
                //nolint:ll
×
1293
                queryFunc := func(ctx context.Context, lastID int64,
×
1294
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1295
                        error) {
×
1296

×
1297
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1298
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1299
                                        Version: int16(ProtocolV1),
×
1300
                                        ID:      lastID,
×
1301
                                        Limit:   limit,
×
1302
                                },
×
1303
                        )
×
1304
                }
×
1305

1306
                return sqldb.ExecutePaginatedQuery(
×
1307
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1308
                        extractCursor, handleChannel,
×
1309
                )
×
1310
        }, reset)
1311
}
1312

1313
// ForEachChannel iterates through all the channel edges stored within the
1314
// graph and invokes the passed callback for each edge. The callback takes two
1315
// edges as since this is a directed graph, both the in/out edges are visited.
1316
// If the callback returns an error, then the transaction is aborted and the
1317
// iteration stops early.
1318
//
1319
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1320
// for that particular channel edge routing policy will be passed into the
1321
// callback.
1322
//
1323
// NOTE: part of the V1Store interface.
1324
func (s *SQLStore) ForEachChannel(ctx context.Context,
1325
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1326
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1327

×
1328
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1329
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1330
        }, reset)
×
1331
}
1332

1333
// FilterChannelRange returns the channel ID's of all known channels which were
1334
// mined in a block height within the passed range. The channel IDs are grouped
1335
// by their common block height. This method can be used to quickly share with a
1336
// peer the set of channels we know of within a particular range to catch them
1337
// up after a period of time offline. If withTimestamps is true then the
1338
// timestamp info of the latest received channel update messages of the channel
1339
// will be included in the response.
1340
//
1341
// NOTE: This is part of the V1Store interface.
1342
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1343
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1344

×
1345
        var (
×
1346
                ctx       = context.TODO()
×
1347
                startSCID = &lnwire.ShortChannelID{
×
1348
                        BlockHeight: startHeight,
×
1349
                }
×
1350
                endSCID = lnwire.ShortChannelID{
×
1351
                        BlockHeight: endHeight,
×
1352
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1353
                        TxPosition:  math.MaxUint16,
×
1354
                }
×
1355
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1356
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1357
        )
×
1358

×
1359
        // 1) get all channels where channelID is between start and end chan ID.
×
1360
        // 2) skip if not public (ie, no channel_proof)
×
1361
        // 3) collect that channel.
×
1362
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1363
        //    and add those timestamps to the collected channel.
×
1364
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1365
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1366
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1367
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1368
                                StartScid: chanIDStart,
×
1369
                                EndScid:   chanIDEnd,
×
1370
                        },
×
1371
                )
×
1372
                if err != nil {
×
1373
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1374
                                err)
×
1375
                }
×
1376

1377
                for _, dbChan := range dbChans {
×
1378
                        cid := lnwire.NewShortChanIDFromInt(
×
1379
                                byteOrder.Uint64(dbChan.Scid),
×
1380
                        )
×
1381
                        chanInfo := NewChannelUpdateInfo(
×
1382
                                cid, time.Time{}, time.Time{},
×
1383
                        )
×
1384

×
1385
                        if !withTimestamps {
×
1386
                                channelsPerBlock[cid.BlockHeight] = append(
×
1387
                                        channelsPerBlock[cid.BlockHeight],
×
1388
                                        chanInfo,
×
1389
                                )
×
1390

×
1391
                                continue
×
1392
                        }
1393

1394
                        //nolint:ll
1395
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1396
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1397
                                        Version:   int16(ProtocolV1),
×
1398
                                        ChannelID: dbChan.ID,
×
1399
                                        NodeID:    dbChan.NodeID1,
×
1400
                                },
×
1401
                        )
×
1402
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1403
                                return fmt.Errorf("unable to fetch node1 "+
×
1404
                                        "policy: %w", err)
×
1405
                        } else if err == nil {
×
1406
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1407
                                        node1Policy.LastUpdate.Int64, 0,
×
1408
                                )
×
1409
                        }
×
1410

1411
                        //nolint:ll
1412
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1413
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1414
                                        Version:   int16(ProtocolV1),
×
1415
                                        ChannelID: dbChan.ID,
×
1416
                                        NodeID:    dbChan.NodeID2,
×
1417
                                },
×
1418
                        )
×
1419
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1420
                                return fmt.Errorf("unable to fetch node2 "+
×
1421
                                        "policy: %w", err)
×
1422
                        } else if err == nil {
×
1423
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1424
                                        node2Policy.LastUpdate.Int64, 0,
×
1425
                                )
×
1426
                        }
×
1427

1428
                        channelsPerBlock[cid.BlockHeight] = append(
×
1429
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1430
                        )
×
1431
                }
1432

1433
                return nil
×
1434
        }, func() {
×
1435
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1436
        })
×
1437
        if err != nil {
×
1438
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1439
        }
×
1440

1441
        if len(channelsPerBlock) == 0 {
×
1442
                return nil, nil
×
1443
        }
×
1444

1445
        // Return the channel ranges in ascending block height order.
1446
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1447
        slices.Sort(blocks)
×
1448

×
1449
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1450
                return BlockChannelRange{
×
1451
                        Height:   block,
×
1452
                        Channels: channelsPerBlock[block],
×
1453
                }
×
1454
        }), nil
×
1455
}
1456

1457
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1458
// zombie. This method is used on an ad-hoc basis, when channels need to be
1459
// marked as zombies outside the normal pruning cycle.
1460
//
1461
// NOTE: part of the V1Store interface.
1462
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1463
        pubKey1, pubKey2 [33]byte) error {
×
1464

×
1465
        ctx := context.TODO()
×
1466

×
1467
        s.cacheMu.Lock()
×
1468
        defer s.cacheMu.Unlock()
×
1469

×
1470
        chanIDB := channelIDToBytes(chanID)
×
1471

×
1472
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1473
                return db.UpsertZombieChannel(
×
1474
                        ctx, sqlc.UpsertZombieChannelParams{
×
1475
                                Version:  int16(ProtocolV1),
×
1476
                                Scid:     chanIDB,
×
1477
                                NodeKey1: pubKey1[:],
×
1478
                                NodeKey2: pubKey2[:],
×
1479
                        },
×
1480
                )
×
1481
        }, sqldb.NoOpReset)
×
1482
        if err != nil {
×
1483
                return fmt.Errorf("unable to upsert zombie channel "+
×
1484
                        "(channel_id=%d): %w", chanID, err)
×
1485
        }
×
1486

1487
        s.rejectCache.remove(chanID)
×
1488
        s.chanCache.remove(chanID)
×
1489

×
1490
        return nil
×
1491
}
1492

1493
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1494
//
1495
// NOTE: part of the V1Store interface.
1496
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1497
        s.cacheMu.Lock()
×
1498
        defer s.cacheMu.Unlock()
×
1499

×
1500
        var (
×
1501
                ctx     = context.TODO()
×
1502
                chanIDB = channelIDToBytes(chanID)
×
1503
        )
×
1504

×
1505
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1506
                res, err := db.DeleteZombieChannel(
×
1507
                        ctx, sqlc.DeleteZombieChannelParams{
×
1508
                                Scid:    chanIDB,
×
1509
                                Version: int16(ProtocolV1),
×
1510
                        },
×
1511
                )
×
1512
                if err != nil {
×
1513
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1514
                                err)
×
1515
                }
×
1516

1517
                rows, err := res.RowsAffected()
×
1518
                if err != nil {
×
1519
                        return err
×
1520
                }
×
1521

1522
                if rows == 0 {
×
1523
                        return ErrZombieEdgeNotFound
×
1524
                } else if rows > 1 {
×
1525
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1526
                                "expected 1", rows)
×
1527
                }
×
1528

1529
                return nil
×
1530
        }, sqldb.NoOpReset)
1531
        if err != nil {
×
1532
                return fmt.Errorf("unable to mark edge live "+
×
1533
                        "(channel_id=%d): %w", chanID, err)
×
1534
        }
×
1535

1536
        s.rejectCache.remove(chanID)
×
1537
        s.chanCache.remove(chanID)
×
1538

×
1539
        return err
×
1540
}
1541

1542
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1543
// zombie, then the two node public keys corresponding to this edge are also
1544
// returned.
1545
//
1546
// NOTE: part of the V1Store interface.
1547
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1548
        error) {
×
1549

×
1550
        var (
×
1551
                ctx              = context.TODO()
×
1552
                isZombie         bool
×
1553
                pubKey1, pubKey2 route.Vertex
×
1554
                chanIDB          = channelIDToBytes(chanID)
×
1555
        )
×
1556

×
1557
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1558
                zombie, err := db.GetZombieChannel(
×
1559
                        ctx, sqlc.GetZombieChannelParams{
×
1560
                                Scid:    chanIDB,
×
1561
                                Version: int16(ProtocolV1),
×
1562
                        },
×
1563
                )
×
1564
                if errors.Is(err, sql.ErrNoRows) {
×
1565
                        return nil
×
1566
                }
×
1567
                if err != nil {
×
1568
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1569
                                err)
×
1570
                }
×
1571

1572
                copy(pubKey1[:], zombie.NodeKey1)
×
1573
                copy(pubKey2[:], zombie.NodeKey2)
×
1574
                isZombie = true
×
1575

×
1576
                return nil
×
1577
        }, sqldb.NoOpReset)
1578
        if err != nil {
×
1579
                return false, route.Vertex{}, route.Vertex{},
×
1580
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1581
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1582
        }
×
1583

1584
        return isZombie, pubKey1, pubKey2, nil
×
1585
}
1586

1587
// NumZombies returns the current number of zombie channels in the graph.
1588
//
1589
// NOTE: part of the V1Store interface.
1590
func (s *SQLStore) NumZombies() (uint64, error) {
×
1591
        var (
×
1592
                ctx        = context.TODO()
×
1593
                numZombies uint64
×
1594
        )
×
1595
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1596
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1597
                if err != nil {
×
1598
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1599
                                err)
×
1600
                }
×
1601

1602
                numZombies = uint64(count)
×
1603

×
1604
                return nil
×
1605
        }, sqldb.NoOpReset)
1606
        if err != nil {
×
1607
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1608
        }
×
1609

1610
        return numZombies, nil
×
1611
}
1612

1613
// DeleteChannelEdges removes edges with the given channel IDs from the
1614
// database and marks them as zombies. This ensures that we're unable to re-add
1615
// it to our database once again. If an edge does not exist within the
1616
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1617
// true, then when we mark these edges as zombies, we'll set up the keys such
1618
// that we require the node that failed to send the fresh update to be the one
1619
// that resurrects the channel from its zombie state. The markZombie bool
1620
// denotes whether to mark the channel as a zombie.
1621
//
1622
// NOTE: part of the V1Store interface.
1623
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1624
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1625

×
1626
        s.cacheMu.Lock()
×
1627
        defer s.cacheMu.Unlock()
×
1628

×
1629
        // Keep track of which channels we end up finding so that we can
×
1630
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1631
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1632
        for _, chanID := range chanIDs {
×
1633
                chanLookup[chanID] = struct{}{}
×
1634
        }
×
1635

1636
        var (
×
1637
                ctx   = context.TODO()
×
1638
                edges []*models.ChannelEdgeInfo
×
1639
        )
×
1640
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1641
                // First, collect all channel rows.
×
1642
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1643
                chanCallBack := func(ctx context.Context,
×
1644
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1645

×
1646
                        // Deleting the entry from the map indicates that we
×
1647
                        // have found the channel.
×
1648
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1649
                        delete(chanLookup, scid)
×
1650

×
1651
                        channelRows = append(channelRows, row)
×
1652

×
1653
                        return nil
×
1654
                }
×
1655

1656
                err := s.forEachChanWithPoliciesInSCIDList(
×
1657
                        ctx, db, chanCallBack, chanIDs,
×
1658
                )
×
1659
                if err != nil {
×
1660
                        return err
×
1661
                }
×
1662

1663
                if len(chanLookup) > 0 {
×
1664
                        return ErrEdgeNotFound
×
1665
                }
×
1666

1667
                if len(channelRows) == 0 {
×
1668
                        return nil
×
1669
                }
×
1670

1671
                // Batch build all channel edges.
1672
                var chanIDsToDelete []int64
×
1673
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
1674
                        ctx, s.cfg, db, channelRows,
×
1675
                )
×
1676
                if err != nil {
×
1677
                        return err
×
1678
                }
×
1679

1680
                if markZombie {
×
1681
                        for i, row := range channelRows {
×
1682
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1683

×
1684
                                err := handleZombieMarking(
×
1685
                                        ctx, db, row, edges[i],
×
1686
                                        strictZombiePruning, scid,
×
1687
                                )
×
1688
                                if err != nil {
×
1689
                                        return fmt.Errorf("unable to mark "+
×
1690
                                                "channel as zombie: %w", err)
×
1691
                                }
×
1692
                        }
1693
                }
1694

1695
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1696
        }, func() {
×
1697
                edges = nil
×
1698

×
1699
                // Re-fill the lookup map.
×
1700
                for _, chanID := range chanIDs {
×
1701
                        chanLookup[chanID] = struct{}{}
×
1702
                }
×
1703
        })
1704
        if err != nil {
×
1705
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1706
                        err)
×
1707
        }
×
1708

1709
        for _, chanID := range chanIDs {
×
1710
                s.rejectCache.remove(chanID)
×
1711
                s.chanCache.remove(chanID)
×
1712
        }
×
1713

1714
        return edges, nil
×
1715
}
1716

1717
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1718
// channel identified by the channel ID. If the channel can't be found, then
1719
// ErrEdgeNotFound is returned. A struct which houses the general information
1720
// for the channel itself is returned as well as two structs that contain the
1721
// routing policies for the channel in either direction.
1722
//
1723
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1724
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1725
// the ChannelEdgeInfo will only include the public keys of each node.
1726
//
1727
// NOTE: part of the V1Store interface.
1728
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1729
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1730
        *models.ChannelEdgePolicy, error) {
×
1731

×
1732
        var (
×
1733
                ctx              = context.TODO()
×
1734
                edge             *models.ChannelEdgeInfo
×
1735
                policy1, policy2 *models.ChannelEdgePolicy
×
1736
                chanIDB          = channelIDToBytes(chanID)
×
1737
        )
×
1738
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1739
                row, err := db.GetChannelBySCIDWithPolicies(
×
1740
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1741
                                Scid:    chanIDB,
×
1742
                                Version: int16(ProtocolV1),
×
1743
                        },
×
1744
                )
×
1745
                if errors.Is(err, sql.ErrNoRows) {
×
1746
                        // First check if this edge is perhaps in the zombie
×
1747
                        // index.
×
1748
                        zombie, err := db.GetZombieChannel(
×
1749
                                ctx, sqlc.GetZombieChannelParams{
×
1750
                                        Scid:    chanIDB,
×
1751
                                        Version: int16(ProtocolV1),
×
1752
                                },
×
1753
                        )
×
1754
                        if errors.Is(err, sql.ErrNoRows) {
×
1755
                                return ErrEdgeNotFound
×
1756
                        } else if err != nil {
×
1757
                                return fmt.Errorf("unable to check if "+
×
1758
                                        "channel is zombie: %w", err)
×
1759
                        }
×
1760

1761
                        // At this point, we know the channel is a zombie, so
1762
                        // we'll return an error indicating this, and we will
1763
                        // populate the edge info with the public keys of each
1764
                        // party as this is the only information we have about
1765
                        // it.
1766
                        edge = &models.ChannelEdgeInfo{}
×
1767
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1768
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1769

×
1770
                        return ErrZombieEdge
×
1771
                } else if err != nil {
×
1772
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1773
                }
×
1774

1775
                node1, node2, err := buildNodeVertices(
×
1776
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1777
                )
×
1778
                if err != nil {
×
1779
                        return err
×
1780
                }
×
1781

1782
                edge, err = getAndBuildEdgeInfo(
×
1783
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
1784
                )
×
1785
                if err != nil {
×
1786
                        return fmt.Errorf("unable to build channel info: %w",
×
1787
                                err)
×
1788
                }
×
1789

1790
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1791
                if err != nil {
×
1792
                        return fmt.Errorf("unable to extract channel "+
×
1793
                                "policies: %w", err)
×
1794
                }
×
1795

1796
                policy1, policy2, err = getAndBuildChanPolicies(
×
1797
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
1798
                        node1, node2,
×
1799
                )
×
1800
                if err != nil {
×
1801
                        return fmt.Errorf("unable to build channel "+
×
1802
                                "policies: %w", err)
×
1803
                }
×
1804

1805
                return nil
×
1806
        }, sqldb.NoOpReset)
1807
        if err != nil {
×
1808
                // If we are returning the ErrZombieEdge, then we also need to
×
1809
                // return the edge info as the method comment indicates that
×
1810
                // this will be populated when the edge is a zombie.
×
1811
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1812
                        err)
×
1813
        }
×
1814

1815
        return edge, policy1, policy2, nil
×
1816
}
1817

1818
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1819
// the channel identified by the funding outpoint. If the channel can't be
1820
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1821
// information for the channel itself is returned as well as two structs that
1822
// contain the routing policies for the channel in either direction.
1823
//
1824
// NOTE: part of the V1Store interface.
1825
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1826
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1827
        *models.ChannelEdgePolicy, error) {
×
1828

×
1829
        var (
×
1830
                ctx              = context.TODO()
×
1831
                edge             *models.ChannelEdgeInfo
×
1832
                policy1, policy2 *models.ChannelEdgePolicy
×
1833
        )
×
1834
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1835
                row, err := db.GetChannelByOutpointWithPolicies(
×
1836
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1837
                                Outpoint: op.String(),
×
1838
                                Version:  int16(ProtocolV1),
×
1839
                        },
×
1840
                )
×
1841
                if errors.Is(err, sql.ErrNoRows) {
×
1842
                        return ErrEdgeNotFound
×
1843
                } else if err != nil {
×
1844
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1845
                }
×
1846

1847
                node1, node2, err := buildNodeVertices(
×
1848
                        row.Node1Pubkey, row.Node2Pubkey,
×
1849
                )
×
1850
                if err != nil {
×
1851
                        return err
×
1852
                }
×
1853

1854
                edge, err = getAndBuildEdgeInfo(
×
1855
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
1856
                )
×
1857
                if err != nil {
×
1858
                        return fmt.Errorf("unable to build channel info: %w",
×
1859
                                err)
×
1860
                }
×
1861

1862
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1863
                if err != nil {
×
1864
                        return fmt.Errorf("unable to extract channel "+
×
1865
                                "policies: %w", err)
×
1866
                }
×
1867

1868
                policy1, policy2, err = getAndBuildChanPolicies(
×
1869
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
1870
                        node1, node2,
×
1871
                )
×
1872
                if err != nil {
×
1873
                        return fmt.Errorf("unable to build channel "+
×
1874
                                "policies: %w", err)
×
1875
                }
×
1876

1877
                return nil
×
1878
        }, sqldb.NoOpReset)
1879
        if err != nil {
×
1880
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1881
                        err)
×
1882
        }
×
1883

1884
        return edge, policy1, policy2, nil
×
1885
}
1886

1887
// HasChannelEdge returns true if the database knows of a channel edge with the
1888
// passed channel ID, and false otherwise. If an edge with that ID is found
1889
// within the graph, then two time stamps representing the last time the edge
1890
// was updated for both directed edges are returned along with the boolean. If
1891
// it is not found, then the zombie index is checked and its result is returned
1892
// as the second boolean.
1893
//
1894
// NOTE: part of the V1Store interface.
1895
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1896
        bool, error) {
×
1897

×
1898
        ctx := context.TODO()
×
1899

×
1900
        var (
×
1901
                exists          bool
×
1902
                isZombie        bool
×
1903
                node1LastUpdate time.Time
×
1904
                node2LastUpdate time.Time
×
1905
        )
×
1906

×
1907
        // We'll query the cache with the shared lock held to allow multiple
×
1908
        // readers to access values in the cache concurrently if they exist.
×
1909
        s.cacheMu.RLock()
×
1910
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1911
                s.cacheMu.RUnlock()
×
1912
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1913
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1914
                exists, isZombie = entry.flags.unpack()
×
1915

×
1916
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1917
        }
×
1918
        s.cacheMu.RUnlock()
×
1919

×
1920
        s.cacheMu.Lock()
×
1921
        defer s.cacheMu.Unlock()
×
1922

×
1923
        // The item was not found with the shared lock, so we'll acquire the
×
1924
        // exclusive lock and check the cache again in case another method added
×
1925
        // the entry to the cache while no lock was held.
×
1926
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1927
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1928
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1929
                exists, isZombie = entry.flags.unpack()
×
1930

×
1931
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1932
        }
×
1933

1934
        chanIDB := channelIDToBytes(chanID)
×
1935
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1936
                channel, err := db.GetChannelBySCID(
×
1937
                        ctx, sqlc.GetChannelBySCIDParams{
×
1938
                                Scid:    chanIDB,
×
1939
                                Version: int16(ProtocolV1),
×
1940
                        },
×
1941
                )
×
1942
                if errors.Is(err, sql.ErrNoRows) {
×
1943
                        // Check if it is a zombie channel.
×
1944
                        isZombie, err = db.IsZombieChannel(
×
1945
                                ctx, sqlc.IsZombieChannelParams{
×
1946
                                        Scid:    chanIDB,
×
1947
                                        Version: int16(ProtocolV1),
×
1948
                                },
×
1949
                        )
×
1950
                        if err != nil {
×
1951
                                return fmt.Errorf("could not check if channel "+
×
1952
                                        "is zombie: %w", err)
×
1953
                        }
×
1954

1955
                        return nil
×
1956
                } else if err != nil {
×
1957
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1958
                }
×
1959

1960
                exists = true
×
1961

×
1962
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
1963
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1964
                                Version:   int16(ProtocolV1),
×
1965
                                ChannelID: channel.ID,
×
1966
                                NodeID:    channel.NodeID1,
×
1967
                        },
×
1968
                )
×
1969
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1970
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
1971
                                err)
×
1972
                } else if err == nil {
×
1973
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
1974
                }
×
1975

1976
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
1977
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1978
                                Version:   int16(ProtocolV1),
×
1979
                                ChannelID: channel.ID,
×
1980
                                NodeID:    channel.NodeID2,
×
1981
                        },
×
1982
                )
×
1983
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1984
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
1985
                                err)
×
1986
                } else if err == nil {
×
1987
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
1988
                }
×
1989

1990
                return nil
×
1991
        }, sqldb.NoOpReset)
1992
        if err != nil {
×
1993
                return time.Time{}, time.Time{}, false, false,
×
1994
                        fmt.Errorf("unable to fetch channel: %w", err)
×
1995
        }
×
1996

1997
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
1998
                upd1Time: node1LastUpdate.Unix(),
×
1999
                upd2Time: node2LastUpdate.Unix(),
×
2000
                flags:    packRejectFlags(exists, isZombie),
×
2001
        })
×
2002

×
2003
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2004
}
2005

2006
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2007
// passed channel point (outpoint). If the passed channel doesn't exist within
2008
// the database, then ErrEdgeNotFound is returned.
2009
//
2010
// NOTE: part of the V1Store interface.
2011
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2012
        var (
×
2013
                ctx       = context.TODO()
×
2014
                channelID uint64
×
2015
        )
×
2016
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2017
                chanID, err := db.GetSCIDByOutpoint(
×
2018
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2019
                                Outpoint: chanPoint.String(),
×
2020
                                Version:  int16(ProtocolV1),
×
2021
                        },
×
2022
                )
×
2023
                if errors.Is(err, sql.ErrNoRows) {
×
2024
                        return ErrEdgeNotFound
×
2025
                } else if err != nil {
×
2026
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2027
                                err)
×
2028
                }
×
2029

2030
                channelID = byteOrder.Uint64(chanID)
×
2031

×
2032
                return nil
×
2033
        }, sqldb.NoOpReset)
2034
        if err != nil {
×
2035
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2036
        }
×
2037

2038
        return channelID, nil
×
2039
}
2040

2041
// IsPublicNode is a helper method that determines whether the node with the
2042
// given public key is seen as a public node in the graph from the graph's
2043
// source node's point of view.
2044
//
2045
// NOTE: part of the V1Store interface.
2046
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2047
        ctx := context.TODO()
×
2048

×
2049
        var isPublic bool
×
2050
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2051
                var err error
×
2052
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2053

×
2054
                return err
×
2055
        }, sqldb.NoOpReset)
×
2056
        if err != nil {
×
2057
                return false, fmt.Errorf("unable to check if node is "+
×
2058
                        "public: %w", err)
×
2059
        }
×
2060

2061
        return isPublic, nil
×
2062
}
2063

2064
// FetchChanInfos returns the set of channel edges that correspond to the passed
2065
// channel ID's. If an edge is the query is unknown to the database, it will
2066
// skipped and the result will contain only those edges that exist at the time
2067
// of the query. This can be used to respond to peer queries that are seeking to
2068
// fill in gaps in their view of the channel graph.
2069
//
2070
// NOTE: part of the V1Store interface.
2071
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2072
        var (
×
2073
                ctx   = context.TODO()
×
2074
                edges = make(map[uint64]ChannelEdge)
×
2075
        )
×
2076
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2077
                // First, collect all channel rows.
×
2078
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2079
                chanCallBack := func(ctx context.Context,
×
2080
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2081

×
2082
                        channelRows = append(channelRows, row)
×
2083
                        return nil
×
2084
                }
×
2085

2086
                err := s.forEachChanWithPoliciesInSCIDList(
×
2087
                        ctx, db, chanCallBack, chanIDs,
×
2088
                )
×
2089
                if err != nil {
×
2090
                        return err
×
2091
                }
×
2092

2093
                if len(channelRows) == 0 {
×
2094
                        return nil
×
2095
                }
×
2096

2097
                // Batch build all channel edges.
2098
                chans, err := batchBuildChannelEdges(
×
2099
                        ctx, s.cfg, db, channelRows,
×
2100
                )
×
2101
                if err != nil {
×
2102
                        return fmt.Errorf("unable to build channel edges: %w",
×
2103
                                err)
×
2104
                }
×
2105

2106
                for _, c := range chans {
×
2107
                        edges[c.Info.ChannelID] = c
×
2108
                }
×
2109

2110
                return err
×
2111
        }, func() {
×
2112
                clear(edges)
×
2113
        })
×
2114
        if err != nil {
×
2115
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2116
        }
×
2117

2118
        res := make([]ChannelEdge, 0, len(edges))
×
2119
        for _, chanID := range chanIDs {
×
2120
                edge, ok := edges[chanID]
×
2121
                if !ok {
×
2122
                        continue
×
2123
                }
2124

2125
                res = append(res, edge)
×
2126
        }
2127

2128
        return res, nil
×
2129
}
2130

2131
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2132
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2133
// channels in a paginated manner.
2134
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2135
        db SQLQueries, cb func(ctx context.Context,
2136
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2137
        chanIDs []uint64) error {
×
2138

×
2139
        queryWrapper := func(ctx context.Context,
×
2140
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2141
                error) {
×
2142

×
2143
                return db.GetChannelsBySCIDWithPolicies(
×
2144
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2145
                                Version: int16(ProtocolV1),
×
2146
                                Scids:   scids,
×
2147
                        },
×
2148
                )
×
2149
        }
×
2150

2151
        return sqldb.ExecuteBatchQuery(
×
2152
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2153
                cb,
×
2154
        )
×
2155
}
2156

2157
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2158
// ID's that we don't know and are not known zombies of the passed set. In other
2159
// words, we perform a set difference of our set of chan ID's and the ones
2160
// passed in. This method can be used by callers to determine the set of
2161
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2162
// known zombies is also returned.
2163
//
2164
// NOTE: part of the V1Store interface.
2165
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2166
        []ChannelUpdateInfo, error) {
×
2167

×
2168
        var (
×
2169
                ctx          = context.TODO()
×
2170
                newChanIDs   []uint64
×
2171
                knownZombies []ChannelUpdateInfo
×
2172
                infoLookup   = make(
×
2173
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2174
                )
×
2175
        )
×
2176

×
2177
        // We first build a lookup map of the channel ID's to the
×
2178
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2179
        // already know about.
×
2180
        for _, chanInfo := range chansInfo {
×
2181
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2182
        }
×
2183

2184
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2185
                // The call-back function deletes known channels from
×
2186
                // infoLookup, so that we can later check which channels are
×
2187
                // zombies by only looking at the remaining channels in the set.
×
2188
                cb := func(ctx context.Context,
×
2189
                        channel sqlc.GraphChannel) error {
×
2190

×
2191
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2192

×
2193
                        return nil
×
2194
                }
×
2195

2196
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2197
                if err != nil {
×
2198
                        return fmt.Errorf("unable to iterate through "+
×
2199
                                "channels: %w", err)
×
2200
                }
×
2201

2202
                // We want to ensure that we deal with the channels in the
2203
                // same order that they were passed in, so we iterate over the
2204
                // original chansInfo slice and then check if that channel is
2205
                // still in the infoLookup map.
2206
                for _, chanInfo := range chansInfo {
×
2207
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2208
                        if _, ok := infoLookup[channelID]; !ok {
×
2209
                                continue
×
2210
                        }
2211

2212
                        isZombie, err := db.IsZombieChannel(
×
2213
                                ctx, sqlc.IsZombieChannelParams{
×
2214
                                        Scid:    channelIDToBytes(channelID),
×
2215
                                        Version: int16(ProtocolV1),
×
2216
                                },
×
2217
                        )
×
2218
                        if err != nil {
×
2219
                                return fmt.Errorf("unable to fetch zombie "+
×
2220
                                        "channel: %w", err)
×
2221
                        }
×
2222

2223
                        if isZombie {
×
2224
                                knownZombies = append(knownZombies, chanInfo)
×
2225

×
2226
                                continue
×
2227
                        }
2228

2229
                        newChanIDs = append(newChanIDs, channelID)
×
2230
                }
2231

2232
                return nil
×
2233
        }, func() {
×
2234
                newChanIDs = nil
×
2235
                knownZombies = nil
×
2236
                // Rebuild the infoLookup map in case of a rollback.
×
2237
                for _, chanInfo := range chansInfo {
×
2238
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2239
                        infoLookup[scid] = chanInfo
×
2240
                }
×
2241
        })
2242
        if err != nil {
×
2243
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2244
        }
×
2245

2246
        return newChanIDs, knownZombies, nil
×
2247
}
2248

2249
// forEachChanInSCIDList is a helper method that executes a paged query
2250
// against the database to fetch all channels that match the passed
2251
// ChannelUpdateInfo slice. The callback function is called for each channel
2252
// that is found.
2253
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2254
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2255
        chansInfo []ChannelUpdateInfo) error {
×
2256

×
2257
        queryWrapper := func(ctx context.Context,
×
2258
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2259

×
2260
                return db.GetChannelsBySCIDs(
×
2261
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2262
                                Version: int16(ProtocolV1),
×
2263
                                Scids:   scids,
×
2264
                        },
×
2265
                )
×
2266
        }
×
2267

2268
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2269
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2270

×
2271
                return channelIDToBytes(channelID)
×
2272
        }
×
2273

2274
        return sqldb.ExecuteBatchQuery(
×
2275
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2276
                cb,
×
2277
        )
×
2278
}
2279

2280
// PruneGraphNodes is a garbage collection method which attempts to prune out
2281
// any nodes from the channel graph that are currently unconnected. This ensure
2282
// that we only maintain a graph of reachable nodes. In the event that a pruned
2283
// node gains more channels, it will be re-added back to the graph.
2284
//
2285
// NOTE: this prunes nodes across protocol versions. It will never prune the
2286
// source nodes.
2287
//
2288
// NOTE: part of the V1Store interface.
2289
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2290
        var ctx = context.TODO()
×
2291

×
2292
        var prunedNodes []route.Vertex
×
2293
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2294
                var err error
×
2295
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2296

×
2297
                return err
×
2298
        }, func() {
×
2299
                prunedNodes = nil
×
2300
        })
×
2301
        if err != nil {
×
2302
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2303
        }
×
2304

2305
        return prunedNodes, nil
×
2306
}
2307

2308
// PruneGraph prunes newly closed channels from the channel graph in response
2309
// to a new block being solved on the network. Any transactions which spend the
2310
// funding output of any known channels within he graph will be deleted.
2311
// Additionally, the "prune tip", or the last block which has been used to
2312
// prune the graph is stored so callers can ensure the graph is fully in sync
2313
// with the current UTXO state. A slice of channels that have been closed by
2314
// the target block along with any pruned nodes are returned if the function
2315
// succeeds without error.
2316
//
2317
// NOTE: part of the V1Store interface.
2318
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2319
        blockHash *chainhash.Hash, blockHeight uint32) (
2320
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2321

×
2322
        ctx := context.TODO()
×
2323

×
2324
        s.cacheMu.Lock()
×
2325
        defer s.cacheMu.Unlock()
×
2326

×
2327
        var (
×
2328
                closedChans []*models.ChannelEdgeInfo
×
2329
                prunedNodes []route.Vertex
×
2330
        )
×
2331
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2332
                // First, collect all channel rows that need to be pruned.
×
2333
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2334
                channelCallback := func(ctx context.Context,
×
2335
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2336

×
2337
                        channelRows = append(channelRows, row)
×
2338

×
2339
                        return nil
×
2340
                }
×
2341

2342
                err := s.forEachChanInOutpoints(
×
2343
                        ctx, db, spentOutputs, channelCallback,
×
2344
                )
×
2345
                if err != nil {
×
2346
                        return fmt.Errorf("unable to fetch channels by "+
×
2347
                                "outpoints: %w", err)
×
2348
                }
×
2349

2350
                if len(channelRows) == 0 {
×
2351
                        // There are no channels to prune. So we can exit early
×
2352
                        // after updating the prune log.
×
2353
                        err = db.UpsertPruneLogEntry(
×
2354
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
2355
                                        BlockHash:   blockHash[:],
×
2356
                                        BlockHeight: int64(blockHeight),
×
2357
                                },
×
2358
                        )
×
2359
                        if err != nil {
×
2360
                                return fmt.Errorf("unable to insert prune log "+
×
2361
                                        "entry: %w", err)
×
2362
                        }
×
2363

2364
                        return nil
×
2365
                }
2366

2367
                // Batch build all channel edges for pruning.
2368
                var chansToDelete []int64
×
2369
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
2370
                        ctx, s.cfg, db, channelRows,
×
2371
                )
×
2372
                if err != nil {
×
2373
                        return err
×
2374
                }
×
2375

2376
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2377
                if err != nil {
×
2378
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2379
                }
×
2380

2381
                err = db.UpsertPruneLogEntry(
×
2382
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2383
                                BlockHash:   blockHash[:],
×
2384
                                BlockHeight: int64(blockHeight),
×
2385
                        },
×
2386
                )
×
2387
                if err != nil {
×
2388
                        return fmt.Errorf("unable to insert prune log "+
×
2389
                                "entry: %w", err)
×
2390
                }
×
2391

2392
                // Now that we've pruned some channels, we'll also prune any
2393
                // nodes that no longer have any channels.
2394
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2395
                if err != nil {
×
2396
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2397
                                err)
×
2398
                }
×
2399

2400
                return nil
×
2401
        }, func() {
×
2402
                prunedNodes = nil
×
2403
                closedChans = nil
×
2404
        })
×
2405
        if err != nil {
×
2406
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2407
        }
×
2408

2409
        for _, channel := range closedChans {
×
2410
                s.rejectCache.remove(channel.ChannelID)
×
2411
                s.chanCache.remove(channel.ChannelID)
×
2412
        }
×
2413

2414
        return closedChans, prunedNodes, nil
×
2415
}
2416

2417
// forEachChanInOutpoints is a helper function that executes a paginated
2418
// query to fetch channels by their outpoints and applies the given call-back
2419
// to each.
2420
//
2421
// NOTE: this fetches channels for all protocol versions.
2422
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2423
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2424
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2425

×
2426
        // Create a wrapper that uses the transaction's db instance to execute
×
2427
        // the query.
×
2428
        queryWrapper := func(ctx context.Context,
×
2429
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2430
                error) {
×
2431

×
2432
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2433
        }
×
2434

2435
        // Define the conversion function from Outpoint to string.
2436
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2437
                return outpoint.String()
×
2438
        }
×
2439

2440
        return sqldb.ExecuteBatchQuery(
×
2441
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2442
                queryWrapper, cb,
×
2443
        )
×
2444
}
2445

2446
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2447
        dbIDs []int64) error {
×
2448

×
2449
        // Create a wrapper that uses the transaction's db instance to execute
×
2450
        // the query.
×
2451
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2452
                return nil, db.DeleteChannels(ctx, ids)
×
2453
        }
×
2454

2455
        idConverter := func(id int64) int64 {
×
2456
                return id
×
2457
        }
×
2458

2459
        return sqldb.ExecuteBatchQuery(
×
2460
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2461
                queryWrapper, func(ctx context.Context, _ any) error {
×
2462
                        return nil
×
2463
                },
×
2464
        )
2465
}
2466

2467
// ChannelView returns the verifiable edge information for each active channel
2468
// within the known channel graph. The set of UTXOs (along with their scripts)
2469
// returned are the ones that need to be watched on chain to detect channel
2470
// closes on the resident blockchain.
2471
//
2472
// NOTE: part of the V1Store interface.
2473
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2474
        var (
×
2475
                ctx        = context.TODO()
×
2476
                edgePoints []EdgePoint
×
2477
        )
×
2478

×
2479
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2480
                handleChannel := func(_ context.Context,
×
2481
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2482

×
2483
                        pkScript, err := genMultiSigP2WSH(
×
2484
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2485
                        )
×
2486
                        if err != nil {
×
2487
                                return err
×
2488
                        }
×
2489

2490
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2491
                        if err != nil {
×
2492
                                return err
×
2493
                        }
×
2494

2495
                        edgePoints = append(edgePoints, EdgePoint{
×
2496
                                FundingPkScript: pkScript,
×
2497
                                OutPoint:        *op,
×
2498
                        })
×
2499

×
2500
                        return nil
×
2501
                }
2502

2503
                queryFunc := func(ctx context.Context, lastID int64,
×
2504
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2505

×
2506
                        return db.ListChannelsPaginated(
×
2507
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2508
                                        Version: int16(ProtocolV1),
×
2509
                                        ID:      lastID,
×
2510
                                        Limit:   limit,
×
2511
                                },
×
2512
                        )
×
2513
                }
×
2514

2515
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2516
                        return row.ID
×
2517
                }
×
2518

2519
                return sqldb.ExecutePaginatedQuery(
×
2520
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
2521
                        extractCursor, handleChannel,
×
2522
                )
×
2523
        }, func() {
×
2524
                edgePoints = nil
×
2525
        })
×
2526
        if err != nil {
×
2527
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2528
        }
×
2529

2530
        return edgePoints, nil
×
2531
}
2532

2533
// PruneTip returns the block height and hash of the latest block that has been
2534
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2535
// to tell if the graph is currently in sync with the current best known UTXO
2536
// state.
2537
//
2538
// NOTE: part of the V1Store interface.
2539
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2540
        var (
×
2541
                ctx       = context.TODO()
×
2542
                tipHash   chainhash.Hash
×
2543
                tipHeight uint32
×
2544
        )
×
2545
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2546
                pruneTip, err := db.GetPruneTip(ctx)
×
2547
                if errors.Is(err, sql.ErrNoRows) {
×
2548
                        return ErrGraphNeverPruned
×
2549
                } else if err != nil {
×
2550
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2551
                }
×
2552

2553
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2554
                tipHeight = uint32(pruneTip.BlockHeight)
×
2555

×
2556
                return nil
×
2557
        }, sqldb.NoOpReset)
2558
        if err != nil {
×
2559
                return nil, 0, err
×
2560
        }
×
2561

2562
        return &tipHash, tipHeight, nil
×
2563
}
2564

2565
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2566
//
2567
// NOTE: this prunes nodes across protocol versions. It will never prune the
2568
// source nodes.
2569
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2570
        db SQLQueries) ([]route.Vertex, error) {
×
2571

×
2572
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2573
        if err != nil {
×
2574
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2575
                        "nodes: %w", err)
×
2576
        }
×
2577

2578
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2579
        for i, nodeKey := range nodeKeys {
×
2580
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2581
                if err != nil {
×
2582
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2583
                                "from bytes: %w", err)
×
2584
                }
×
2585

2586
                prunedNodes[i] = pub
×
2587
        }
2588

2589
        return prunedNodes, nil
×
2590
}
2591

2592
// DisconnectBlockAtHeight is used to indicate that the block specified
2593
// by the passed height has been disconnected from the main chain. This
2594
// will "rewind" the graph back to the height below, deleting channels
2595
// that are no longer confirmed from the graph. The prune log will be
2596
// set to the last prune height valid for the remaining chain.
2597
// Channels that were removed from the graph resulting from the
2598
// disconnected block are returned.
2599
//
2600
// NOTE: part of the V1Store interface.
2601
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2602
        []*models.ChannelEdgeInfo, error) {
×
2603

×
2604
        ctx := context.TODO()
×
2605

×
2606
        var (
×
2607
                // Every channel having a ShortChannelID starting at 'height'
×
2608
                // will no longer be confirmed.
×
2609
                startShortChanID = lnwire.ShortChannelID{
×
2610
                        BlockHeight: height,
×
2611
                }
×
2612

×
2613
                // Delete everything after this height from the db up until the
×
2614
                // SCID alias range.
×
2615
                endShortChanID = aliasmgr.StartingAlias
×
2616

×
2617
                removedChans []*models.ChannelEdgeInfo
×
2618

×
2619
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2620
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2621
        )
×
2622

×
2623
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2624
                rows, err := db.GetChannelsBySCIDRange(
×
2625
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2626
                                StartScid: chanIDStart,
×
2627
                                EndScid:   chanIDEnd,
×
2628
                        },
×
2629
                )
×
2630
                if err != nil {
×
2631
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2632
                }
×
2633

2634
                if len(rows) == 0 {
×
2635
                        // No channels to disconnect, but still clean up prune
×
2636
                        // log.
×
2637
                        return db.DeletePruneLogEntriesInRange(
×
2638
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2639
                                        StartHeight: int64(height),
×
2640
                                        EndHeight: int64(
×
2641
                                                endShortChanID.BlockHeight,
×
2642
                                        ),
×
2643
                                },
×
2644
                        )
×
2645
                }
×
2646

2647
                // Batch build all channel edges for disconnection.
2648
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
2649
                        ctx, s.cfg, db, rows,
×
2650
                )
×
2651
                if err != nil {
×
2652
                        return err
×
2653
                }
×
2654

2655
                removedChans = channelEdges
×
2656

×
2657
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2658
                if err != nil {
×
2659
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2660
                }
×
2661

2662
                return db.DeletePruneLogEntriesInRange(
×
2663
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2664
                                StartHeight: int64(height),
×
2665
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2666
                        },
×
2667
                )
×
2668
        }, func() {
×
2669
                removedChans = nil
×
2670
        })
×
2671
        if err != nil {
×
2672
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2673
                        "height: %w", err)
×
2674
        }
×
2675

2676
        for _, channel := range removedChans {
×
2677
                s.rejectCache.remove(channel.ChannelID)
×
2678
                s.chanCache.remove(channel.ChannelID)
×
2679
        }
×
2680

2681
        return removedChans, nil
×
2682
}
2683

2684
// AddEdgeProof sets the proof of an existing edge in the graph database.
2685
//
2686
// NOTE: part of the V1Store interface.
2687
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2688
        proof *models.ChannelAuthProof) error {
×
2689

×
2690
        var (
×
2691
                ctx       = context.TODO()
×
2692
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2693
        )
×
2694

×
2695
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2696
                res, err := db.AddV1ChannelProof(
×
2697
                        ctx, sqlc.AddV1ChannelProofParams{
×
2698
                                Scid:              scidBytes,
×
2699
                                Node1Signature:    proof.NodeSig1Bytes,
×
2700
                                Node2Signature:    proof.NodeSig2Bytes,
×
2701
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2702
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2703
                        },
×
2704
                )
×
2705
                if err != nil {
×
2706
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2707
                }
×
2708

2709
                n, err := res.RowsAffected()
×
2710
                if err != nil {
×
2711
                        return err
×
2712
                }
×
2713

2714
                if n == 0 {
×
2715
                        return fmt.Errorf("no rows affected when adding edge "+
×
2716
                                "proof for SCID %v", scid)
×
2717
                } else if n > 1 {
×
2718
                        return fmt.Errorf("multiple rows affected when adding "+
×
2719
                                "edge proof for SCID %v: %d rows affected",
×
2720
                                scid, n)
×
2721
                }
×
2722

2723
                return nil
×
2724
        }, sqldb.NoOpReset)
2725
        if err != nil {
×
2726
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2727
        }
×
2728

2729
        return nil
×
2730
}
2731

2732
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2733
// that we can ignore channel announcements that we know to be closed without
2734
// having to validate them and fetch a block.
2735
//
2736
// NOTE: part of the V1Store interface.
2737
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2738
        var (
×
2739
                ctx     = context.TODO()
×
2740
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2741
        )
×
2742

×
2743
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2744
                return db.InsertClosedChannel(ctx, chanIDB)
×
2745
        }, sqldb.NoOpReset)
×
2746
}
2747

2748
// IsClosedScid checks whether a channel identified by the passed in scid is
2749
// closed. This helps avoid having to perform expensive validation checks.
2750
//
2751
// NOTE: part of the V1Store interface.
2752
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2753
        var (
×
2754
                ctx      = context.TODO()
×
2755
                isClosed bool
×
2756
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2757
        )
×
2758
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2759
                var err error
×
2760
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2761
                if err != nil {
×
2762
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2763
                                err)
×
2764
                }
×
2765

2766
                return nil
×
2767
        }, sqldb.NoOpReset)
2768
        if err != nil {
×
2769
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2770
                        err)
×
2771
        }
×
2772

2773
        return isClosed, nil
×
2774
}
2775

2776
// GraphSession will provide the call-back with access to a NodeTraverser
2777
// instance which can be used to perform queries against the channel graph.
2778
//
2779
// NOTE: part of the V1Store interface.
2780
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
2781
        reset func()) error {
×
2782

×
2783
        var ctx = context.TODO()
×
2784

×
2785
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2786
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2787
        }, reset)
×
2788
}
2789

2790
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2791
// read only transaction for a consistent view of the graph.
2792
type sqlNodeTraverser struct {
2793
        db    SQLQueries
2794
        chain chainhash.Hash
2795
}
2796

2797
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2798
// NodeTraverser interface.
2799
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2800

2801
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2802
func newSQLNodeTraverser(db SQLQueries,
2803
        chain chainhash.Hash) *sqlNodeTraverser {
×
2804

×
2805
        return &sqlNodeTraverser{
×
2806
                db:    db,
×
2807
                chain: chain,
×
2808
        }
×
2809
}
×
2810

2811
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2812
// node.
2813
//
2814
// NOTE: Part of the NodeTraverser interface.
2815
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2816
        cb func(channel *DirectedChannel) error, _ func()) error {
×
2817

×
2818
        ctx := context.TODO()
×
2819

×
2820
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2821
}
×
2822

2823
// FetchNodeFeatures returns the features of the given node. If the node is
2824
// unknown, assume no additional features are supported.
2825
//
2826
// NOTE: Part of the NodeTraverser interface.
2827
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2828
        *lnwire.FeatureVector, error) {
×
2829

×
2830
        ctx := context.TODO()
×
2831

×
2832
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2833
}
×
2834

2835
// forEachNodeDirectedChannel iterates through all channels of a given
2836
// node, executing the passed callback on the directed edge representing the
2837
// channel and its incoming policy. If the node is not found, no error is
2838
// returned.
2839
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2840
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2841

×
2842
        toNodeCallback := func() route.Vertex {
×
2843
                return nodePub
×
2844
        }
×
2845

2846
        dbID, err := db.GetNodeIDByPubKey(
×
2847
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2848
                        Version: int16(ProtocolV1),
×
2849
                        PubKey:  nodePub[:],
×
2850
                },
×
2851
        )
×
2852
        if errors.Is(err, sql.ErrNoRows) {
×
2853
                return nil
×
2854
        } else if err != nil {
×
2855
                return fmt.Errorf("unable to fetch node: %w", err)
×
2856
        }
×
2857

2858
        rows, err := db.ListChannelsByNodeID(
×
2859
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2860
                        Version: int16(ProtocolV1),
×
2861
                        NodeID1: dbID,
×
2862
                },
×
2863
        )
×
2864
        if err != nil {
×
2865
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2866
        }
×
2867

2868
        // Exit early if there are no channels for this node so we don't
2869
        // do the unnecessary feature fetching.
2870
        if len(rows) == 0 {
×
2871
                return nil
×
2872
        }
×
2873

2874
        features, err := getNodeFeatures(ctx, db, dbID)
×
2875
        if err != nil {
×
2876
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2877
        }
×
2878

2879
        for _, row := range rows {
×
2880
                node1, node2, err := buildNodeVertices(
×
2881
                        row.Node1Pubkey, row.Node2Pubkey,
×
2882
                )
×
2883
                if err != nil {
×
2884
                        return fmt.Errorf("unable to build node vertices: %w",
×
2885
                                err)
×
2886
                }
×
2887

2888
                edge := buildCacheableChannelInfo(
×
2889
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
2890
                        node1, node2,
×
2891
                )
×
2892

×
2893
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2894
                if err != nil {
×
2895
                        return err
×
2896
                }
×
2897

2898
                p1, p2, err := buildCachedChanPolicies(
×
2899
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
2900
                )
×
2901
                if err != nil {
×
2902
                        return err
×
2903
                }
×
2904

2905
                // Determine the outgoing and incoming policy for this
2906
                // channel and node combo.
2907
                outPolicy, inPolicy := p1, p2
×
2908
                if p1 != nil && node2 == nodePub {
×
2909
                        outPolicy, inPolicy = p2, p1
×
2910
                } else if p2 != nil && node1 != nodePub {
×
2911
                        outPolicy, inPolicy = p2, p1
×
2912
                }
×
2913

2914
                var cachedInPolicy *models.CachedEdgePolicy
×
2915
                if inPolicy != nil {
×
2916
                        cachedInPolicy = inPolicy
×
2917
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2918
                        cachedInPolicy.ToNodeFeatures = features
×
2919
                }
×
2920

2921
                directedChannel := &DirectedChannel{
×
2922
                        ChannelID:    edge.ChannelID,
×
2923
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2924
                        OtherNode:    edge.NodeKey2Bytes,
×
2925
                        Capacity:     edge.Capacity,
×
2926
                        OutPolicySet: outPolicy != nil,
×
2927
                        InPolicy:     cachedInPolicy,
×
2928
                }
×
2929
                if outPolicy != nil {
×
2930
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2931
                                directedChannel.InboundFee = fee
×
2932
                        })
×
2933
                }
2934

2935
                if nodePub == edge.NodeKey2Bytes {
×
2936
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2937
                }
×
2938

2939
                if err := cb(directedChannel); err != nil {
×
2940
                        return err
×
2941
                }
×
2942
        }
2943

2944
        return nil
×
2945
}
2946

2947
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2948
// and executes the provided callback for each node. It does so via pagination
2949
// along with batch loading of the node feature bits.
2950
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
2951
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
2952
                features *lnwire.FeatureVector) error) error {
×
2953

×
2954
        handleNode := func(_ context.Context,
×
2955
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
2956
                featureBits map[int64][]int) error {
×
2957

×
2958
                fv := lnwire.EmptyFeatureVector()
×
2959
                if features, exists := featureBits[dbNode.ID]; exists {
×
2960
                        for _, bit := range features {
×
2961
                                fv.Set(lnwire.FeatureBit(bit))
×
2962
                        }
×
2963
                }
2964

2965
                var pub route.Vertex
×
2966
                copy(pub[:], dbNode.PubKey)
×
2967

×
2968
                return processNode(dbNode.ID, pub, fv)
×
2969
        }
2970

2971
        queryFunc := func(ctx context.Context, lastID int64,
×
2972
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
2973

×
2974
                return db.ListNodeIDsAndPubKeys(
×
2975
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
2976
                                Version: int16(ProtocolV1),
×
2977
                                ID:      lastID,
×
2978
                                Limit:   limit,
×
2979
                        },
×
2980
                )
×
2981
        }
×
2982

2983
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
2984
                return row.ID
×
2985
        }
×
2986

2987
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
2988
                return node.ID, nil
×
2989
        }
×
2990

2991
        batchQueryFunc := func(ctx context.Context,
×
2992
                nodeIDs []int64) (map[int64][]int, error) {
×
2993

×
2994
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
2995
        }
×
2996

2997
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
2998
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
2999
                batchQueryFunc, handleNode,
×
3000
        )
×
3001
}
3002

3003
// forEachNodeChannel iterates through all channels of a node, executing
3004
// the passed callback on each. The call-back is provided with the channel's
3005
// edge information, the outgoing policy and the incoming policy for the
3006
// channel and node combo.
3007
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3008
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
3009
                *models.ChannelEdgePolicy,
3010
                *models.ChannelEdgePolicy) error) error {
×
3011

×
3012
        // Get all the V1 channels for this node.
×
3013
        rows, err := db.ListChannelsByNodeID(
×
3014
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3015
                        Version: int16(ProtocolV1),
×
3016
                        NodeID1: id,
×
3017
                },
×
3018
        )
×
3019
        if err != nil {
×
3020
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3021
        }
×
3022

3023
        // Collect all the channel and policy IDs.
3024
        var (
×
3025
                chanIDs   = make([]int64, 0, len(rows))
×
3026
                policyIDs = make([]int64, 0, 2*len(rows))
×
3027
        )
×
3028
        for _, row := range rows {
×
3029
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3030

×
3031
                if row.Policy1ID.Valid {
×
3032
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3033
                }
×
3034
                if row.Policy2ID.Valid {
×
3035
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3036
                }
×
3037
        }
3038

3039
        batchData, err := batchLoadChannelData(
×
3040
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3041
        )
×
3042
        if err != nil {
×
3043
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3044
        }
×
3045

3046
        // Call the call-back for each channel and its known policies.
3047
        for _, row := range rows {
×
3048
                node1, node2, err := buildNodeVertices(
×
3049
                        row.Node1Pubkey, row.Node2Pubkey,
×
3050
                )
×
3051
                if err != nil {
×
3052
                        return fmt.Errorf("unable to build node vertices: %w",
×
3053
                                err)
×
3054
                }
×
3055

3056
                edge, err := buildEdgeInfoWithBatchData(
×
3057
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3058
                        batchData,
×
3059
                )
×
3060
                if err != nil {
×
3061
                        return fmt.Errorf("unable to build channel info: %w",
×
3062
                                err)
×
3063
                }
×
3064

3065
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3066
                if err != nil {
×
3067
                        return fmt.Errorf("unable to extract channel "+
×
3068
                                "policies: %w", err)
×
3069
                }
×
3070

3071
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3072
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3073
                )
×
3074
                if err != nil {
×
3075
                        return fmt.Errorf("unable to build channel "+
×
3076
                                "policies: %w", err)
×
3077
                }
×
3078

3079
                // Determine the outgoing and incoming policy for this
3080
                // channel and node combo.
3081
                p1ToNode := row.GraphChannel.NodeID2
×
3082
                p2ToNode := row.GraphChannel.NodeID1
×
3083
                outPolicy, inPolicy := p1, p2
×
3084
                if (p1 != nil && p1ToNode == id) ||
×
3085
                        (p2 != nil && p2ToNode != id) {
×
3086

×
3087
                        outPolicy, inPolicy = p2, p1
×
3088
                }
×
3089

3090
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3091
                        return err
×
3092
                }
×
3093
        }
3094

3095
        return nil
×
3096
}
3097

3098
// updateChanEdgePolicy upserts the channel policy info we have stored for
3099
// a channel we already know of.
3100
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3101
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3102
        error) {
×
3103

×
3104
        var (
×
3105
                node1Pub, node2Pub route.Vertex
×
3106
                isNode1            bool
×
3107
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3108
        )
×
3109

×
3110
        // Check that this edge policy refers to a channel that we already
×
3111
        // know of. We do this explicitly so that we can return the appropriate
×
3112
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3113
        // abort the transaction which would abort the entire batch.
×
3114
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3115
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3116
                        Scid:    chanIDB,
×
3117
                        Version: int16(ProtocolV1),
×
3118
                },
×
3119
        )
×
3120
        if errors.Is(err, sql.ErrNoRows) {
×
3121
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3122
        } else if err != nil {
×
3123
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3124
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3125
        }
×
3126

3127
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3128
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3129

×
3130
        // Figure out which node this edge is from.
×
3131
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3132
        nodeID := dbChan.NodeID1
×
3133
        if !isNode1 {
×
3134
                nodeID = dbChan.NodeID2
×
3135
        }
×
3136

3137
        var (
×
3138
                inboundBase sql.NullInt64
×
3139
                inboundRate sql.NullInt64
×
3140
        )
×
3141
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3142
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3143
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3144
        })
×
3145

3146
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3147
                Version:     int16(ProtocolV1),
×
3148
                ChannelID:   dbChan.ID,
×
3149
                NodeID:      nodeID,
×
3150
                Timelock:    int32(edge.TimeLockDelta),
×
3151
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3152
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3153
                MinHtlcMsat: int64(edge.MinHTLC),
×
3154
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3155
                Disabled: sql.NullBool{
×
3156
                        Valid: true,
×
3157
                        Bool:  edge.IsDisabled(),
×
3158
                },
×
3159
                MaxHtlcMsat: sql.NullInt64{
×
3160
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3161
                        Int64: int64(edge.MaxHTLC),
×
3162
                },
×
3163
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3164
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3165
                InboundBaseFeeMsat:      inboundBase,
×
3166
                InboundFeeRateMilliMsat: inboundRate,
×
3167
                Signature:               edge.SigBytes,
×
3168
        })
×
3169
        if err != nil {
×
3170
                return node1Pub, node2Pub, isNode1,
×
3171
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3172
        }
×
3173

3174
        // Convert the flat extra opaque data into a map of TLV types to
3175
        // values.
3176
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3177
        if err != nil {
×
3178
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3179
                        "marshal extra opaque data: %w", err)
×
3180
        }
×
3181

3182
        // Update the channel policy's extra signed fields.
3183
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3184
        if err != nil {
×
3185
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3186
                        "policy extra TLVs: %w", err)
×
3187
        }
×
3188

3189
        return node1Pub, node2Pub, isNode1, nil
×
3190
}
3191

3192
// getNodeByPubKey attempts to look up a target node by its public key.
3193
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3194
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3195

×
3196
        dbNode, err := db.GetNodeByPubKey(
×
3197
                ctx, sqlc.GetNodeByPubKeyParams{
×
3198
                        Version: int16(ProtocolV1),
×
3199
                        PubKey:  pubKey[:],
×
3200
                },
×
3201
        )
×
3202
        if errors.Is(err, sql.ErrNoRows) {
×
3203
                return 0, nil, ErrGraphNodeNotFound
×
3204
        } else if err != nil {
×
3205
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3206
        }
×
3207

3208
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3209
        if err != nil {
×
3210
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3211
        }
×
3212

3213
        return dbNode.ID, node, nil
×
3214
}
3215

3216
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3217
// provided parameters.
3218
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3219
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3220

×
3221
        return &models.CachedEdgeInfo{
×
3222
                ChannelID:     byteOrder.Uint64(scid),
×
3223
                NodeKey1Bytes: node1Pub,
×
3224
                NodeKey2Bytes: node2Pub,
×
3225
                Capacity:      btcutil.Amount(capacity),
×
3226
        }
×
3227
}
×
3228

3229
// buildNode constructs a LightningNode instance from the given database node
3230
// record. The node's features, addresses and extra signed fields are also
3231
// fetched from the database and set on the node.
3232
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3233
        dbNode sqlc.GraphNode) (*models.LightningNode, error) {
×
3234

×
3235
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3236
        if err != nil {
×
3237
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3238
                        err)
×
3239
        }
×
3240

3241
        return buildNodeWithBatchData(dbNode, data)
×
3242
}
3243

3244
// buildNodeWithBatchData builds a models.LightningNode instance
3245
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3246
// features/addresses/extra fields, then the corresponding fields are expected
3247
// to be present in the batchNodeData.
3248
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3249
        batchData *batchNodeData) (*models.LightningNode, error) {
×
3250

×
3251
        if dbNode.Version != int16(ProtocolV1) {
×
3252
                return nil, fmt.Errorf("unsupported node version: %d",
×
3253
                        dbNode.Version)
×
3254
        }
×
3255

3256
        var pub [33]byte
×
3257
        copy(pub[:], dbNode.PubKey)
×
3258

×
3259
        node := &models.LightningNode{
×
3260
                PubKeyBytes: pub,
×
3261
                Features:    lnwire.EmptyFeatureVector(),
×
3262
                LastUpdate:  time.Unix(0, 0),
×
3263
        }
×
3264

×
3265
        if len(dbNode.Signature) == 0 {
×
3266
                return node, nil
×
3267
        }
×
3268

3269
        node.HaveNodeAnnouncement = true
×
3270
        node.AuthSigBytes = dbNode.Signature
×
3271
        node.Alias = dbNode.Alias.String
×
3272
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3273

×
3274
        var err error
×
3275
        if dbNode.Color.Valid {
×
3276
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3277
                if err != nil {
×
3278
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3279
                                err)
×
3280
                }
×
3281
        }
3282

3283
        // Use preloaded features.
3284
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3285
                fv := lnwire.EmptyFeatureVector()
×
3286
                for _, bit := range features {
×
3287
                        fv.Set(lnwire.FeatureBit(bit))
×
3288
                }
×
3289
                node.Features = fv
×
3290
        }
3291

3292
        // Use preloaded addresses.
3293
        addresses, exists := batchData.addresses[dbNode.ID]
×
3294
        if exists && len(addresses) > 0 {
×
3295
                node.Addresses, err = buildNodeAddresses(addresses)
×
3296
                if err != nil {
×
3297
                        return nil, fmt.Errorf("unable to build addresses "+
×
3298
                                "for node(%d): %w", dbNode.ID, err)
×
3299
                }
×
3300
        }
3301

3302
        // Use preloaded extra fields.
3303
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3304
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3305
                if err != nil {
×
3306
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3307
                                "signed fields: %w", err)
×
3308
                }
×
3309
                if len(recs) != 0 {
×
3310
                        node.ExtraOpaqueData = recs
×
3311
                }
×
3312
        }
3313

3314
        return node, nil
×
3315
}
3316

3317
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3318
// with the preloaded data, and executes the provided callback for each node.
3319
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3320
        db SQLQueries, nodes []sqlc.GraphNode,
3321
        cb func(dbID int64, node *models.LightningNode) error) error {
×
3322

×
3323
        // Extract node IDs for batch loading.
×
3324
        nodeIDs := make([]int64, len(nodes))
×
3325
        for i, node := range nodes {
×
3326
                nodeIDs[i] = node.ID
×
3327
        }
×
3328

3329
        // Batch load all related data for this page.
3330
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3331
        if err != nil {
×
3332
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3333
        }
×
3334

3335
        for _, dbNode := range nodes {
×
3336
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
3337
                if err != nil {
×
3338
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3339
                                dbNode.ID, err)
×
3340
                }
×
3341

3342
                if err := cb(dbNode.ID, node); err != nil {
×
3343
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3344
                                dbNode.ID, err)
×
3345
                }
×
3346
        }
3347

3348
        return nil
×
3349
}
3350

3351
// getNodeFeatures fetches the feature bits and constructs the feature vector
3352
// for a node with the given DB ID.
3353
func getNodeFeatures(ctx context.Context, db SQLQueries,
3354
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3355

×
3356
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3357
        if err != nil {
×
3358
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3359
                        nodeID, err)
×
3360
        }
×
3361

3362
        features := lnwire.EmptyFeatureVector()
×
3363
        for _, feature := range rows {
×
3364
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3365
        }
×
3366

3367
        return features, nil
×
3368
}
3369

3370
// upsertNode upserts the node record into the database. If the node already
3371
// exists, then the node's information is updated. If the node doesn't exist,
3372
// then a new node is created. The node's features, addresses and extra TLV
3373
// types are also updated. The node's DB ID is returned.
3374
func upsertNode(ctx context.Context, db SQLQueries,
3375
        node *models.LightningNode) (int64, error) {
×
3376

×
3377
        params := sqlc.UpsertNodeParams{
×
3378
                Version: int16(ProtocolV1),
×
3379
                PubKey:  node.PubKeyBytes[:],
×
3380
        }
×
3381

×
3382
        if node.HaveNodeAnnouncement {
×
3383
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3384
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3385
                params.Alias = sqldb.SQLStr(node.Alias)
×
3386
                params.Signature = node.AuthSigBytes
×
3387
        }
×
3388

3389
        nodeID, err := db.UpsertNode(ctx, params)
×
3390
        if err != nil {
×
3391
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3392
                        err)
×
3393
        }
×
3394

3395
        // We can exit here if we don't have the announcement yet.
3396
        if !node.HaveNodeAnnouncement {
×
3397
                return nodeID, nil
×
3398
        }
×
3399

3400
        // Update the node's features.
3401
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3402
        if err != nil {
×
3403
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3404
        }
×
3405

3406
        // Update the node's addresses.
3407
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3408
        if err != nil {
×
3409
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3410
        }
×
3411

3412
        // Convert the flat extra opaque data into a map of TLV types to
3413
        // values.
3414
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3415
        if err != nil {
×
3416
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3417
                        err)
×
3418
        }
×
3419

3420
        // Update the node's extra signed fields.
3421
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3422
        if err != nil {
×
3423
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3424
        }
×
3425

3426
        return nodeID, nil
×
3427
}
3428

3429
// upsertNodeFeatures updates the node's features node_features table. This
3430
// includes deleting any feature bits no longer present and inserting any new
3431
// feature bits. If the feature bit does not yet exist in the features table,
3432
// then an entry is created in that table first.
3433
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3434
        features *lnwire.FeatureVector) error {
×
3435

×
3436
        // Get any existing features for the node.
×
3437
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3438
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3439
                return err
×
3440
        }
×
3441

3442
        // Copy the nodes latest set of feature bits.
3443
        newFeatures := make(map[int32]struct{})
×
3444
        if features != nil {
×
3445
                for feature := range features.Features() {
×
3446
                        newFeatures[int32(feature)] = struct{}{}
×
3447
                }
×
3448
        }
3449

3450
        // For any current feature that already exists in the DB, remove it from
3451
        // the in-memory map. For any existing feature that does not exist in
3452
        // the in-memory map, delete it from the database.
3453
        for _, feature := range existingFeatures {
×
3454
                // The feature is still present, so there are no updates to be
×
3455
                // made.
×
3456
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3457
                        delete(newFeatures, feature.FeatureBit)
×
3458
                        continue
×
3459
                }
3460

3461
                // The feature is no longer present, so we remove it from the
3462
                // database.
3463
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3464
                        NodeID:     nodeID,
×
3465
                        FeatureBit: feature.FeatureBit,
×
3466
                })
×
3467
                if err != nil {
×
3468
                        return fmt.Errorf("unable to delete node(%d) "+
×
3469
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3470
                                err)
×
3471
                }
×
3472
        }
3473

3474
        // Any remaining entries in newFeatures are new features that need to be
3475
        // added to the database for the first time.
3476
        for feature := range newFeatures {
×
3477
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3478
                        NodeID:     nodeID,
×
3479
                        FeatureBit: feature,
×
3480
                })
×
3481
                if err != nil {
×
3482
                        return fmt.Errorf("unable to insert node(%d) "+
×
3483
                                "feature(%v): %w", nodeID, feature, err)
×
3484
                }
×
3485
        }
3486

3487
        return nil
×
3488
}
3489

3490
// fetchNodeFeatures fetches the features for a node with the given public key.
3491
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3492
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3493

×
3494
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3495
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3496
                        PubKey:  nodePub[:],
×
3497
                        Version: int16(ProtocolV1),
×
3498
                },
×
3499
        )
×
3500
        if err != nil {
×
3501
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3502
                        nodePub, err)
×
3503
        }
×
3504

3505
        features := lnwire.EmptyFeatureVector()
×
3506
        for _, bit := range rows {
×
3507
                features.Set(lnwire.FeatureBit(bit))
×
3508
        }
×
3509

3510
        return features, nil
×
3511
}
3512

3513
// dbAddressType is an enum type that represents the different address types
3514
// that we store in the node_addresses table. The address type determines how
3515
// the address is to be serialised/deserialize.
3516
type dbAddressType uint8
3517

3518
const (
3519
        addressTypeIPv4   dbAddressType = 1
3520
        addressTypeIPv6   dbAddressType = 2
3521
        addressTypeTorV2  dbAddressType = 3
3522
        addressTypeTorV3  dbAddressType = 4
3523
        addressTypeDNS    dbAddressType = 5
3524
        addressTypeOpaque dbAddressType = math.MaxInt8
3525
)
3526

3527
// upsertNodeAddresses updates the node's addresses in the database. This
3528
// includes deleting any existing addresses and inserting the new set of
3529
// addresses. The deletion is necessary since the ordering of the addresses may
3530
// change, and we need to ensure that the database reflects the latest set of
3531
// addresses so that at the time of reconstructing the node announcement, the
3532
// order is preserved and the signature over the message remains valid.
3533
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3534
        addresses []net.Addr) error {
×
3535

×
3536
        // Delete any existing addresses for the node. This is required since
×
3537
        // even if the new set of addresses is the same, the ordering may have
×
3538
        // changed for a given address type.
×
3539
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3540
        if err != nil {
×
3541
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3542
                        nodeID, err)
×
3543
        }
×
3544

3545
        // Copy the nodes latest set of addresses.
3546
        newAddresses := map[dbAddressType][]string{
×
3547
                addressTypeIPv4:   {},
×
3548
                addressTypeIPv6:   {},
×
3549
                addressTypeTorV2:  {},
×
3550
                addressTypeTorV3:  {},
×
NEW
3551
                addressTypeDNS:    {},
×
3552
                addressTypeOpaque: {},
×
3553
        }
×
3554
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3555
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3556
        }
×
3557

3558
        for _, address := range addresses {
×
3559
                switch addr := address.(type) {
×
3560
                case *net.TCPAddr:
×
3561
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3562
                                addAddr(addressTypeIPv4, addr)
×
3563
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3564
                                addAddr(addressTypeIPv6, addr)
×
3565
                        } else {
×
3566
                                return fmt.Errorf("unhandled IP address: %v",
×
3567
                                        addr)
×
3568
                        }
×
3569

3570
                case *tor.OnionAddr:
×
3571
                        switch len(addr.OnionService) {
×
3572
                        case tor.V2Len:
×
3573
                                addAddr(addressTypeTorV2, addr)
×
3574
                        case tor.V3Len:
×
3575
                                addAddr(addressTypeTorV3, addr)
×
3576
                        default:
×
3577
                                return fmt.Errorf("invalid length for a tor " +
×
3578
                                        "address")
×
3579
                        }
3580

NEW
3581
                case *lnwire.DNSAddress:
×
NEW
3582
                        addAddr(addressTypeDNS, addr)
×
3583

3584
                case *lnwire.OpaqueAddrs:
×
3585
                        addAddr(addressTypeOpaque, addr)
×
3586

3587
                default:
×
3588
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3589
                }
3590
        }
3591

3592
        // Any remaining entries in newAddresses are new addresses that need to
3593
        // be added to the database for the first time.
3594
        for addrType, addrList := range newAddresses {
×
3595
                for position, addr := range addrList {
×
3596
                        err := db.InsertNodeAddress(
×
3597
                                ctx, sqlc.InsertNodeAddressParams{
×
3598
                                        NodeID:   nodeID,
×
3599
                                        Type:     int16(addrType),
×
3600
                                        Address:  addr,
×
3601
                                        Position: int32(position),
×
3602
                                },
×
3603
                        )
×
3604
                        if err != nil {
×
3605
                                return fmt.Errorf("unable to insert "+
×
3606
                                        "node(%d) address(%v): %w", nodeID,
×
3607
                                        addr, err)
×
3608
                        }
×
3609
                }
3610
        }
3611

3612
        return nil
×
3613
}
3614

3615
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3616
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
3617
        error) {
×
3618

×
3619
        // GetNodeAddresses ensures that the addresses for a given type are
×
3620
        // returned in the same order as they were inserted.
×
3621
        rows, err := db.GetNodeAddresses(ctx, id)
×
3622
        if err != nil {
×
3623
                return nil, err
×
3624
        }
×
3625

3626
        addresses := make([]net.Addr, 0, len(rows))
×
3627
        for _, row := range rows {
×
3628
                address := row.Address
×
3629

×
3630
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
3631
                if err != nil {
×
3632
                        return nil, fmt.Errorf("unable to parse address "+
×
3633
                                "for node(%d): %v: %w", id, address, err)
×
3634
                }
×
3635

3636
                addresses = append(addresses, addr)
×
3637
        }
3638

3639
        // If we have no addresses, then we'll return nil instead of an
3640
        // empty slice.
3641
        if len(addresses) == 0 {
×
3642
                addresses = nil
×
3643
        }
×
3644

3645
        return addresses, nil
×
3646
}
3647

3648
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3649
// database. This includes updating any existing types, inserting any new types,
3650
// and deleting any types that are no longer present.
3651
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3652
        nodeID int64, extraFields map[uint64][]byte) error {
×
3653

×
3654
        // Get any existing extra signed fields for the node.
×
3655
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3656
        if err != nil {
×
3657
                return err
×
3658
        }
×
3659

3660
        // Make a lookup map of the existing field types so that we can use it
3661
        // to keep track of any fields we should delete.
3662
        m := make(map[uint64]bool)
×
3663
        for _, field := range existingFields {
×
3664
                m[uint64(field.Type)] = true
×
3665
        }
×
3666

3667
        // For all the new fields, we'll upsert them and remove them from the
3668
        // map of existing fields.
3669
        for tlvType, value := range extraFields {
×
3670
                err = db.UpsertNodeExtraType(
×
3671
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3672
                                NodeID: nodeID,
×
3673
                                Type:   int64(tlvType),
×
3674
                                Value:  value,
×
3675
                        },
×
3676
                )
×
3677
                if err != nil {
×
3678
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3679
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3680
                }
×
3681

3682
                // Remove the field from the map of existing fields if it was
3683
                // present.
3684
                delete(m, tlvType)
×
3685
        }
3686

3687
        // For all the fields that are left in the map of existing fields, we'll
3688
        // delete them as they are no longer present in the new set of fields.
3689
        for tlvType := range m {
×
3690
                err = db.DeleteExtraNodeType(
×
3691
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3692
                                NodeID: nodeID,
×
3693
                                Type:   int64(tlvType),
×
3694
                        },
×
3695
                )
×
3696
                if err != nil {
×
3697
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3698
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3699
                }
×
3700
        }
3701

3702
        return nil
×
3703
}
3704

3705
// srcNodeInfo holds the information about the source node of the graph.
3706
type srcNodeInfo struct {
3707
        // id is the DB level ID of the source node entry in the "nodes" table.
3708
        id int64
3709

3710
        // pub is the public key of the source node.
3711
        pub route.Vertex
3712
}
3713

3714
// sourceNode returns the DB node ID and pub key of the source node for the
3715
// specified protocol version.
3716
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3717
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3718

×
3719
        s.srcNodeMu.Lock()
×
3720
        defer s.srcNodeMu.Unlock()
×
3721

×
3722
        // If we already have the source node ID and pub key cached, then
×
3723
        // return them.
×
3724
        if info, ok := s.srcNodes[version]; ok {
×
3725
                return info.id, info.pub, nil
×
3726
        }
×
3727

3728
        var pubKey route.Vertex
×
3729

×
3730
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3731
        if err != nil {
×
3732
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3733
                        err)
×
3734
        }
×
3735

3736
        if len(nodes) == 0 {
×
3737
                return 0, pubKey, ErrSourceNodeNotSet
×
3738
        } else if len(nodes) > 1 {
×
3739
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3740
                        "protocol %s found", version)
×
3741
        }
×
3742

3743
        copy(pubKey[:], nodes[0].PubKey)
×
3744

×
3745
        s.srcNodes[version] = &srcNodeInfo{
×
3746
                id:  nodes[0].NodeID,
×
3747
                pub: pubKey,
×
3748
        }
×
3749

×
3750
        return nodes[0].NodeID, pubKey, nil
×
3751
}
3752

3753
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3754
// This then produces a map from TLV type to value. If the input is not a
3755
// valid TLV stream, then an error is returned.
3756
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3757
        r := bytes.NewReader(data)
×
3758

×
3759
        tlvStream, err := tlv.NewStream()
×
3760
        if err != nil {
×
3761
                return nil, err
×
3762
        }
×
3763

3764
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3765
        // pass it into the P2P decoding variant.
3766
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3767
        if err != nil {
×
3768
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3769
        }
×
3770
        if len(parsedTypes) == 0 {
×
3771
                return nil, nil
×
3772
        }
×
3773

3774
        records := make(map[uint64][]byte)
×
3775
        for k, v := range parsedTypes {
×
3776
                records[uint64(k)] = v
×
3777
        }
×
3778

3779
        return records, nil
×
3780
}
3781

3782
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3783
// channel.
3784
type dbChanInfo struct {
3785
        channelID int64
3786
        node1ID   int64
3787
        node2ID   int64
3788
}
3789

3790
// insertChannel inserts a new channel record into the database.
3791
func insertChannel(ctx context.Context, db SQLQueries,
3792
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3793

×
3794
        // Make sure that at least a "shell" entry for each node is present in
×
3795
        // the nodes table.
×
3796
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3797
        if err != nil {
×
3798
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3799
        }
×
3800

3801
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3802
        if err != nil {
×
3803
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3804
        }
×
3805

3806
        var capacity sql.NullInt64
×
3807
        if edge.Capacity != 0 {
×
3808
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3809
        }
×
3810

3811
        createParams := sqlc.CreateChannelParams{
×
3812
                Version:     int16(ProtocolV1),
×
3813
                Scid:        channelIDToBytes(edge.ChannelID),
×
3814
                NodeID1:     node1DBID,
×
3815
                NodeID2:     node2DBID,
×
3816
                Outpoint:    edge.ChannelPoint.String(),
×
3817
                Capacity:    capacity,
×
3818
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3819
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3820
        }
×
3821

×
3822
        if edge.AuthProof != nil {
×
3823
                proof := edge.AuthProof
×
3824

×
3825
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3826
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3827
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3828
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3829
        }
×
3830

3831
        // Insert the new channel record.
3832
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3833
        if err != nil {
×
3834
                return nil, err
×
3835
        }
×
3836

3837
        // Insert any channel features.
3838
        for feature := range edge.Features.Features() {
×
3839
                err = db.InsertChannelFeature(
×
3840
                        ctx, sqlc.InsertChannelFeatureParams{
×
3841
                                ChannelID:  dbChanID,
×
3842
                                FeatureBit: int32(feature),
×
3843
                        },
×
3844
                )
×
3845
                if err != nil {
×
3846
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
3847
                                "feature(%v): %w", dbChanID, feature, err)
×
3848
                }
×
3849
        }
3850

3851
        // Finally, insert any extra TLV fields in the channel announcement.
3852
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3853
        if err != nil {
×
3854
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
3855
                        "data: %w", err)
×
3856
        }
×
3857

3858
        for tlvType, value := range extra {
×
3859
                err := db.CreateChannelExtraType(
×
3860
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3861
                                ChannelID: dbChanID,
×
3862
                                Type:      int64(tlvType),
×
3863
                                Value:     value,
×
3864
                        },
×
3865
                )
×
3866
                if err != nil {
×
3867
                        return nil, fmt.Errorf("unable to upsert "+
×
3868
                                "channel(%d) extra signed field(%v): %w",
×
3869
                                edge.ChannelID, tlvType, err)
×
3870
                }
×
3871
        }
3872

3873
        return &dbChanInfo{
×
3874
                channelID: dbChanID,
×
3875
                node1ID:   node1DBID,
×
3876
                node2ID:   node2DBID,
×
3877
        }, nil
×
3878
}
3879

3880
// maybeCreateShellNode checks if a shell node entry exists for the
3881
// given public key. If it does not exist, then a new shell node entry is
3882
// created. The ID of the node is returned. A shell node only has a protocol
3883
// version and public key persisted.
3884
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3885
        pubKey route.Vertex) (int64, error) {
×
3886

×
3887
        dbNode, err := db.GetNodeByPubKey(
×
3888
                ctx, sqlc.GetNodeByPubKeyParams{
×
3889
                        PubKey:  pubKey[:],
×
3890
                        Version: int16(ProtocolV1),
×
3891
                },
×
3892
        )
×
3893
        // The node exists. Return the ID.
×
3894
        if err == nil {
×
3895
                return dbNode.ID, nil
×
3896
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3897
                return 0, err
×
3898
        }
×
3899

3900
        // Otherwise, the node does not exist, so we create a shell entry for
3901
        // it.
3902
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3903
                Version: int16(ProtocolV1),
×
3904
                PubKey:  pubKey[:],
×
3905
        })
×
3906
        if err != nil {
×
3907
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3908
        }
×
3909

3910
        return id, nil
×
3911
}
3912

3913
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3914
// the database. This includes deleting any existing types and then inserting
3915
// the new types.
3916
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3917
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3918

×
3919
        // Delete all existing extra signed fields for the channel policy.
×
3920
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3921
        if err != nil {
×
3922
                return fmt.Errorf("unable to delete "+
×
3923
                        "existing policy extra signed fields for policy %d: %w",
×
3924
                        chanPolicyID, err)
×
3925
        }
×
3926

3927
        // Insert all new extra signed fields for the channel policy.
3928
        for tlvType, value := range extraFields {
×
3929
                err = db.InsertChanPolicyExtraType(
×
3930
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3931
                                ChannelPolicyID: chanPolicyID,
×
3932
                                Type:            int64(tlvType),
×
3933
                                Value:           value,
×
3934
                        },
×
3935
                )
×
3936
                if err != nil {
×
3937
                        return fmt.Errorf("unable to insert "+
×
3938
                                "channel_policy(%d) extra signed field(%v): %w",
×
3939
                                chanPolicyID, tlvType, err)
×
3940
                }
×
3941
        }
3942

3943
        return nil
×
3944
}
3945

3946
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3947
// provided dbChanRow and also fetches any other required information
3948
// to construct the edge info.
3949
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
3950
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
3951
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3952

×
3953
        data, err := batchLoadChannelData(
×
3954
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
3955
        )
×
3956
        if err != nil {
×
3957
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
3958
                        err)
×
3959
        }
×
3960

3961
        return buildEdgeInfoWithBatchData(
×
3962
                cfg.ChainHash, dbChan, node1, node2, data,
×
3963
        )
×
3964
}
3965

3966
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
3967
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
3968
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
3969
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
3970

×
3971
        if dbChan.Version != int16(ProtocolV1) {
×
3972
                return nil, fmt.Errorf("unsupported channel version: %d",
×
3973
                        dbChan.Version)
×
3974
        }
×
3975

3976
        // Use pre-loaded features and extras types.
3977
        fv := lnwire.EmptyFeatureVector()
×
3978
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
3979
                for _, bit := range features {
×
3980
                        fv.Set(lnwire.FeatureBit(bit))
×
3981
                }
×
3982
        }
3983

3984
        var extras map[uint64][]byte
×
3985
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
3986
        if exists {
×
3987
                extras = channelExtras
×
3988
        } else {
×
3989
                extras = make(map[uint64][]byte)
×
3990
        }
×
3991

3992
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
3993
        if err != nil {
×
3994
                return nil, err
×
3995
        }
×
3996

3997
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3998
        if err != nil {
×
3999
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4000
                        "fields: %w", err)
×
4001
        }
×
4002
        if recs == nil {
×
4003
                recs = make([]byte, 0)
×
4004
        }
×
4005

4006
        var btcKey1, btcKey2 route.Vertex
×
4007
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4008
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4009

×
4010
        channel := &models.ChannelEdgeInfo{
×
4011
                ChainHash:        chain,
×
4012
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4013
                NodeKey1Bytes:    node1,
×
4014
                NodeKey2Bytes:    node2,
×
4015
                BitcoinKey1Bytes: btcKey1,
×
4016
                BitcoinKey2Bytes: btcKey2,
×
4017
                ChannelPoint:     *op,
×
4018
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4019
                Features:         fv,
×
4020
                ExtraOpaqueData:  recs,
×
4021
        }
×
4022

×
4023
        // We always set all the signatures at the same time, so we can
×
4024
        // safely check if one signature is present to determine if we have the
×
4025
        // rest of the signatures for the auth proof.
×
4026
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4027
                channel.AuthProof = &models.ChannelAuthProof{
×
4028
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4029
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4030
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4031
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4032
                }
×
4033
        }
×
4034

4035
        return channel, nil
×
4036
}
4037

4038
// buildNodeVertices is a helper that converts raw node public keys
4039
// into route.Vertex instances.
4040
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4041
        route.Vertex, error) {
×
4042

×
4043
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4044
        if err != nil {
×
4045
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4046
                        "create vertex from node1 pubkey: %w", err)
×
4047
        }
×
4048

4049
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4050
        if err != nil {
×
4051
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4052
                        "create vertex from node2 pubkey: %w", err)
×
4053
        }
×
4054

4055
        return node1Vertex, node2Vertex, nil
×
4056
}
4057

4058
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4059
// retrieves all the extra info required to build the complete
4060
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4061
// the provided sqlc.GraphChannelPolicy records are nil.
4062
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
4063
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4064
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
4065
        *models.ChannelEdgePolicy, error) {
×
4066

×
4067
        if dbPol1 == nil && dbPol2 == nil {
×
4068
                return nil, nil, nil
×
4069
        }
×
4070

4071
        var policyIDs = make([]int64, 0, 2)
×
4072
        if dbPol1 != nil {
×
4073
                policyIDs = append(policyIDs, dbPol1.ID)
×
4074
        }
×
4075
        if dbPol2 != nil {
×
4076
                policyIDs = append(policyIDs, dbPol2.ID)
×
4077
        }
×
4078

4079
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4080
        if err != nil {
×
4081
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4082
                        "data: %w", err)
×
4083
        }
×
4084

4085
        pol1, err := buildChanPolicyWithBatchData(
×
4086
                dbPol1, channelID, node2, batchData,
×
4087
        )
×
4088
        if err != nil {
×
4089
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4090
        }
×
4091

4092
        pol2, err := buildChanPolicyWithBatchData(
×
4093
                dbPol2, channelID, node1, batchData,
×
4094
        )
×
4095
        if err != nil {
×
4096
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4097
        }
×
4098

4099
        return pol1, pol2, nil
×
4100
}
4101

4102
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4103
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4104
// then nil is returned for it.
4105
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4106
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4107
        *models.CachedEdgePolicy, error) {
×
4108

×
4109
        var p1, p2 *models.CachedEdgePolicy
×
4110
        if dbPol1 != nil {
×
4111
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
4112
                if err != nil {
×
4113
                        return nil, nil, err
×
4114
                }
×
4115

4116
                p1 = models.NewCachedPolicy(policy1)
×
4117
        }
4118
        if dbPol2 != nil {
×
4119
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4120
                if err != nil {
×
4121
                        return nil, nil, err
×
4122
                }
×
4123

4124
                p2 = models.NewCachedPolicy(policy2)
×
4125
        }
4126

4127
        return p1, p2, nil
×
4128
}
4129

4130
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4131
// provided sqlc.GraphChannelPolicy and other required information.
4132
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4133
        extras map[uint64][]byte,
4134
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4135

×
4136
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4137
        if err != nil {
×
4138
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4139
                        "fields: %w", err)
×
4140
        }
×
4141

4142
        var inboundFee fn.Option[lnwire.Fee]
×
4143
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4144
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4145

×
4146
                inboundFee = fn.Some(lnwire.Fee{
×
4147
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4148
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4149
                })
×
4150
        }
×
4151

4152
        return &models.ChannelEdgePolicy{
×
4153
                SigBytes:  dbPolicy.Signature,
×
4154
                ChannelID: channelID,
×
4155
                LastUpdate: time.Unix(
×
4156
                        dbPolicy.LastUpdate.Int64, 0,
×
4157
                ),
×
4158
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4159
                        dbPolicy.MessageFlags,
×
4160
                ),
×
4161
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4162
                        dbPolicy.ChannelFlags,
×
4163
                ),
×
4164
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4165
                MinHTLC: lnwire.MilliSatoshi(
×
4166
                        dbPolicy.MinHtlcMsat,
×
4167
                ),
×
4168
                MaxHTLC: lnwire.MilliSatoshi(
×
4169
                        dbPolicy.MaxHtlcMsat.Int64,
×
4170
                ),
×
4171
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4172
                        dbPolicy.BaseFeeMsat,
×
4173
                ),
×
4174
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4175
                ToNode:                    toNode,
×
4176
                InboundFee:                inboundFee,
×
4177
                ExtraOpaqueData:           recs,
×
4178
        }, nil
×
4179
}
4180

4181
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4182
// row which is expected to be a sqlc type that contains channel policy
4183
// information. It returns two policies, which may be nil if the policy
4184
// information is not present in the row.
4185
//
4186
//nolint:ll,dupl,funlen
4187
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4188
        *sqlc.GraphChannelPolicy, error) {
×
4189

×
4190
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4191
        switch r := row.(type) {
×
4192
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4193
                if r.Policy1Timelock.Valid {
×
4194
                        policy1 = &sqlc.GraphChannelPolicy{
×
4195
                                Timelock:                r.Policy1Timelock.Int32,
×
4196
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4197
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4198
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4199
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4200
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4201
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4202
                                Disabled:                r.Policy1Disabled,
×
4203
                                MessageFlags:            r.Policy1MessageFlags,
×
4204
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4205
                        }
×
4206
                }
×
4207
                if r.Policy2Timelock.Valid {
×
4208
                        policy2 = &sqlc.GraphChannelPolicy{
×
4209
                                Timelock:                r.Policy2Timelock.Int32,
×
4210
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4211
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4212
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4213
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4214
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4215
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4216
                                Disabled:                r.Policy2Disabled,
×
4217
                                MessageFlags:            r.Policy2MessageFlags,
×
4218
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4219
                        }
×
4220
                }
×
4221

4222
                return policy1, policy2, nil
×
4223

4224
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4225
                if r.Policy1ID.Valid {
×
4226
                        policy1 = &sqlc.GraphChannelPolicy{
×
4227
                                ID:                      r.Policy1ID.Int64,
×
4228
                                Version:                 r.Policy1Version.Int16,
×
4229
                                ChannelID:               r.GraphChannel.ID,
×
4230
                                NodeID:                  r.Policy1NodeID.Int64,
×
4231
                                Timelock:                r.Policy1Timelock.Int32,
×
4232
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4233
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4234
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4235
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4236
                                LastUpdate:              r.Policy1LastUpdate,
×
4237
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4238
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4239
                                Disabled:                r.Policy1Disabled,
×
4240
                                MessageFlags:            r.Policy1MessageFlags,
×
4241
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4242
                                Signature:               r.Policy1Signature,
×
4243
                        }
×
4244
                }
×
4245
                if r.Policy2ID.Valid {
×
4246
                        policy2 = &sqlc.GraphChannelPolicy{
×
4247
                                ID:                      r.Policy2ID.Int64,
×
4248
                                Version:                 r.Policy2Version.Int16,
×
4249
                                ChannelID:               r.GraphChannel.ID,
×
4250
                                NodeID:                  r.Policy2NodeID.Int64,
×
4251
                                Timelock:                r.Policy2Timelock.Int32,
×
4252
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4253
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4254
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4255
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4256
                                LastUpdate:              r.Policy2LastUpdate,
×
4257
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4258
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4259
                                Disabled:                r.Policy2Disabled,
×
4260
                                MessageFlags:            r.Policy2MessageFlags,
×
4261
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4262
                                Signature:               r.Policy2Signature,
×
4263
                        }
×
4264
                }
×
4265

4266
                return policy1, policy2, nil
×
4267

4268
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4269
                if r.Policy1ID.Valid {
×
4270
                        policy1 = &sqlc.GraphChannelPolicy{
×
4271
                                ID:                      r.Policy1ID.Int64,
×
4272
                                Version:                 r.Policy1Version.Int16,
×
4273
                                ChannelID:               r.GraphChannel.ID,
×
4274
                                NodeID:                  r.Policy1NodeID.Int64,
×
4275
                                Timelock:                r.Policy1Timelock.Int32,
×
4276
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4277
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4278
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4279
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4280
                                LastUpdate:              r.Policy1LastUpdate,
×
4281
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4282
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4283
                                Disabled:                r.Policy1Disabled,
×
4284
                                MessageFlags:            r.Policy1MessageFlags,
×
4285
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4286
                                Signature:               r.Policy1Signature,
×
4287
                        }
×
4288
                }
×
4289
                if r.Policy2ID.Valid {
×
4290
                        policy2 = &sqlc.GraphChannelPolicy{
×
4291
                                ID:                      r.Policy2ID.Int64,
×
4292
                                Version:                 r.Policy2Version.Int16,
×
4293
                                ChannelID:               r.GraphChannel.ID,
×
4294
                                NodeID:                  r.Policy2NodeID.Int64,
×
4295
                                Timelock:                r.Policy2Timelock.Int32,
×
4296
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4297
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4298
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4299
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4300
                                LastUpdate:              r.Policy2LastUpdate,
×
4301
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4302
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4303
                                Disabled:                r.Policy2Disabled,
×
4304
                                MessageFlags:            r.Policy2MessageFlags,
×
4305
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4306
                                Signature:               r.Policy2Signature,
×
4307
                        }
×
4308
                }
×
4309

4310
                return policy1, policy2, nil
×
4311

4312
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4313
                if r.Policy1ID.Valid {
×
4314
                        policy1 = &sqlc.GraphChannelPolicy{
×
4315
                                ID:                      r.Policy1ID.Int64,
×
4316
                                Version:                 r.Policy1Version.Int16,
×
4317
                                ChannelID:               r.GraphChannel.ID,
×
4318
                                NodeID:                  r.Policy1NodeID.Int64,
×
4319
                                Timelock:                r.Policy1Timelock.Int32,
×
4320
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4321
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4322
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4323
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4324
                                LastUpdate:              r.Policy1LastUpdate,
×
4325
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4326
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4327
                                Disabled:                r.Policy1Disabled,
×
4328
                                MessageFlags:            r.Policy1MessageFlags,
×
4329
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4330
                                Signature:               r.Policy1Signature,
×
4331
                        }
×
4332
                }
×
4333
                if r.Policy2ID.Valid {
×
4334
                        policy2 = &sqlc.GraphChannelPolicy{
×
4335
                                ID:                      r.Policy2ID.Int64,
×
4336
                                Version:                 r.Policy2Version.Int16,
×
4337
                                ChannelID:               r.GraphChannel.ID,
×
4338
                                NodeID:                  r.Policy2NodeID.Int64,
×
4339
                                Timelock:                r.Policy2Timelock.Int32,
×
4340
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4341
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4342
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4343
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4344
                                LastUpdate:              r.Policy2LastUpdate,
×
4345
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4346
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4347
                                Disabled:                r.Policy2Disabled,
×
4348
                                MessageFlags:            r.Policy2MessageFlags,
×
4349
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4350
                                Signature:               r.Policy2Signature,
×
4351
                        }
×
4352
                }
×
4353

4354
                return policy1, policy2, nil
×
4355

4356
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4357
                if r.Policy1ID.Valid {
×
4358
                        policy1 = &sqlc.GraphChannelPolicy{
×
4359
                                ID:                      r.Policy1ID.Int64,
×
4360
                                Version:                 r.Policy1Version.Int16,
×
4361
                                ChannelID:               r.GraphChannel.ID,
×
4362
                                NodeID:                  r.Policy1NodeID.Int64,
×
4363
                                Timelock:                r.Policy1Timelock.Int32,
×
4364
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4365
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4366
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4367
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4368
                                LastUpdate:              r.Policy1LastUpdate,
×
4369
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4370
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4371
                                Disabled:                r.Policy1Disabled,
×
4372
                                MessageFlags:            r.Policy1MessageFlags,
×
4373
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4374
                                Signature:               r.Policy1Signature,
×
4375
                        }
×
4376
                }
×
4377
                if r.Policy2ID.Valid {
×
4378
                        policy2 = &sqlc.GraphChannelPolicy{
×
4379
                                ID:                      r.Policy2ID.Int64,
×
4380
                                Version:                 r.Policy2Version.Int16,
×
4381
                                ChannelID:               r.GraphChannel.ID,
×
4382
                                NodeID:                  r.Policy2NodeID.Int64,
×
4383
                                Timelock:                r.Policy2Timelock.Int32,
×
4384
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4385
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4386
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4387
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4388
                                LastUpdate:              r.Policy2LastUpdate,
×
4389
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4390
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4391
                                Disabled:                r.Policy2Disabled,
×
4392
                                MessageFlags:            r.Policy2MessageFlags,
×
4393
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4394
                                Signature:               r.Policy2Signature,
×
4395
                        }
×
4396
                }
×
4397

4398
                return policy1, policy2, nil
×
4399

4400
        case sqlc.ListChannelsForNodeIDsRow:
×
4401
                if r.Policy1ID.Valid {
×
4402
                        policy1 = &sqlc.GraphChannelPolicy{
×
4403
                                ID:                      r.Policy1ID.Int64,
×
4404
                                Version:                 r.Policy1Version.Int16,
×
4405
                                ChannelID:               r.GraphChannel.ID,
×
4406
                                NodeID:                  r.Policy1NodeID.Int64,
×
4407
                                Timelock:                r.Policy1Timelock.Int32,
×
4408
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4409
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4410
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4411
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4412
                                LastUpdate:              r.Policy1LastUpdate,
×
4413
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4414
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4415
                                Disabled:                r.Policy1Disabled,
×
4416
                                MessageFlags:            r.Policy1MessageFlags,
×
4417
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4418
                                Signature:               r.Policy1Signature,
×
4419
                        }
×
4420
                }
×
4421
                if r.Policy2ID.Valid {
×
4422
                        policy2 = &sqlc.GraphChannelPolicy{
×
4423
                                ID:                      r.Policy2ID.Int64,
×
4424
                                Version:                 r.Policy2Version.Int16,
×
4425
                                ChannelID:               r.GraphChannel.ID,
×
4426
                                NodeID:                  r.Policy2NodeID.Int64,
×
4427
                                Timelock:                r.Policy2Timelock.Int32,
×
4428
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4429
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4430
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4431
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4432
                                LastUpdate:              r.Policy2LastUpdate,
×
4433
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4434
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4435
                                Disabled:                r.Policy2Disabled,
×
4436
                                MessageFlags:            r.Policy2MessageFlags,
×
4437
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4438
                                Signature:               r.Policy2Signature,
×
4439
                        }
×
4440
                }
×
4441

4442
                return policy1, policy2, nil
×
4443

4444
        case sqlc.ListChannelsByNodeIDRow:
×
4445
                if r.Policy1ID.Valid {
×
4446
                        policy1 = &sqlc.GraphChannelPolicy{
×
4447
                                ID:                      r.Policy1ID.Int64,
×
4448
                                Version:                 r.Policy1Version.Int16,
×
4449
                                ChannelID:               r.GraphChannel.ID,
×
4450
                                NodeID:                  r.Policy1NodeID.Int64,
×
4451
                                Timelock:                r.Policy1Timelock.Int32,
×
4452
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4453
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4454
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4455
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4456
                                LastUpdate:              r.Policy1LastUpdate,
×
4457
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4458
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4459
                                Disabled:                r.Policy1Disabled,
×
4460
                                MessageFlags:            r.Policy1MessageFlags,
×
4461
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4462
                                Signature:               r.Policy1Signature,
×
4463
                        }
×
4464
                }
×
4465
                if r.Policy2ID.Valid {
×
4466
                        policy2 = &sqlc.GraphChannelPolicy{
×
4467
                                ID:                      r.Policy2ID.Int64,
×
4468
                                Version:                 r.Policy2Version.Int16,
×
4469
                                ChannelID:               r.GraphChannel.ID,
×
4470
                                NodeID:                  r.Policy2NodeID.Int64,
×
4471
                                Timelock:                r.Policy2Timelock.Int32,
×
4472
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4473
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4474
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4475
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4476
                                LastUpdate:              r.Policy2LastUpdate,
×
4477
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4478
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4479
                                Disabled:                r.Policy2Disabled,
×
4480
                                MessageFlags:            r.Policy2MessageFlags,
×
4481
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4482
                                Signature:               r.Policy2Signature,
×
4483
                        }
×
4484
                }
×
4485

4486
                return policy1, policy2, nil
×
4487

4488
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4489
                if r.Policy1ID.Valid {
×
4490
                        policy1 = &sqlc.GraphChannelPolicy{
×
4491
                                ID:                      r.Policy1ID.Int64,
×
4492
                                Version:                 r.Policy1Version.Int16,
×
4493
                                ChannelID:               r.GraphChannel.ID,
×
4494
                                NodeID:                  r.Policy1NodeID.Int64,
×
4495
                                Timelock:                r.Policy1Timelock.Int32,
×
4496
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4497
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4498
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4499
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4500
                                LastUpdate:              r.Policy1LastUpdate,
×
4501
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4502
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4503
                                Disabled:                r.Policy1Disabled,
×
4504
                                MessageFlags:            r.Policy1MessageFlags,
×
4505
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4506
                                Signature:               r.Policy1Signature,
×
4507
                        }
×
4508
                }
×
4509
                if r.Policy2ID.Valid {
×
4510
                        policy2 = &sqlc.GraphChannelPolicy{
×
4511
                                ID:                      r.Policy2ID.Int64,
×
4512
                                Version:                 r.Policy2Version.Int16,
×
4513
                                ChannelID:               r.GraphChannel.ID,
×
4514
                                NodeID:                  r.Policy2NodeID.Int64,
×
4515
                                Timelock:                r.Policy2Timelock.Int32,
×
4516
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4517
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4518
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4519
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4520
                                LastUpdate:              r.Policy2LastUpdate,
×
4521
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4522
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4523
                                Disabled:                r.Policy2Disabled,
×
4524
                                MessageFlags:            r.Policy2MessageFlags,
×
4525
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4526
                                Signature:               r.Policy2Signature,
×
4527
                        }
×
4528
                }
×
4529

4530
                return policy1, policy2, nil
×
4531

4532
        case sqlc.GetChannelsByIDsRow:
×
4533
                if r.Policy1ID.Valid {
×
4534
                        policy1 = &sqlc.GraphChannelPolicy{
×
4535
                                ID:                      r.Policy1ID.Int64,
×
4536
                                Version:                 r.Policy1Version.Int16,
×
4537
                                ChannelID:               r.GraphChannel.ID,
×
4538
                                NodeID:                  r.Policy1NodeID.Int64,
×
4539
                                Timelock:                r.Policy1Timelock.Int32,
×
4540
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4541
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4542
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4543
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4544
                                LastUpdate:              r.Policy1LastUpdate,
×
4545
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4546
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4547
                                Disabled:                r.Policy1Disabled,
×
4548
                                MessageFlags:            r.Policy1MessageFlags,
×
4549
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4550
                                Signature:               r.Policy1Signature,
×
4551
                        }
×
4552
                }
×
4553
                if r.Policy2ID.Valid {
×
4554
                        policy2 = &sqlc.GraphChannelPolicy{
×
4555
                                ID:                      r.Policy2ID.Int64,
×
4556
                                Version:                 r.Policy2Version.Int16,
×
4557
                                ChannelID:               r.GraphChannel.ID,
×
4558
                                NodeID:                  r.Policy2NodeID.Int64,
×
4559
                                Timelock:                r.Policy2Timelock.Int32,
×
4560
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4561
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4562
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4563
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4564
                                LastUpdate:              r.Policy2LastUpdate,
×
4565
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4566
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4567
                                Disabled:                r.Policy2Disabled,
×
4568
                                MessageFlags:            r.Policy2MessageFlags,
×
4569
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4570
                                Signature:               r.Policy2Signature,
×
4571
                        }
×
4572
                }
×
4573

4574
                return policy1, policy2, nil
×
4575

4576
        default:
×
4577
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4578
                        "extractChannelPolicies: %T", r)
×
4579
        }
4580
}
4581

4582
// channelIDToBytes converts a channel ID (SCID) to a byte array
4583
// representation.
4584
func channelIDToBytes(channelID uint64) []byte {
×
4585
        var chanIDB [8]byte
×
4586
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4587

×
4588
        return chanIDB[:]
×
4589
}
×
4590

4591
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4592
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4593
        if len(addresses) == 0 {
×
4594
                return nil, nil
×
4595
        }
×
4596

4597
        result := make([]net.Addr, 0, len(addresses))
×
4598
        for _, addr := range addresses {
×
4599
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
4600
                if err != nil {
×
4601
                        return nil, fmt.Errorf("unable to parse address %s "+
×
4602
                                "of type %d: %w", addr.address, addr.addrType,
×
4603
                                err)
×
4604
                }
×
4605
                if netAddr != nil {
×
4606
                        result = append(result, netAddr)
×
4607
                }
×
4608
        }
4609

4610
        // If we have no valid addresses, return nil instead of empty slice.
4611
        if len(result) == 0 {
×
4612
                return nil, nil
×
4613
        }
×
4614

4615
        return result, nil
×
4616
}
4617

4618
// parseAddress parses the given address string based on the address type
4619
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4620
// and opaque addresses.
4621
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4622
        switch addrType {
×
4623
        case addressTypeIPv4:
×
4624
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4625
                if err != nil {
×
4626
                        return nil, err
×
4627
                }
×
4628

4629
                tcp.IP = tcp.IP.To4()
×
4630

×
4631
                return tcp, nil
×
4632

4633
        case addressTypeIPv6:
×
4634
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
4635
                if err != nil {
×
4636
                        return nil, err
×
4637
                }
×
4638

4639
                return tcp, nil
×
4640

4641
        case addressTypeTorV3, addressTypeTorV2:
×
4642
                service, portStr, err := net.SplitHostPort(address)
×
4643
                if err != nil {
×
4644
                        return nil, fmt.Errorf("unable to split tor "+
×
4645
                                "address: %v", address)
×
4646
                }
×
4647

4648
                port, err := strconv.Atoi(portStr)
×
4649
                if err != nil {
×
4650
                        return nil, err
×
4651
                }
×
4652

4653
                return &tor.OnionAddr{
×
4654
                        OnionService: service,
×
4655
                        Port:         port,
×
4656
                }, nil
×
4657

NEW
4658
        case addressTypeDNS:
×
NEW
4659
                hostname, portStr, err := net.SplitHostPort(address)
×
NEW
4660
                if err != nil {
×
NEW
4661
                        return nil, fmt.Errorf("unable to split DNS "+
×
NEW
4662
                                "address: %v", address)
×
NEW
4663
                }
×
4664

NEW
4665
                port, err := strconv.Atoi(portStr)
×
NEW
4666
                if err != nil {
×
NEW
4667
                        return nil, err
×
NEW
4668
                }
×
4669

NEW
4670
                return &lnwire.DNSAddress{
×
NEW
4671
                        Hostname: hostname,
×
NEW
4672
                        Port:     uint16(port),
×
NEW
4673
                }, nil
×
4674

4675
        case addressTypeOpaque:
×
4676
                opaque, err := hex.DecodeString(address)
×
4677
                if err != nil {
×
4678
                        return nil, fmt.Errorf("unable to decode opaque "+
×
4679
                                "address: %v", address)
×
4680
                }
×
4681

4682
                return &lnwire.OpaqueAddrs{
×
4683
                        Payload: opaque,
×
4684
                }, nil
×
4685

4686
        default:
×
4687
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
4688
        }
4689
}
4690

4691
// batchNodeData holds all the related data for a batch of nodes.
4692
type batchNodeData struct {
4693
        // features is a map from a DB node ID to the feature bits for that
4694
        // node.
4695
        features map[int64][]int
4696

4697
        // addresses is a map from a DB node ID to the node's addresses.
4698
        addresses map[int64][]nodeAddress
4699

4700
        // extraFields is a map from a DB node ID to the extra signed fields
4701
        // for that node.
4702
        extraFields map[int64]map[uint64][]byte
4703
}
4704

4705
// nodeAddress holds the address type, position and address string for a
4706
// node. This is used to batch the fetching of node addresses.
4707
type nodeAddress struct {
4708
        addrType dbAddressType
4709
        position int32
4710
        address  string
4711
}
4712

4713
// batchLoadNodeData loads all related data for a batch of node IDs using the
4714
// provided SQLQueries interface. It returns a batchNodeData instance containing
4715
// the node features, addresses and extra signed fields.
4716
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
4717
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
4718

×
4719
        // Batch load the node features.
×
4720
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
4721
        if err != nil {
×
4722
                return nil, fmt.Errorf("unable to batch load node "+
×
4723
                        "features: %w", err)
×
4724
        }
×
4725

4726
        // Batch load the node addresses.
4727
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
4728
        if err != nil {
×
4729
                return nil, fmt.Errorf("unable to batch load node "+
×
4730
                        "addresses: %w", err)
×
4731
        }
×
4732

4733
        // Batch load the node extra signed fields.
4734
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
4735
        if err != nil {
×
4736
                return nil, fmt.Errorf("unable to batch load node extra "+
×
4737
                        "signed fields: %w", err)
×
4738
        }
×
4739

4740
        return &batchNodeData{
×
4741
                features:    features,
×
4742
                addresses:   addrs,
×
4743
                extraFields: extraTypes,
×
4744
        }, nil
×
4745
}
4746

4747
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
4748
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
4749
func batchLoadNodeFeaturesHelper(ctx context.Context,
4750
        cfg *sqldb.QueryConfig, db SQLQueries,
4751
        nodeIDs []int64) (map[int64][]int, error) {
×
4752

×
4753
        features := make(map[int64][]int)
×
4754

×
4755
        return features, sqldb.ExecuteBatchQuery(
×
4756
                ctx, cfg, nodeIDs,
×
4757
                func(id int64) int64 {
×
4758
                        return id
×
4759
                },
×
4760
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
4761
                        error) {
×
4762

×
4763
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
4764
                },
×
4765
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
4766
                        features[feature.NodeID] = append(
×
4767
                                features[feature.NodeID],
×
4768
                                int(feature.FeatureBit),
×
4769
                        )
×
4770

×
4771
                        return nil
×
4772
                },
×
4773
        )
4774
}
4775

4776
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
4777
// wrapper around the GetNodeAddressesBatch query. It returns a map from
4778
// node ID to a slice of nodeAddress structs.
4779
func batchLoadNodeAddressesHelper(ctx context.Context,
4780
        cfg *sqldb.QueryConfig, db SQLQueries,
4781
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
4782

×
4783
        addrs := make(map[int64][]nodeAddress)
×
4784

×
4785
        return addrs, sqldb.ExecuteBatchQuery(
×
4786
                ctx, cfg, nodeIDs,
×
4787
                func(id int64) int64 {
×
4788
                        return id
×
4789
                },
×
4790
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
4791
                        error) {
×
4792

×
4793
                        return db.GetNodeAddressesBatch(ctx, ids)
×
4794
                },
×
4795
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
4796
                        addrs[addr.NodeID] = append(
×
4797
                                addrs[addr.NodeID], nodeAddress{
×
4798
                                        addrType: dbAddressType(addr.Type),
×
4799
                                        position: addr.Position,
×
4800
                                        address:  addr.Address,
×
4801
                                },
×
4802
                        )
×
4803

×
4804
                        return nil
×
4805
                },
×
4806
        )
4807
}
4808

4809
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
4810
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
4811
// query.
4812
func batchLoadNodeExtraTypesHelper(ctx context.Context,
4813
        cfg *sqldb.QueryConfig, db SQLQueries,
4814
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4815

×
4816
        extraFields := make(map[int64]map[uint64][]byte)
×
4817

×
4818
        callback := func(ctx context.Context,
×
4819
                field sqlc.GraphNodeExtraType) error {
×
4820

×
4821
                if extraFields[field.NodeID] == nil {
×
4822
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
4823
                }
×
4824
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
4825

×
4826
                return nil
×
4827
        }
4828

4829
        return extraFields, sqldb.ExecuteBatchQuery(
×
4830
                ctx, cfg, nodeIDs,
×
4831
                func(id int64) int64 {
×
4832
                        return id
×
4833
                },
×
4834
                func(ctx context.Context, ids []int64) (
4835
                        []sqlc.GraphNodeExtraType, error) {
×
4836

×
4837
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
4838
                },
×
4839
                callback,
4840
        )
4841
}
4842

4843
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
4844
// from the provided sqlc.GraphChannelPolicy records and the
4845
// provided batchChannelData.
4846
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4847
        channelID uint64, node1, node2 route.Vertex,
4848
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
4849
        *models.ChannelEdgePolicy, error) {
×
4850

×
4851
        pol1, err := buildChanPolicyWithBatchData(
×
4852
                dbPol1, channelID, node2, batchData,
×
4853
        )
×
4854
        if err != nil {
×
4855
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4856
        }
×
4857

4858
        pol2, err := buildChanPolicyWithBatchData(
×
4859
                dbPol2, channelID, node1, batchData,
×
4860
        )
×
4861
        if err != nil {
×
4862
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4863
        }
×
4864

4865
        return pol1, pol2, nil
×
4866
}
4867

4868
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
4869
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
4870
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
4871
        channelID uint64, toNode route.Vertex,
4872
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
4873

×
4874
        if dbPol == nil {
×
4875
                return nil, nil
×
4876
        }
×
4877

4878
        var dbPol1Extras map[uint64][]byte
×
4879
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
4880
                dbPol1Extras = extras
×
4881
        } else {
×
4882
                dbPol1Extras = make(map[uint64][]byte)
×
4883
        }
×
4884

4885
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
4886
}
4887

4888
// batchChannelData holds all the related data for a batch of channels.
4889
type batchChannelData struct {
4890
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
4891
        chanfeatures map[int64][]int
4892

4893
        // chanExtras is a map from DB channel ID to a map of TLV type to
4894
        // extra signed field bytes.
4895
        chanExtraTypes map[int64]map[uint64][]byte
4896

4897
        // policyExtras is a map from DB channel policy ID to a map of TLV type
4898
        // to extra signed field bytes.
4899
        policyExtras map[int64]map[uint64][]byte
4900
}
4901

4902
// batchLoadChannelData loads all related data for batches of channels and
4903
// policies.
4904
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
4905
        db SQLQueries, channelIDs []int64,
4906
        policyIDs []int64) (*batchChannelData, error) {
×
4907

×
4908
        batchData := &batchChannelData{
×
4909
                chanfeatures:   make(map[int64][]int),
×
4910
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
4911
                policyExtras:   make(map[int64]map[uint64][]byte),
×
4912
        }
×
4913

×
4914
        // Batch load channel features and extras
×
4915
        var err error
×
4916
        if len(channelIDs) > 0 {
×
4917
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
4918
                        ctx, cfg, db, channelIDs,
×
4919
                )
×
4920
                if err != nil {
×
4921
                        return nil, fmt.Errorf("unable to batch load "+
×
4922
                                "channel features: %w", err)
×
4923
                }
×
4924

4925
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
4926
                        ctx, cfg, db, channelIDs,
×
4927
                )
×
4928
                if err != nil {
×
4929
                        return nil, fmt.Errorf("unable to batch load "+
×
4930
                                "channel extras: %w", err)
×
4931
                }
×
4932
        }
4933

4934
        if len(policyIDs) > 0 {
×
4935
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
4936
                        ctx, cfg, db, policyIDs,
×
4937
                )
×
4938
                if err != nil {
×
4939
                        return nil, fmt.Errorf("unable to batch load "+
×
4940
                                "policy extras: %w", err)
×
4941
                }
×
4942
                batchData.policyExtras = policyExtras
×
4943
        }
4944

4945
        return batchData, nil
×
4946
}
4947

4948
// batchLoadChannelFeaturesHelper loads channel features for a batch of
4949
// channel IDs using ExecuteBatchQuery wrapper around the
4950
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
4951
// slice of feature bits.
4952
func batchLoadChannelFeaturesHelper(ctx context.Context,
4953
        cfg *sqldb.QueryConfig, db SQLQueries,
4954
        channelIDs []int64) (map[int64][]int, error) {
×
4955

×
4956
        features := make(map[int64][]int)
×
4957

×
4958
        return features, sqldb.ExecuteBatchQuery(
×
4959
                ctx, cfg, channelIDs,
×
4960
                func(id int64) int64 {
×
4961
                        return id
×
4962
                },
×
4963
                func(ctx context.Context,
4964
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
4965

×
4966
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
4967
                },
×
4968
                func(ctx context.Context,
4969
                        feature sqlc.GraphChannelFeature) error {
×
4970

×
4971
                        features[feature.ChannelID] = append(
×
4972
                                features[feature.ChannelID],
×
4973
                                int(feature.FeatureBit),
×
4974
                        )
×
4975

×
4976
                        return nil
×
4977
                },
×
4978
        )
4979
}
4980

4981
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
4982
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
4983
// query. It returns a map from DB channel ID to a map of TLV type to extra
4984
// signed field bytes.
4985
func batchLoadChannelExtrasHelper(ctx context.Context,
4986
        cfg *sqldb.QueryConfig, db SQLQueries,
4987
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4988

×
4989
        extras := make(map[int64]map[uint64][]byte)
×
4990

×
4991
        cb := func(ctx context.Context,
×
4992
                extra sqlc.GraphChannelExtraType) error {
×
4993

×
4994
                if extras[extra.ChannelID] == nil {
×
4995
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
4996
                }
×
4997
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
4998

×
4999
                return nil
×
5000
        }
5001

5002
        return extras, sqldb.ExecuteBatchQuery(
×
5003
                ctx, cfg, channelIDs,
×
5004
                func(id int64) int64 {
×
5005
                        return id
×
5006
                },
×
5007
                func(ctx context.Context,
5008
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5009

×
5010
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5011
                }, cb,
×
5012
        )
5013
}
5014

5015
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5016
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5017
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5018
// a map of TLV type to extra signed field bytes.
5019
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5020
        cfg *sqldb.QueryConfig, db SQLQueries,
5021
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5022

×
5023
        extras := make(map[int64]map[uint64][]byte)
×
5024

×
5025
        return extras, sqldb.ExecuteBatchQuery(
×
5026
                ctx, cfg, policyIDs,
×
5027
                func(id int64) int64 {
×
5028
                        return id
×
5029
                },
×
5030
                func(ctx context.Context, ids []int64) (
5031
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5032

×
5033
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5034
                },
×
5035
                func(ctx context.Context,
5036
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5037

×
5038
                        if extras[row.PolicyID] == nil {
×
5039
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5040
                        }
×
5041
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5042

×
5043
                        return nil
×
5044
                },
5045
        )
5046
}
5047

5048
// forEachNodePaginated executes a paginated query to process each node in the
5049
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5050
// and applies the provided processNode function to each node.
5051
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5052
        db SQLQueries, protocol ProtocolVersion,
5053
        processNode func(context.Context, int64,
5054
                *models.LightningNode) error) error {
×
5055

×
5056
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5057
                limit int32) ([]sqlc.GraphNode, error) {
×
5058

×
5059
                return db.ListNodesPaginated(
×
5060
                        ctx, sqlc.ListNodesPaginatedParams{
×
5061
                                Version: int16(protocol),
×
5062
                                ID:      lastID,
×
5063
                                Limit:   limit,
×
5064
                        },
×
5065
                )
×
5066
        }
×
5067

5068
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5069
                return node.ID
×
5070
        }
×
5071

5072
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5073
                return node.ID, nil
×
5074
        }
×
5075

5076
        batchQueryFunc := func(ctx context.Context,
×
5077
                nodeIDs []int64) (*batchNodeData, error) {
×
5078

×
5079
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5080
        }
×
5081

5082
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5083
                batchData *batchNodeData) error {
×
5084

×
5085
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
5086
                if err != nil {
×
5087
                        return fmt.Errorf("unable to build "+
×
5088
                                "node(id=%d): %w", dbNode.ID, err)
×
5089
                }
×
5090

5091
                return processNode(ctx, dbNode.ID, node)
×
5092
        }
5093

5094
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5095
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5096
                collectFunc, batchQueryFunc, processItem,
×
5097
        )
×
5098
}
5099

5100
// forEachChannelWithPolicies executes a paginated query to process each channel
5101
// with policies in the graph.
5102
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5103
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5104
                *models.ChannelEdgePolicy,
5105
                *models.ChannelEdgePolicy) error) error {
×
5106

×
5107
        type channelBatchIDs struct {
×
5108
                channelID int64
×
5109
                policyIDs []int64
×
5110
        }
×
5111

×
5112
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5113
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5114
                error) {
×
5115

×
5116
                return db.ListChannelsWithPoliciesPaginated(
×
5117
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
5118
                                Version: int16(ProtocolV1),
×
5119
                                ID:      lastID,
×
5120
                                Limit:   limit,
×
5121
                        },
×
5122
                )
×
5123
        }
×
5124

5125
        extractPageCursor := func(
×
5126
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5127

×
5128
                return row.GraphChannel.ID
×
5129
        }
×
5130

5131
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
5132
                channelBatchIDs, error) {
×
5133

×
5134
                ids := channelBatchIDs{
×
5135
                        channelID: row.GraphChannel.ID,
×
5136
                }
×
5137

×
5138
                // Extract policy IDs from the row.
×
5139
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5140
                if err != nil {
×
5141
                        return ids, err
×
5142
                }
×
5143

5144
                if dbPol1 != nil {
×
5145
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5146
                }
×
5147
                if dbPol2 != nil {
×
5148
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5149
                }
×
5150

5151
                return ids, nil
×
5152
        }
5153

5154
        batchDataFunc := func(ctx context.Context,
×
5155
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5156

×
5157
                // Separate channel IDs from policy IDs.
×
5158
                var (
×
5159
                        channelIDs = make([]int64, len(allIDs))
×
5160
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5161
                )
×
5162

×
5163
                for i, ids := range allIDs {
×
5164
                        channelIDs[i] = ids.channelID
×
5165
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5166
                }
×
5167

5168
                return batchLoadChannelData(
×
5169
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5170
                )
×
5171
        }
5172

5173
        processItem := func(ctx context.Context,
×
5174
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5175
                batchData *batchChannelData) error {
×
5176

×
5177
                node1, node2, err := buildNodeVertices(
×
5178
                        row.Node1Pubkey, row.Node2Pubkey,
×
5179
                )
×
5180
                if err != nil {
×
5181
                        return err
×
5182
                }
×
5183

5184
                edge, err := buildEdgeInfoWithBatchData(
×
5185
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5186
                        batchData,
×
5187
                )
×
5188
                if err != nil {
×
5189
                        return fmt.Errorf("unable to build channel info: %w",
×
5190
                                err)
×
5191
                }
×
5192

5193
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5194
                if err != nil {
×
5195
                        return err
×
5196
                }
×
5197

5198
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5199
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
5200
                )
×
5201
                if err != nil {
×
5202
                        return err
×
5203
                }
×
5204

5205
                return processChannel(edge, p1, p2)
×
5206
        }
5207

5208
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5209
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5210
                collectFunc, batchDataFunc, processItem,
×
5211
        )
×
5212
}
5213

5214
// buildDirectedChannel builds a DirectedChannel instance from the provided
5215
// data.
5216
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5217
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5218
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5219
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
5220

×
5221
        node1, node2, err := buildNodeVertices(
×
5222
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5223
        )
×
5224
        if err != nil {
×
5225
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5226
        }
×
5227

5228
        edge, err := buildEdgeInfoWithBatchData(
×
5229
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5230
        )
×
5231
        if err != nil {
×
5232
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5233
        }
×
5234

5235
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
5236
        if err != nil {
×
5237
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5238
                        err)
×
5239
        }
×
5240

5241
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5242
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5243
                channelBatchData,
×
5244
        )
×
5245
        if err != nil {
×
5246
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5247
                        err)
×
5248
        }
×
5249

5250
        // Determine outgoing and incoming policy for this specific node.
5251
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5252
        p2ToNode := channelRow.GraphChannel.NodeID1
×
5253
        outPolicy, inPolicy := p1, p2
×
5254
        if (p1 != nil && p1ToNode == nodeID) ||
×
5255
                (p2 != nil && p2ToNode != nodeID) {
×
5256

×
5257
                outPolicy, inPolicy = p2, p1
×
5258
        }
×
5259

5260
        // Build cached policy.
5261
        var cachedInPolicy *models.CachedEdgePolicy
×
5262
        if inPolicy != nil {
×
5263
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5264
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5265
                cachedInPolicy.ToNodeFeatures = features
×
5266
        }
×
5267

5268
        // Extract inbound fee.
5269
        var inboundFee lnwire.Fee
×
5270
        if outPolicy != nil {
×
5271
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5272
                        inboundFee = fee
×
5273
                })
×
5274
        }
5275

5276
        // Build directed channel.
5277
        directedChannel := &DirectedChannel{
×
5278
                ChannelID:    edge.ChannelID,
×
5279
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
5280
                OtherNode:    edge.NodeKey2Bytes,
×
5281
                Capacity:     edge.Capacity,
×
5282
                OutPolicySet: outPolicy != nil,
×
5283
                InPolicy:     cachedInPolicy,
×
5284
                InboundFee:   inboundFee,
×
5285
        }
×
5286

×
5287
        if nodePub == edge.NodeKey2Bytes {
×
5288
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5289
        }
×
5290

5291
        return directedChannel, nil
×
5292
}
5293

5294
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
5295
// provided rows. It uses batch loading for channels, policies, and nodes.
5296
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
5297
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
5298

×
5299
        var (
×
5300
                channelIDs = make([]int64, len(rows))
×
5301
                policyIDs  = make([]int64, 0, len(rows)*2)
×
5302
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
5303

×
5304
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
5305
                nodeIDSet = make(map[int64]bool)
×
5306

×
5307
                // edges will hold the final channel edges built from the rows.
×
5308
                edges = make([]ChannelEdge, 0, len(rows))
×
5309
        )
×
5310

×
5311
        // Collect all IDs needed for batch loading.
×
5312
        for i, row := range rows {
×
5313
                channelIDs[i] = row.Channel().ID
×
5314

×
5315
                // Collect policy IDs
×
5316
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5317
                if err != nil {
×
5318
                        return nil, fmt.Errorf("unable to extract channel "+
×
5319
                                "policies: %w", err)
×
5320
                }
×
5321
                if dbPol1 != nil {
×
5322
                        policyIDs = append(policyIDs, dbPol1.ID)
×
5323
                }
×
5324
                if dbPol2 != nil {
×
5325
                        policyIDs = append(policyIDs, dbPol2.ID)
×
5326
                }
×
5327

5328
                var (
×
5329
                        node1ID = row.Node1().ID
×
5330
                        node2ID = row.Node2().ID
×
5331
                )
×
5332

×
5333
                // Collect unique node IDs.
×
5334
                if !nodeIDSet[node1ID] {
×
5335
                        nodeIDs = append(nodeIDs, node1ID)
×
5336
                        nodeIDSet[node1ID] = true
×
5337
                }
×
5338

5339
                if !nodeIDSet[node2ID] {
×
5340
                        nodeIDs = append(nodeIDs, node2ID)
×
5341
                        nodeIDSet[node2ID] = true
×
5342
                }
×
5343
        }
5344

5345
        // Batch the data for all the channels and policies.
5346
        channelBatchData, err := batchLoadChannelData(
×
5347
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5348
        )
×
5349
        if err != nil {
×
5350
                return nil, fmt.Errorf("unable to batch load channel and "+
×
5351
                        "policy data: %w", err)
×
5352
        }
×
5353

5354
        // Batch the data for all the nodes.
5355
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
5356
        if err != nil {
×
5357
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
5358
                        err)
×
5359
        }
×
5360

5361
        // Build all channel edges using batch data.
5362
        for _, row := range rows {
×
5363
                // Build nodes using batch data.
×
5364
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
5365
                if err != nil {
×
5366
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
5367
                }
×
5368

5369
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
5370
                if err != nil {
×
5371
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
5372
                }
×
5373

5374
                // Build channel info using batch data.
5375
                channel, err := buildEdgeInfoWithBatchData(
×
5376
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
5377
                        node2.PubKeyBytes, channelBatchData,
×
5378
                )
×
5379
                if err != nil {
×
5380
                        return nil, fmt.Errorf("unable to build channel "+
×
5381
                                "info: %w", err)
×
5382
                }
×
5383

5384
                // Extract and build policies using batch data.
5385
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5386
                if err != nil {
×
5387
                        return nil, fmt.Errorf("unable to extract channel "+
×
5388
                                "policies: %w", err)
×
5389
                }
×
5390

5391
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5392
                        dbPol1, dbPol2, channel.ChannelID,
×
5393
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
5394
                )
×
5395
                if err != nil {
×
5396
                        return nil, fmt.Errorf("unable to build channel "+
×
5397
                                "policies: %w", err)
×
5398
                }
×
5399

5400
                edges = append(edges, ChannelEdge{
×
5401
                        Info:    channel,
×
5402
                        Policy1: p1,
×
5403
                        Policy2: p2,
×
5404
                        Node1:   node1,
×
5405
                        Node2:   node2,
×
5406
                })
×
5407
        }
5408

5409
        return edges, nil
×
5410
}
5411

5412
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
5413
// instances from the provided rows using batch loading for channel data.
5414
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
5415
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
5416
        []*models.ChannelEdgeInfo, []int64, error) {
×
5417

×
5418
        if len(rows) == 0 {
×
5419
                return nil, nil, nil
×
5420
        }
×
5421

5422
        // Collect all the channel IDs needed for batch loading.
5423
        channelIDs := make([]int64, len(rows))
×
5424
        for i, row := range rows {
×
5425
                channelIDs[i] = row.Channel().ID
×
5426
        }
×
5427

5428
        // Batch load the channel data.
5429
        channelBatchData, err := batchLoadChannelData(
×
5430
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
5431
        )
×
5432
        if err != nil {
×
5433
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
5434
                        "data: %w", err)
×
5435
        }
×
5436

5437
        // Build all channel edges using batch data.
5438
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
5439
        for _, row := range rows {
×
5440
                node1, node2, err := buildNodeVertices(
×
5441
                        row.Node1Pub(), row.Node2Pub(),
×
5442
                )
×
5443
                if err != nil {
×
5444
                        return nil, nil, err
×
5445
                }
×
5446

5447
                // Build channel info using batch data
5448
                info, err := buildEdgeInfoWithBatchData(
×
5449
                        cfg.ChainHash, row.Channel(), node1, node2,
×
5450
                        channelBatchData,
×
5451
                )
×
5452
                if err != nil {
×
5453
                        return nil, nil, err
×
5454
                }
×
5455

5456
                edges = append(edges, info)
×
5457
        }
5458

5459
        return edges, channelIDs, nil
×
5460
}
5461

5462
// handleZombieMarking is a helper function that handles the logic of
5463
// marking a channel as a zombie in the database. It takes into account whether
5464
// we are in strict zombie pruning mode, and adjusts the node public keys
5465
// accordingly based on the last update timestamps of the channel policies.
5466
func handleZombieMarking(ctx context.Context, db SQLQueries,
5467
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
5468
        strictZombiePruning bool, scid uint64) error {
×
5469

×
5470
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
5471

×
5472
        if strictZombiePruning {
×
5473
                var e1UpdateTime, e2UpdateTime *time.Time
×
5474
                if row.Policy1LastUpdate.Valid {
×
5475
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
5476
                        e1UpdateTime = &e1Time
×
5477
                }
×
5478
                if row.Policy2LastUpdate.Valid {
×
5479
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
5480
                        e2UpdateTime = &e2Time
×
5481
                }
×
5482

5483
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
5484
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
5485
                        e2UpdateTime,
×
5486
                )
×
5487
        }
5488

5489
        return db.UpsertZombieChannel(
×
5490
                ctx, sqlc.UpsertZombieChannelParams{
×
5491
                        Version:  int16(ProtocolV1),
×
5492
                        Scid:     channelIDToBytes(scid),
×
5493
                        NodeKey1: nodeKey1[:],
×
5494
                        NodeKey2: nodeKey2[:],
×
5495
                },
×
5496
        )
×
5497
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc