• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 17411742737

02 Sep 2025 05:51PM UTC coverage: 66.659% (-0.01%) from 66.672%
17411742737

Pull #9455

github

web-flow
Merge 7d084eeb2 into f51684dd0
Pull Request #9455: [1/2] discovery+lnwire: add support for DNS host name in NodeAnnouncement msg

115 of 187 new or added lines in 8 files covered. (61.5%)

85 existing lines in 14 files now uncovered.

136115 of 204197 relevant lines covered (66.66%)

21456.31 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// ProtocolVersion is an enum that defines the gossip protocol version of a
35
// message.
36
type ProtocolVersion uint8
37

38
const (
39
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
40
        ProtocolV1 ProtocolVersion = 1
41
)
42

43
// String returns a string representation of the protocol version.
44
func (v ProtocolVersion) String() string {
×
45
        return fmt.Sprintf("V%d", v)
×
46
}
×
47

48
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
49
// execute queries against the SQL graph tables.
50
//
51
//nolint:ll,interfacebloat
52
type SQLQueries interface {
53
        /*
54
                Node queries.
55
        */
56
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
57
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
58
        GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
59
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
60
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
61
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
62
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
63
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
64
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
65
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
66
        DeleteNode(ctx context.Context, id int64) error
67

68
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
69
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
70
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
71
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
72

73
        UpsertNodeAddress(ctx context.Context, arg sqlc.UpsertNodeAddressParams) error
74
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
75
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
76
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
77

78
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
79
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
80
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
96
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
97
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
98
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
99
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
100
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
101
        GetChannelsByIDs(ctx context.Context, ids []int64) ([]sqlc.GetChannelsByIDsRow, error)
102
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
103
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
104
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
105
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
106
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
107
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
108
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
109
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
110
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
111
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
112
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
113
        DeleteChannels(ctx context.Context, ids []int64) error
114

115
        UpsertChannelExtraType(ctx context.Context, arg sqlc.UpsertChannelExtraTypeParams) error
116
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
117
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
118
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
119

120
        /*
121
                Channel Policy table queries.
122
        */
123
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
124
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
125
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
126

127
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
128
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
129
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
130

131
        /*
132
                Zombie index queries.
133
        */
134
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
135
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
136
        GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error)
137
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
138
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
139
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
140

141
        /*
142
                Prune log table queries.
143
        */
144
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
145
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
146
        GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]sqlc.GraphPruneLog, error)
147
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
148
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
149

150
        /*
151
                Closed SCID table queries.
152
        */
153
        InsertClosedChannel(ctx context.Context, scid []byte) error
154
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
155
        GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
156

157
        /*
158
                Migration specific queries.
159

160
                NOTE: these should not be used in code other than migrations.
161
                Once sqldbv2 is in place, these can be removed from this struct
162
                as then migrations will have their own dedicated queries
163
                structs.
164
        */
165
        InsertNodeMig(ctx context.Context, arg sqlc.InsertNodeMigParams) (int64, error)
166
        InsertChannelMig(ctx context.Context, arg sqlc.InsertChannelMigParams) (int64, error)
167
        InsertEdgePolicyMig(ctx context.Context, arg sqlc.InsertEdgePolicyMigParams) (int64, error)
168
}
169

170
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
171
// database operations.
172
type BatchedSQLQueries interface {
173
        SQLQueries
174
        sqldb.BatchedTx[SQLQueries]
175
}
176

177
// SQLStore is an implementation of the V1Store interface that uses a SQL
178
// database as the backend.
179
type SQLStore struct {
180
        cfg *SQLStoreConfig
181
        db  BatchedSQLQueries
182

183
        // cacheMu guards all caches (rejectCache and chanCache). If
184
        // this mutex will be acquired at the same time as the DB mutex then
185
        // the cacheMu MUST be acquired first to prevent deadlock.
186
        cacheMu     sync.RWMutex
187
        rejectCache *rejectCache
188
        chanCache   *channelCache
189

190
        chanScheduler batch.Scheduler[SQLQueries]
191
        nodeScheduler batch.Scheduler[SQLQueries]
192

193
        srcNodes  map[ProtocolVersion]*srcNodeInfo
194
        srcNodeMu sync.Mutex
195
}
196

197
// A compile-time assertion to ensure that SQLStore implements the V1Store
198
// interface.
199
var _ V1Store = (*SQLStore)(nil)
200

201
// SQLStoreConfig holds the configuration for the SQLStore.
202
type SQLStoreConfig struct {
203
        // ChainHash is the genesis hash for the chain that all the gossip
204
        // messages in this store are aimed at.
205
        ChainHash chainhash.Hash
206

207
        // QueryConfig holds configuration values for SQL queries.
208
        QueryCfg *sqldb.QueryConfig
209
}
210

211
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
212
// storage backend.
213
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
214
        options ...StoreOptionModifier) (*SQLStore, error) {
×
215

×
216
        opts := DefaultOptions()
×
217
        for _, o := range options {
×
218
                o(opts)
×
219
        }
×
220

221
        if opts.NoMigration {
×
222
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
223
                        "supported for SQL stores")
×
224
        }
×
225

226
        s := &SQLStore{
×
227
                cfg:         cfg,
×
228
                db:          db,
×
229
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
230
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
231
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
232
        }
×
233

×
234
        s.chanScheduler = batch.NewTimeScheduler(
×
235
                db, &s.cacheMu, opts.BatchCommitInterval,
×
236
        )
×
237
        s.nodeScheduler = batch.NewTimeScheduler(
×
238
                db, nil, opts.BatchCommitInterval,
×
239
        )
×
240

×
241
        return s, nil
×
242
}
243

244
// AddLightningNode adds a vertex/node to the graph database. If the node is not
245
// in the database from before, this will add a new, unconnected one to the
246
// graph. If it is present from before, this will update that node's
247
// information.
248
//
249
// NOTE: part of the V1Store interface.
250
func (s *SQLStore) AddLightningNode(ctx context.Context,
251
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
252

×
253
        r := &batch.Request[SQLQueries]{
×
254
                Opts: batch.NewSchedulerOptions(opts...),
×
255
                Do: func(queries SQLQueries) error {
×
256
                        _, err := upsertNode(ctx, queries, node)
×
257
                        return err
×
258
                },
×
259
        }
260

261
        return s.nodeScheduler.Execute(ctx, r)
×
262
}
263

264
// FetchLightningNode attempts to look up a target node by its identity public
265
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
266
// returned.
267
//
268
// NOTE: part of the V1Store interface.
269
func (s *SQLStore) FetchLightningNode(ctx context.Context,
270
        pubKey route.Vertex) (*models.LightningNode, error) {
×
271

×
272
        var node *models.LightningNode
×
273
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
274
                var err error
×
275
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, pubKey)
×
276

×
277
                return err
×
278
        }, sqldb.NoOpReset)
×
279
        if err != nil {
×
280
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
281
        }
×
282

283
        return node, nil
×
284
}
285

286
// HasLightningNode determines if the graph has a vertex identified by the
287
// target node identity public key. If the node exists in the database, a
288
// timestamp of when the data for the node was lasted updated is returned along
289
// with a true boolean. Otherwise, an empty time.Time is returned with a false
290
// boolean.
291
//
292
// NOTE: part of the V1Store interface.
293
func (s *SQLStore) HasLightningNode(ctx context.Context,
294
        pubKey [33]byte) (time.Time, bool, error) {
×
295

×
296
        var (
×
297
                exists     bool
×
298
                lastUpdate time.Time
×
299
        )
×
300
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
301
                dbNode, err := db.GetNodeByPubKey(
×
302
                        ctx, sqlc.GetNodeByPubKeyParams{
×
303
                                Version: int16(ProtocolV1),
×
304
                                PubKey:  pubKey[:],
×
305
                        },
×
306
                )
×
307
                if errors.Is(err, sql.ErrNoRows) {
×
308
                        return nil
×
309
                } else if err != nil {
×
310
                        return fmt.Errorf("unable to fetch node: %w", err)
×
311
                }
×
312

313
                exists = true
×
314

×
315
                if dbNode.LastUpdate.Valid {
×
316
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
317
                }
×
318

319
                return nil
×
320
        }, sqldb.NoOpReset)
321
        if err != nil {
×
322
                return time.Time{}, false,
×
323
                        fmt.Errorf("unable to fetch node: %w", err)
×
324
        }
×
325

326
        return lastUpdate, exists, nil
×
327
}
328

329
// AddrsForNode returns all known addresses for the target node public key
330
// that the graph DB is aware of. The returned boolean indicates if the
331
// given node is unknown to the graph DB or not.
332
//
333
// NOTE: part of the V1Store interface.
334
func (s *SQLStore) AddrsForNode(ctx context.Context,
335
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
336

×
337
        var (
×
338
                addresses []net.Addr
×
339
                known     bool
×
340
        )
×
341
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
342
                // First, check if the node exists and get its DB ID if it
×
343
                // does.
×
344
                dbID, err := db.GetNodeIDByPubKey(
×
345
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
346
                                Version: int16(ProtocolV1),
×
347
                                PubKey:  nodePub.SerializeCompressed(),
×
348
                        },
×
349
                )
×
350
                if errors.Is(err, sql.ErrNoRows) {
×
351
                        return nil
×
352
                }
×
353

354
                known = true
×
355

×
356
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
357
                if err != nil {
×
358
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
359
                                err)
×
360
                }
×
361

362
                return nil
×
363
        }, sqldb.NoOpReset)
364
        if err != nil {
×
365
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
366
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
367
        }
×
368

369
        return known, addresses, nil
×
370
}
371

372
// DeleteLightningNode starts a new database transaction to remove a vertex/node
373
// from the database according to the node's public key.
374
//
375
// NOTE: part of the V1Store interface.
376
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
377
        pubKey route.Vertex) error {
×
378

×
379
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
380
                res, err := db.DeleteNodeByPubKey(
×
381
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
382
                                Version: int16(ProtocolV1),
×
383
                                PubKey:  pubKey[:],
×
384
                        },
×
385
                )
×
386
                if err != nil {
×
387
                        return err
×
388
                }
×
389

390
                rows, err := res.RowsAffected()
×
391
                if err != nil {
×
392
                        return err
×
393
                }
×
394

395
                if rows == 0 {
×
396
                        return ErrGraphNodeNotFound
×
397
                } else if rows > 1 {
×
398
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
399
                }
×
400

401
                return err
×
402
        }, sqldb.NoOpReset)
403
        if err != nil {
×
404
                return fmt.Errorf("unable to delete node: %w", err)
×
405
        }
×
406

407
        return nil
×
408
}
409

410
// FetchNodeFeatures returns the features of the given node. If no features are
411
// known for the node, an empty feature vector is returned.
412
//
413
// NOTE: this is part of the graphdb.NodeTraverser interface.
414
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
415
        *lnwire.FeatureVector, error) {
×
416

×
417
        ctx := context.TODO()
×
418

×
419
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
420
}
×
421

422
// DisabledChannelIDs returns the channel ids of disabled channels.
423
// A channel is disabled when two of the associated ChanelEdgePolicies
424
// have their disabled bit on.
425
//
426
// NOTE: part of the V1Store interface.
427
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
428
        var (
×
429
                ctx     = context.TODO()
×
430
                chanIDs []uint64
×
431
        )
×
432
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
433
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
434
                if err != nil {
×
435
                        return fmt.Errorf("unable to fetch disabled "+
×
436
                                "channels: %w", err)
×
437
                }
×
438

439
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
440

×
441
                return nil
×
442
        }, sqldb.NoOpReset)
443
        if err != nil {
×
444
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
445
                        err)
×
446
        }
×
447

448
        return chanIDs, nil
×
449
}
450

451
// LookupAlias attempts to return the alias as advertised by the target node.
452
//
453
// NOTE: part of the V1Store interface.
454
func (s *SQLStore) LookupAlias(ctx context.Context,
455
        pub *btcec.PublicKey) (string, error) {
×
456

×
457
        var alias string
×
458
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
459
                dbNode, err := db.GetNodeByPubKey(
×
460
                        ctx, sqlc.GetNodeByPubKeyParams{
×
461
                                Version: int16(ProtocolV1),
×
462
                                PubKey:  pub.SerializeCompressed(),
×
463
                        },
×
464
                )
×
465
                if errors.Is(err, sql.ErrNoRows) {
×
466
                        return ErrNodeAliasNotFound
×
467
                } else if err != nil {
×
468
                        return fmt.Errorf("unable to fetch node: %w", err)
×
469
                }
×
470

471
                if !dbNode.Alias.Valid {
×
472
                        return ErrNodeAliasNotFound
×
473
                }
×
474

475
                alias = dbNode.Alias.String
×
476

×
477
                return nil
×
478
        }, sqldb.NoOpReset)
479
        if err != nil {
×
480
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
481
        }
×
482

483
        return alias, nil
×
484
}
485

486
// SourceNode returns the source node of the graph. The source node is treated
487
// as the center node within a star-graph. This method may be used to kick off
488
// a path finding algorithm in order to explore the reachability of another
489
// node based off the source node.
490
//
491
// NOTE: part of the V1Store interface.
492
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
493
        error) {
×
494

×
495
        var node *models.LightningNode
×
496
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
497
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
498
                if err != nil {
×
499
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
500
                                err)
×
501
                }
×
502

503
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, nodePub)
×
504

×
505
                return err
×
506
        }, sqldb.NoOpReset)
507
        if err != nil {
×
508
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
509
        }
×
510

511
        return node, nil
×
512
}
513

514
// SetSourceNode sets the source node within the graph database. The source
515
// node is to be used as the center of a star-graph within path finding
516
// algorithms.
517
//
518
// NOTE: part of the V1Store interface.
519
func (s *SQLStore) SetSourceNode(ctx context.Context,
520
        node *models.LightningNode) error {
×
521

×
522
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
523
                id, err := upsertNode(ctx, db, node)
×
524
                if err != nil {
×
525
                        return fmt.Errorf("unable to upsert source node: %w",
×
526
                                err)
×
527
                }
×
528

529
                // Make sure that if a source node for this version is already
530
                // set, then the ID is the same as the one we are about to set.
531
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
532
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
533
                        return fmt.Errorf("unable to fetch source node: %w",
×
534
                                err)
×
535
                } else if err == nil {
×
536
                        if dbSourceNodeID != id {
×
537
                                return fmt.Errorf("v1 source node already "+
×
538
                                        "set to a different node: %d vs %d",
×
539
                                        dbSourceNodeID, id)
×
540
                        }
×
541

542
                        return nil
×
543
                }
544

545
                return db.AddSourceNode(ctx, id)
×
546
        }, sqldb.NoOpReset)
547
}
548

549
// NodeUpdatesInHorizon returns all the known lightning node which have an
550
// update timestamp within the passed range. This method can be used by two
551
// nodes to quickly determine if they have the same set of up to date node
552
// announcements.
553
//
554
// NOTE: This is part of the V1Store interface.
555
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
556
        endTime time.Time) ([]models.LightningNode, error) {
×
557

×
558
        ctx := context.TODO()
×
559

×
560
        var nodes []models.LightningNode
×
561
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
562
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
563
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
564
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
565
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
566
                        },
×
567
                )
×
568
                if err != nil {
×
569
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
570
                }
×
571

572
                err = forEachNodeInBatch(
×
573
                        ctx, s.cfg.QueryCfg, db, dbNodes,
×
574
                        func(_ int64, node *models.LightningNode) error {
×
575
                                nodes = append(nodes, *node)
×
576

×
577
                                return nil
×
578
                        },
×
579
                )
580
                if err != nil {
×
581
                        return fmt.Errorf("unable to build nodes: %w", err)
×
582
                }
×
583

584
                return nil
×
585
        }, sqldb.NoOpReset)
586
        if err != nil {
×
587
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
588
        }
×
589

590
        return nodes, nil
×
591
}
592

593
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
594
// undirected edge from the two target nodes are created. The information stored
595
// denotes the static attributes of the channel, such as the channelID, the keys
596
// involved in creation of the channel, and the set of features that the channel
597
// supports. The chanPoint and chanID are used to uniquely identify the edge
598
// globally within the database.
599
//
600
// NOTE: part of the V1Store interface.
601
func (s *SQLStore) AddChannelEdge(ctx context.Context,
602
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
603

×
604
        var alreadyExists bool
×
605
        r := &batch.Request[SQLQueries]{
×
606
                Opts: batch.NewSchedulerOptions(opts...),
×
607
                Reset: func() {
×
608
                        alreadyExists = false
×
609
                },
×
610
                Do: func(tx SQLQueries) error {
×
611
                        chanIDB := channelIDToBytes(edge.ChannelID)
×
612

×
613
                        // Make sure that the channel doesn't already exist. We
×
614
                        // do this explicitly instead of relying on catching a
×
615
                        // unique constraint error because relying on SQL to
×
616
                        // throw that error would abort the entire batch of
×
617
                        // transactions.
×
618
                        _, err := tx.GetChannelBySCID(
×
619
                                ctx, sqlc.GetChannelBySCIDParams{
×
620
                                        Scid:    chanIDB,
×
621
                                        Version: int16(ProtocolV1),
×
622
                                },
×
623
                        )
×
624
                        if err == nil {
×
625
                                alreadyExists = true
×
626
                                return nil
×
627
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
628
                                return fmt.Errorf("unable to fetch channel: %w",
×
629
                                        err)
×
630
                        }
×
631

632
                        return insertChannel(ctx, tx, edge)
×
633
                },
634
                OnCommit: func(err error) error {
×
635
                        switch {
×
636
                        case err != nil:
×
637
                                return err
×
638
                        case alreadyExists:
×
639
                                return ErrEdgeAlreadyExist
×
640
                        default:
×
641
                                s.rejectCache.remove(edge.ChannelID)
×
642
                                s.chanCache.remove(edge.ChannelID)
×
643
                                return nil
×
644
                        }
645
                },
646
        }
647

648
        return s.chanScheduler.Execute(ctx, r)
×
649
}
650

651
// HighestChanID returns the "highest" known channel ID in the channel graph.
652
// This represents the "newest" channel from the PoV of the chain. This method
653
// can be used by peers to quickly determine if their graphs are in sync.
654
//
655
// NOTE: This is part of the V1Store interface.
656
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
657
        var highestChanID uint64
×
658
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
659
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
660
                if errors.Is(err, sql.ErrNoRows) {
×
661
                        return nil
×
662
                } else if err != nil {
×
663
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
664
                                err)
×
665
                }
×
666

667
                highestChanID = byteOrder.Uint64(chanID)
×
668

×
669
                return nil
×
670
        }, sqldb.NoOpReset)
671
        if err != nil {
×
672
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
673
        }
×
674

675
        return highestChanID, nil
×
676
}
677

678
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
679
// within the database for the referenced channel. The `flags` attribute within
680
// the ChannelEdgePolicy determines which of the directed edges are being
681
// updated. If the flag is 1, then the first node's information is being
682
// updated, otherwise it's the second node's information. The node ordering is
683
// determined by the lexicographical ordering of the identity public keys of the
684
// nodes on either side of the channel.
685
//
686
// NOTE: part of the V1Store interface.
687
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
688
        edge *models.ChannelEdgePolicy,
689
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
690

×
691
        var (
×
692
                isUpdate1    bool
×
693
                edgeNotFound bool
×
694
                from, to     route.Vertex
×
695
        )
×
696

×
697
        r := &batch.Request[SQLQueries]{
×
698
                Opts: batch.NewSchedulerOptions(opts...),
×
699
                Reset: func() {
×
700
                        isUpdate1 = false
×
701
                        edgeNotFound = false
×
702
                },
×
703
                Do: func(tx SQLQueries) error {
×
704
                        var err error
×
705
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
706
                                ctx, tx, edge,
×
707
                        )
×
708
                        if err != nil {
×
709
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
710
                        }
×
711

712
                        // Silence ErrEdgeNotFound so that the batch can
713
                        // succeed, but propagate the error via local state.
714
                        if errors.Is(err, ErrEdgeNotFound) {
×
715
                                edgeNotFound = true
×
716
                                return nil
×
717
                        }
×
718

719
                        return err
×
720
                },
721
                OnCommit: func(err error) error {
×
722
                        switch {
×
723
                        case err != nil:
×
724
                                return err
×
725
                        case edgeNotFound:
×
726
                                return ErrEdgeNotFound
×
727
                        default:
×
728
                                s.updateEdgeCache(edge, isUpdate1)
×
729
                                return nil
×
730
                        }
731
                },
732
        }
733

734
        err := s.chanScheduler.Execute(ctx, r)
×
735

×
736
        return from, to, err
×
737
}
738

739
// updateEdgeCache updates our reject and channel caches with the new
740
// edge policy information.
741
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
742
        isUpdate1 bool) {
×
743

×
744
        // If an entry for this channel is found in reject cache, we'll modify
×
745
        // the entry with the updated timestamp for the direction that was just
×
746
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
747
        // during the next query for this edge.
×
748
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
749
                if isUpdate1 {
×
750
                        entry.upd1Time = e.LastUpdate.Unix()
×
751
                } else {
×
752
                        entry.upd2Time = e.LastUpdate.Unix()
×
753
                }
×
754
                s.rejectCache.insert(e.ChannelID, entry)
×
755
        }
756

757
        // If an entry for this channel is found in channel cache, we'll modify
758
        // the entry with the updated policy for the direction that was just
759
        // written. If the edge doesn't exist, we'll defer loading the info and
760
        // policies and lazily read from disk during the next query.
761
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
762
                if isUpdate1 {
×
763
                        channel.Policy1 = e
×
764
                } else {
×
765
                        channel.Policy2 = e
×
766
                }
×
767
                s.chanCache.insert(e.ChannelID, channel)
×
768
        }
769
}
770

771
// ForEachSourceNodeChannel iterates through all channels of the source node,
772
// executing the passed callback on each. The call-back is provided with the
773
// channel's outpoint, whether we have a policy for the channel and the channel
774
// peer's node information.
775
//
776
// NOTE: part of the V1Store interface.
777
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
778
        cb func(chanPoint wire.OutPoint, havePolicy bool,
779
                otherNode *models.LightningNode) error, reset func()) error {
×
780

×
781
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
782
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
783
                if err != nil {
×
784
                        return fmt.Errorf("unable to fetch source node: %w",
×
785
                                err)
×
786
                }
×
787

788
                return forEachNodeChannel(
×
789
                        ctx, db, s.cfg, nodeID,
×
790
                        func(info *models.ChannelEdgeInfo,
×
791
                                outPolicy *models.ChannelEdgePolicy,
×
792
                                _ *models.ChannelEdgePolicy) error {
×
793

×
794
                                // Fetch the other node.
×
795
                                var (
×
796
                                        otherNodePub [33]byte
×
797
                                        node1        = info.NodeKey1Bytes
×
798
                                        node2        = info.NodeKey2Bytes
×
799
                                )
×
800
                                switch {
×
801
                                case bytes.Equal(node1[:], nodePub[:]):
×
802
                                        otherNodePub = node2
×
803
                                case bytes.Equal(node2[:], nodePub[:]):
×
804
                                        otherNodePub = node1
×
805
                                default:
×
806
                                        return fmt.Errorf("node not " +
×
807
                                                "participating in this channel")
×
808
                                }
809

810
                                _, otherNode, err := getNodeByPubKey(
×
811
                                        ctx, s.cfg.QueryCfg, db, otherNodePub,
×
812
                                )
×
813
                                if err != nil {
×
814
                                        return fmt.Errorf("unable to fetch "+
×
815
                                                "other node(%x): %w",
×
816
                                                otherNodePub, err)
×
817
                                }
×
818

819
                                return cb(
×
820
                                        info.ChannelPoint, outPolicy != nil,
×
821
                                        otherNode,
×
822
                                )
×
823
                        },
824
                )
825
        }, reset)
826
}
827

828
// ForEachNode iterates through all the stored vertices/nodes in the graph,
829
// executing the passed callback with each node encountered. If the callback
830
// returns an error, then the transaction is aborted and the iteration stops
831
// early.
832
//
833
// NOTE: part of the V1Store interface.
834
func (s *SQLStore) ForEachNode(ctx context.Context,
835
        cb func(node *models.LightningNode) error, reset func()) error {
×
836

×
837
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
838
                return forEachNodePaginated(
×
839
                        ctx, s.cfg.QueryCfg, db,
×
840
                        ProtocolV1, func(_ context.Context, _ int64,
×
841
                                node *models.LightningNode) error {
×
842

×
843
                                return cb(node)
×
844
                        },
×
845
                )
846
        }, reset)
847
}
848

849
// ForEachNodeDirectedChannel iterates through all channels of a given node,
850
// executing the passed callback on the directed edge representing the channel
851
// and its incoming policy. If the callback returns an error, then the iteration
852
// is halted with the error propagated back up to the caller.
853
//
854
// Unknown policies are passed into the callback as nil values.
855
//
856
// NOTE: this is part of the graphdb.NodeTraverser interface.
857
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
858
        cb func(channel *DirectedChannel) error, reset func()) error {
×
859

×
860
        var ctx = context.TODO()
×
861

×
862
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
863
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
864
        }, reset)
×
865
}
866

867
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
868
// graph, executing the passed callback with each node encountered. If the
869
// callback returns an error, then the transaction is aborted and the iteration
870
// stops early.
871
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
872
        cb func(route.Vertex, *lnwire.FeatureVector) error,
873
        reset func()) error {
×
874

×
875
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
876
                return forEachNodeCacheable(
×
877
                        ctx, s.cfg.QueryCfg, db,
×
878
                        func(_ int64, nodePub route.Vertex,
×
879
                                features *lnwire.FeatureVector) error {
×
880

×
881
                                return cb(nodePub, features)
×
882
                        },
×
883
                )
884
        }, reset)
885
        if err != nil {
×
886
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
887
        }
×
888

889
        return nil
×
890
}
891

892
// ForEachNodeChannel iterates through all channels of the given node,
893
// executing the passed callback with an edge info structure and the policies
894
// of each end of the channel. The first edge policy is the outgoing edge *to*
895
// the connecting node, while the second is the incoming edge *from* the
896
// connecting node. If the callback returns an error, then the iteration is
897
// halted with the error propagated back up to the caller.
898
//
899
// Unknown policies are passed into the callback as nil values.
900
//
901
// NOTE: part of the V1Store interface.
902
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
903
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
904
                *models.ChannelEdgePolicy) error, reset func()) error {
×
905

×
906
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
907
                dbNode, err := db.GetNodeByPubKey(
×
908
                        ctx, sqlc.GetNodeByPubKeyParams{
×
909
                                Version: int16(ProtocolV1),
×
910
                                PubKey:  nodePub[:],
×
911
                        },
×
912
                )
×
913
                if errors.Is(err, sql.ErrNoRows) {
×
914
                        return nil
×
915
                } else if err != nil {
×
916
                        return fmt.Errorf("unable to fetch node: %w", err)
×
917
                }
×
918

919
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
920
        }, reset)
921
}
922

923
// ChanUpdatesInHorizon returns all the known channel edges which have at least
924
// one edge that has an update timestamp within the specified horizon.
925
//
926
// NOTE: This is part of the V1Store interface.
927
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
928
        endTime time.Time) ([]ChannelEdge, error) {
×
929

×
930
        s.cacheMu.Lock()
×
931
        defer s.cacheMu.Unlock()
×
932

×
933
        var (
×
934
                ctx = context.TODO()
×
935
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
936
                // an additional map to keep track of the edges already seen to
×
937
                // prevent re-adding it.
×
938
                edgesSeen    = make(map[uint64]struct{})
×
939
                edgesToCache = make(map[uint64]ChannelEdge)
×
940
                edges        []ChannelEdge
×
941
                hits         int
×
942
        )
×
943

×
944
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
945
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
946
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
947
                                Version:   int16(ProtocolV1),
×
948
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
949
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
950
                        },
×
951
                )
×
952
                if err != nil {
×
953
                        return err
×
954
                }
×
955

956
                if len(rows) == 0 {
×
957
                        return nil
×
958
                }
×
959

960
                // We'll pre-allocate the slices and maps here with a best
961
                // effort size in order to avoid unnecessary allocations later
962
                // on.
963
                uncachedRows := make(
×
964
                        []sqlc.GetChannelsByPolicyLastUpdateRangeRow, 0,
×
965
                        len(rows),
×
966
                )
×
967
                edgesToCache = make(map[uint64]ChannelEdge, len(rows))
×
968
                edgesSeen = make(map[uint64]struct{}, len(rows))
×
969
                edges = make([]ChannelEdge, 0, len(rows))
×
970

×
971
                // Separate cached from non-cached channels since we will only
×
972
                // batch load the data for the ones we haven't cached yet.
×
973
                for _, row := range rows {
×
974
                        chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid)
×
975

×
976
                        // Skip duplicates.
×
977
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
978
                                continue
×
979
                        }
980
                        edgesSeen[chanIDInt] = struct{}{}
×
981

×
982
                        // Check cache first.
×
983
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
984
                                hits++
×
985
                                edges = append(edges, channel)
×
986
                                continue
×
987
                        }
988

989
                        // Mark this row as one we need to batch load data for.
990
                        uncachedRows = append(uncachedRows, row)
×
991
                }
992

993
                // If there are no uncached rows, then we can return early.
994
                if len(uncachedRows) == 0 {
×
995
                        return nil
×
996
                }
×
997

998
                // Batch load data for all uncached channels.
999
                newEdges, err := batchBuildChannelEdges(
×
1000
                        ctx, s.cfg, db, uncachedRows,
×
1001
                )
×
1002
                if err != nil {
×
1003
                        return fmt.Errorf("unable to batch build channel "+
×
1004
                                "edges: %w", err)
×
1005
                }
×
1006

1007
                edges = append(edges, newEdges...)
×
1008

×
1009
                return nil
×
1010
        }, sqldb.NoOpReset)
1011
        if err != nil {
×
1012
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1013
        }
×
1014

1015
        // Insert any edges loaded from disk into the cache.
1016
        for chanid, channel := range edgesToCache {
×
1017
                s.chanCache.insert(chanid, channel)
×
1018
        }
×
1019

1020
        if len(edges) > 0 {
×
1021
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1022
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1023
        } else {
×
1024
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1025
                        "horizon (%s, %s)", startTime, endTime)
×
1026
        }
×
1027

1028
        return edges, nil
×
1029
}
1030

1031
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1032
// data to the call-back. If withAddrs is true, then the call-back will also be
1033
// provided with the addresses associated with the node. The address retrieval
1034
// result in an additional round-trip to the database, so it should only be used
1035
// if the addresses are actually needed.
1036
//
1037
// NOTE: part of the V1Store interface.
1038
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1039
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1040
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1041

×
1042
        type nodeCachedBatchData struct {
×
1043
                features      map[int64][]int
×
1044
                addrs         map[int64][]nodeAddress
×
1045
                chanBatchData *batchChannelData
×
1046
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1047
        }
×
1048

×
1049
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1050
                // pageQueryFunc is used to query the next page of nodes.
×
1051
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1052
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1053

×
1054
                        return db.ListNodeIDsAndPubKeys(
×
1055
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1056
                                        Version: int16(ProtocolV1),
×
1057
                                        ID:      lastID,
×
1058
                                        Limit:   limit,
×
1059
                                },
×
1060
                        )
×
1061
                }
×
1062

1063
                // batchDataFunc is then used to batch load the data required
1064
                // for each page of nodes.
1065
                batchDataFunc := func(ctx context.Context,
×
1066
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1067

×
1068
                        // Batch load node features.
×
1069
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1070
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1071
                        )
×
1072
                        if err != nil {
×
1073
                                return nil, fmt.Errorf("unable to batch load "+
×
1074
                                        "node features: %w", err)
×
1075
                        }
×
1076

1077
                        // Maybe fetch the node's addresses if requested.
1078
                        var nodeAddrs map[int64][]nodeAddress
×
1079
                        if withAddrs {
×
1080
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1081
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1082
                                )
×
1083
                                if err != nil {
×
1084
                                        return nil, fmt.Errorf("unable to "+
×
1085
                                                "batch load node "+
×
1086
                                                "addresses: %w", err)
×
1087
                                }
×
1088
                        }
1089

1090
                        // Batch load ALL unique channels for ALL nodes in this
1091
                        // page.
1092
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1093
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1094
                                        Version:  int16(ProtocolV1),
×
1095
                                        Node1Ids: nodeIDs,
×
1096
                                        Node2Ids: nodeIDs,
×
1097
                                },
×
1098
                        )
×
1099
                        if err != nil {
×
1100
                                return nil, fmt.Errorf("unable to batch "+
×
1101
                                        "fetch channels for nodes: %w", err)
×
1102
                        }
×
1103

1104
                        // Deduplicate channels and collect IDs.
1105
                        var (
×
1106
                                allChannelIDs []int64
×
1107
                                allPolicyIDs  []int64
×
1108
                        )
×
1109
                        uniqueChannels := make(
×
1110
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1111
                        )
×
1112

×
1113
                        for _, channel := range allChannels {
×
1114
                                channelID := channel.GraphChannel.ID
×
1115

×
1116
                                // Only process each unique channel once.
×
1117
                                _, exists := uniqueChannels[channelID]
×
1118
                                if exists {
×
1119
                                        continue
×
1120
                                }
1121

1122
                                uniqueChannels[channelID] = channel
×
1123
                                allChannelIDs = append(allChannelIDs, channelID)
×
1124

×
1125
                                if channel.Policy1ID.Valid {
×
1126
                                        allPolicyIDs = append(
×
1127
                                                allPolicyIDs,
×
1128
                                                channel.Policy1ID.Int64,
×
1129
                                        )
×
1130
                                }
×
1131
                                if channel.Policy2ID.Valid {
×
1132
                                        allPolicyIDs = append(
×
1133
                                                allPolicyIDs,
×
1134
                                                channel.Policy2ID.Int64,
×
1135
                                        )
×
1136
                                }
×
1137
                        }
1138

1139
                        // Batch load channel data for all unique channels.
1140
                        channelBatchData, err := batchLoadChannelData(
×
1141
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1142
                                allPolicyIDs,
×
1143
                        )
×
1144
                        if err != nil {
×
1145
                                return nil, fmt.Errorf("unable to batch "+
×
1146
                                        "load channel data: %w", err)
×
1147
                        }
×
1148

1149
                        // Create map of node ID to channels that involve this
1150
                        // node.
1151
                        nodeIDSet := make(map[int64]bool)
×
1152
                        for _, nodeID := range nodeIDs {
×
1153
                                nodeIDSet[nodeID] = true
×
1154
                        }
×
1155

1156
                        nodeChannelMap := make(
×
1157
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1158
                        )
×
1159
                        for _, channel := range uniqueChannels {
×
1160
                                // Add channel to both nodes if they're in our
×
1161
                                // current page.
×
1162
                                node1 := channel.GraphChannel.NodeID1
×
1163
                                if nodeIDSet[node1] {
×
1164
                                        nodeChannelMap[node1] = append(
×
1165
                                                nodeChannelMap[node1], channel,
×
1166
                                        )
×
1167
                                }
×
1168
                                node2 := channel.GraphChannel.NodeID2
×
1169
                                if nodeIDSet[node2] {
×
1170
                                        nodeChannelMap[node2] = append(
×
1171
                                                nodeChannelMap[node2], channel,
×
1172
                                        )
×
1173
                                }
×
1174
                        }
1175

1176
                        return &nodeCachedBatchData{
×
1177
                                features:      nodeFeatures,
×
1178
                                addrs:         nodeAddrs,
×
1179
                                chanBatchData: channelBatchData,
×
1180
                                chanMap:       nodeChannelMap,
×
1181
                        }, nil
×
1182
                }
1183

1184
                // processItem is used to process each node in the current page.
1185
                processItem := func(ctx context.Context,
×
1186
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1187
                        batchData *nodeCachedBatchData) error {
×
1188

×
1189
                        // Build feature vector for this node.
×
1190
                        fv := lnwire.EmptyFeatureVector()
×
1191
                        features, exists := batchData.features[nodeData.ID]
×
1192
                        if exists {
×
1193
                                for _, bit := range features {
×
1194
                                        fv.Set(lnwire.FeatureBit(bit))
×
1195
                                }
×
1196
                        }
1197

1198
                        var nodePub route.Vertex
×
1199
                        copy(nodePub[:], nodeData.PubKey)
×
1200

×
1201
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1202

×
1203
                        toNodeCallback := func() route.Vertex {
×
1204
                                return nodePub
×
1205
                        }
×
1206

1207
                        // Build cached channels map for this node.
1208
                        channels := make(map[uint64]*DirectedChannel)
×
1209
                        for _, channelRow := range nodeChannels {
×
1210
                                directedChan, err := buildDirectedChannel(
×
1211
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1212
                                        channelRow, batchData.chanBatchData, fv,
×
1213
                                        toNodeCallback,
×
1214
                                )
×
1215
                                if err != nil {
×
1216
                                        return err
×
1217
                                }
×
1218

1219
                                channels[directedChan.ChannelID] = directedChan
×
1220
                        }
1221

1222
                        addrs, err := buildNodeAddresses(
×
1223
                                batchData.addrs[nodeData.ID],
×
1224
                        )
×
1225
                        if err != nil {
×
1226
                                return fmt.Errorf("unable to build node "+
×
1227
                                        "addresses: %w", err)
×
1228
                        }
×
1229

1230
                        return cb(ctx, nodePub, addrs, channels)
×
1231
                }
1232

1233
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1234
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1235
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1236
                                return node.ID
×
1237
                        },
×
1238
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1239
                                error) {
×
1240

×
1241
                                return node.ID, nil
×
1242
                        },
×
1243
                        batchDataFunc, processItem,
1244
                )
1245
        }, reset)
1246
}
1247

1248
// ForEachChannelCacheable iterates through all the channel edges stored
1249
// within the graph and invokes the passed callback for each edge. The
1250
// callback takes two edges as since this is a directed graph, both the
1251
// in/out edges are visited. If the callback returns an error, then the
1252
// transaction is aborted and the iteration stops early.
1253
//
1254
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1255
// pointer for that particular channel edge routing policy will be
1256
// passed into the callback.
1257
//
1258
// NOTE: this method is like ForEachChannel but fetches only the data
1259
// required for the graph cache.
1260
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1261
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1262
        reset func()) error {
×
1263

×
1264
        ctx := context.TODO()
×
1265

×
1266
        handleChannel := func(_ context.Context,
×
1267
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1268

×
1269
                node1, node2, err := buildNodeVertices(
×
1270
                        row.Node1Pubkey, row.Node2Pubkey,
×
1271
                )
×
1272
                if err != nil {
×
1273
                        return err
×
1274
                }
×
1275

1276
                edge := buildCacheableChannelInfo(
×
1277
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1278
                )
×
1279

×
1280
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1281
                if err != nil {
×
1282
                        return err
×
1283
                }
×
1284

1285
                pol1, pol2, err := buildCachedChanPolicies(
×
1286
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1287
                )
×
1288
                if err != nil {
×
1289
                        return err
×
1290
                }
×
1291

1292
                return cb(edge, pol1, pol2)
×
1293
        }
1294

1295
        extractCursor := func(
×
1296
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1297

×
1298
                return row.ID
×
1299
        }
×
1300

1301
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1302
                //nolint:ll
×
1303
                queryFunc := func(ctx context.Context, lastID int64,
×
1304
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1305
                        error) {
×
1306

×
1307
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1308
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1309
                                        Version: int16(ProtocolV1),
×
1310
                                        ID:      lastID,
×
1311
                                        Limit:   limit,
×
1312
                                },
×
1313
                        )
×
1314
                }
×
1315

1316
                return sqldb.ExecutePaginatedQuery(
×
1317
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1318
                        extractCursor, handleChannel,
×
1319
                )
×
1320
        }, reset)
1321
}
1322

1323
// ForEachChannel iterates through all the channel edges stored within the
1324
// graph and invokes the passed callback for each edge. The callback takes two
1325
// edges as since this is a directed graph, both the in/out edges are visited.
1326
// If the callback returns an error, then the transaction is aborted and the
1327
// iteration stops early.
1328
//
1329
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1330
// for that particular channel edge routing policy will be passed into the
1331
// callback.
1332
//
1333
// NOTE: part of the V1Store interface.
1334
func (s *SQLStore) ForEachChannel(ctx context.Context,
1335
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1336
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1337

×
1338
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1339
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1340
        }, reset)
×
1341
}
1342

1343
// FilterChannelRange returns the channel ID's of all known channels which were
1344
// mined in a block height within the passed range. The channel IDs are grouped
1345
// by their common block height. This method can be used to quickly share with a
1346
// peer the set of channels we know of within a particular range to catch them
1347
// up after a period of time offline. If withTimestamps is true then the
1348
// timestamp info of the latest received channel update messages of the channel
1349
// will be included in the response.
1350
//
1351
// NOTE: This is part of the V1Store interface.
1352
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1353
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1354

×
1355
        var (
×
1356
                ctx       = context.TODO()
×
1357
                startSCID = &lnwire.ShortChannelID{
×
1358
                        BlockHeight: startHeight,
×
1359
                }
×
1360
                endSCID = lnwire.ShortChannelID{
×
1361
                        BlockHeight: endHeight,
×
1362
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1363
                        TxPosition:  math.MaxUint16,
×
1364
                }
×
1365
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1366
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1367
        )
×
1368

×
1369
        // 1) get all channels where channelID is between start and end chan ID.
×
1370
        // 2) skip if not public (ie, no channel_proof)
×
1371
        // 3) collect that channel.
×
1372
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1373
        //    and add those timestamps to the collected channel.
×
1374
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1375
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1376
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1377
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1378
                                StartScid: chanIDStart,
×
1379
                                EndScid:   chanIDEnd,
×
1380
                        },
×
1381
                )
×
1382
                if err != nil {
×
1383
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1384
                                err)
×
1385
                }
×
1386

1387
                for _, dbChan := range dbChans {
×
1388
                        cid := lnwire.NewShortChanIDFromInt(
×
1389
                                byteOrder.Uint64(dbChan.Scid),
×
1390
                        )
×
1391
                        chanInfo := NewChannelUpdateInfo(
×
1392
                                cid, time.Time{}, time.Time{},
×
1393
                        )
×
1394

×
1395
                        if !withTimestamps {
×
1396
                                channelsPerBlock[cid.BlockHeight] = append(
×
1397
                                        channelsPerBlock[cid.BlockHeight],
×
1398
                                        chanInfo,
×
1399
                                )
×
1400

×
1401
                                continue
×
1402
                        }
1403

1404
                        //nolint:ll
1405
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1406
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1407
                                        Version:   int16(ProtocolV1),
×
1408
                                        ChannelID: dbChan.ID,
×
1409
                                        NodeID:    dbChan.NodeID1,
×
1410
                                },
×
1411
                        )
×
1412
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1413
                                return fmt.Errorf("unable to fetch node1 "+
×
1414
                                        "policy: %w", err)
×
1415
                        } else if err == nil {
×
1416
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1417
                                        node1Policy.LastUpdate.Int64, 0,
×
1418
                                )
×
1419
                        }
×
1420

1421
                        //nolint:ll
1422
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1423
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1424
                                        Version:   int16(ProtocolV1),
×
1425
                                        ChannelID: dbChan.ID,
×
1426
                                        NodeID:    dbChan.NodeID2,
×
1427
                                },
×
1428
                        )
×
1429
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1430
                                return fmt.Errorf("unable to fetch node2 "+
×
1431
                                        "policy: %w", err)
×
1432
                        } else if err == nil {
×
1433
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1434
                                        node2Policy.LastUpdate.Int64, 0,
×
1435
                                )
×
1436
                        }
×
1437

1438
                        channelsPerBlock[cid.BlockHeight] = append(
×
1439
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1440
                        )
×
1441
                }
1442

1443
                return nil
×
1444
        }, func() {
×
1445
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1446
        })
×
1447
        if err != nil {
×
1448
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1449
        }
×
1450

1451
        if len(channelsPerBlock) == 0 {
×
1452
                return nil, nil
×
1453
        }
×
1454

1455
        // Return the channel ranges in ascending block height order.
1456
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1457
        slices.Sort(blocks)
×
1458

×
1459
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1460
                return BlockChannelRange{
×
1461
                        Height:   block,
×
1462
                        Channels: channelsPerBlock[block],
×
1463
                }
×
1464
        }), nil
×
1465
}
1466

1467
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1468
// zombie. This method is used on an ad-hoc basis, when channels need to be
1469
// marked as zombies outside the normal pruning cycle.
1470
//
1471
// NOTE: part of the V1Store interface.
1472
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1473
        pubKey1, pubKey2 [33]byte) error {
×
1474

×
1475
        ctx := context.TODO()
×
1476

×
1477
        s.cacheMu.Lock()
×
1478
        defer s.cacheMu.Unlock()
×
1479

×
1480
        chanIDB := channelIDToBytes(chanID)
×
1481

×
1482
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1483
                return db.UpsertZombieChannel(
×
1484
                        ctx, sqlc.UpsertZombieChannelParams{
×
1485
                                Version:  int16(ProtocolV1),
×
1486
                                Scid:     chanIDB,
×
1487
                                NodeKey1: pubKey1[:],
×
1488
                                NodeKey2: pubKey2[:],
×
1489
                        },
×
1490
                )
×
1491
        }, sqldb.NoOpReset)
×
1492
        if err != nil {
×
1493
                return fmt.Errorf("unable to upsert zombie channel "+
×
1494
                        "(channel_id=%d): %w", chanID, err)
×
1495
        }
×
1496

1497
        s.rejectCache.remove(chanID)
×
1498
        s.chanCache.remove(chanID)
×
1499

×
1500
        return nil
×
1501
}
1502

1503
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1504
//
1505
// NOTE: part of the V1Store interface.
1506
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1507
        s.cacheMu.Lock()
×
1508
        defer s.cacheMu.Unlock()
×
1509

×
1510
        var (
×
1511
                ctx     = context.TODO()
×
1512
                chanIDB = channelIDToBytes(chanID)
×
1513
        )
×
1514

×
1515
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1516
                res, err := db.DeleteZombieChannel(
×
1517
                        ctx, sqlc.DeleteZombieChannelParams{
×
1518
                                Scid:    chanIDB,
×
1519
                                Version: int16(ProtocolV1),
×
1520
                        },
×
1521
                )
×
1522
                if err != nil {
×
1523
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1524
                                err)
×
1525
                }
×
1526

1527
                rows, err := res.RowsAffected()
×
1528
                if err != nil {
×
1529
                        return err
×
1530
                }
×
1531

1532
                if rows == 0 {
×
1533
                        return ErrZombieEdgeNotFound
×
1534
                } else if rows > 1 {
×
1535
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1536
                                "expected 1", rows)
×
1537
                }
×
1538

1539
                return nil
×
1540
        }, sqldb.NoOpReset)
1541
        if err != nil {
×
1542
                return fmt.Errorf("unable to mark edge live "+
×
1543
                        "(channel_id=%d): %w", chanID, err)
×
1544
        }
×
1545

1546
        s.rejectCache.remove(chanID)
×
1547
        s.chanCache.remove(chanID)
×
1548

×
1549
        return err
×
1550
}
1551

1552
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1553
// zombie, then the two node public keys corresponding to this edge are also
1554
// returned.
1555
//
1556
// NOTE: part of the V1Store interface.
1557
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1558
        error) {
×
1559

×
1560
        var (
×
1561
                ctx              = context.TODO()
×
1562
                isZombie         bool
×
1563
                pubKey1, pubKey2 route.Vertex
×
1564
                chanIDB          = channelIDToBytes(chanID)
×
1565
        )
×
1566

×
1567
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1568
                zombie, err := db.GetZombieChannel(
×
1569
                        ctx, sqlc.GetZombieChannelParams{
×
1570
                                Scid:    chanIDB,
×
1571
                                Version: int16(ProtocolV1),
×
1572
                        },
×
1573
                )
×
1574
                if errors.Is(err, sql.ErrNoRows) {
×
1575
                        return nil
×
1576
                }
×
1577
                if err != nil {
×
1578
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1579
                                err)
×
1580
                }
×
1581

1582
                copy(pubKey1[:], zombie.NodeKey1)
×
1583
                copy(pubKey2[:], zombie.NodeKey2)
×
1584
                isZombie = true
×
1585

×
1586
                return nil
×
1587
        }, sqldb.NoOpReset)
1588
        if err != nil {
×
1589
                return false, route.Vertex{}, route.Vertex{},
×
1590
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1591
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1592
        }
×
1593

1594
        return isZombie, pubKey1, pubKey2, nil
×
1595
}
1596

1597
// NumZombies returns the current number of zombie channels in the graph.
1598
//
1599
// NOTE: part of the V1Store interface.
1600
func (s *SQLStore) NumZombies() (uint64, error) {
×
1601
        var (
×
1602
                ctx        = context.TODO()
×
1603
                numZombies uint64
×
1604
        )
×
1605
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1606
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1607
                if err != nil {
×
1608
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1609
                                err)
×
1610
                }
×
1611

1612
                numZombies = uint64(count)
×
1613

×
1614
                return nil
×
1615
        }, sqldb.NoOpReset)
1616
        if err != nil {
×
1617
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1618
        }
×
1619

1620
        return numZombies, nil
×
1621
}
1622

1623
// DeleteChannelEdges removes edges with the given channel IDs from the
1624
// database and marks them as zombies. This ensures that we're unable to re-add
1625
// it to our database once again. If an edge does not exist within the
1626
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1627
// true, then when we mark these edges as zombies, we'll set up the keys such
1628
// that we require the node that failed to send the fresh update to be the one
1629
// that resurrects the channel from its zombie state. The markZombie bool
1630
// denotes whether to mark the channel as a zombie.
1631
//
1632
// NOTE: part of the V1Store interface.
1633
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1634
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1635

×
1636
        s.cacheMu.Lock()
×
1637
        defer s.cacheMu.Unlock()
×
1638

×
1639
        // Keep track of which channels we end up finding so that we can
×
1640
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1641
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1642
        for _, chanID := range chanIDs {
×
1643
                chanLookup[chanID] = struct{}{}
×
1644
        }
×
1645

1646
        var (
×
1647
                ctx   = context.TODO()
×
1648
                edges []*models.ChannelEdgeInfo
×
1649
        )
×
1650
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1651
                // First, collect all channel rows.
×
1652
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1653
                chanCallBack := func(ctx context.Context,
×
1654
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1655

×
1656
                        // Deleting the entry from the map indicates that we
×
1657
                        // have found the channel.
×
1658
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1659
                        delete(chanLookup, scid)
×
1660

×
1661
                        channelRows = append(channelRows, row)
×
1662

×
1663
                        return nil
×
1664
                }
×
1665

1666
                err := s.forEachChanWithPoliciesInSCIDList(
×
1667
                        ctx, db, chanCallBack, chanIDs,
×
1668
                )
×
1669
                if err != nil {
×
1670
                        return err
×
1671
                }
×
1672

1673
                if len(chanLookup) > 0 {
×
1674
                        return ErrEdgeNotFound
×
1675
                }
×
1676

1677
                if len(channelRows) == 0 {
×
1678
                        return nil
×
1679
                }
×
1680

1681
                // Batch build all channel edges.
1682
                var chanIDsToDelete []int64
×
1683
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
1684
                        ctx, s.cfg, db, channelRows,
×
1685
                )
×
1686
                if err != nil {
×
1687
                        return err
×
1688
                }
×
1689

1690
                if markZombie {
×
1691
                        for i, row := range channelRows {
×
1692
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1693

×
1694
                                err := handleZombieMarking(
×
1695
                                        ctx, db, row, edges[i],
×
1696
                                        strictZombiePruning, scid,
×
1697
                                )
×
1698
                                if err != nil {
×
1699
                                        return fmt.Errorf("unable to mark "+
×
1700
                                                "channel as zombie: %w", err)
×
1701
                                }
×
1702
                        }
1703
                }
1704

1705
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1706
        }, func() {
×
1707
                edges = nil
×
1708

×
1709
                // Re-fill the lookup map.
×
1710
                for _, chanID := range chanIDs {
×
1711
                        chanLookup[chanID] = struct{}{}
×
1712
                }
×
1713
        })
1714
        if err != nil {
×
1715
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1716
                        err)
×
1717
        }
×
1718

1719
        for _, chanID := range chanIDs {
×
1720
                s.rejectCache.remove(chanID)
×
1721
                s.chanCache.remove(chanID)
×
1722
        }
×
1723

1724
        return edges, nil
×
1725
}
1726

1727
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1728
// channel identified by the channel ID. If the channel can't be found, then
1729
// ErrEdgeNotFound is returned. A struct which houses the general information
1730
// for the channel itself is returned as well as two structs that contain the
1731
// routing policies for the channel in either direction.
1732
//
1733
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1734
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1735
// the ChannelEdgeInfo will only include the public keys of each node.
1736
//
1737
// NOTE: part of the V1Store interface.
1738
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1739
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1740
        *models.ChannelEdgePolicy, error) {
×
1741

×
1742
        var (
×
1743
                ctx              = context.TODO()
×
1744
                edge             *models.ChannelEdgeInfo
×
1745
                policy1, policy2 *models.ChannelEdgePolicy
×
1746
                chanIDB          = channelIDToBytes(chanID)
×
1747
        )
×
1748
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1749
                row, err := db.GetChannelBySCIDWithPolicies(
×
1750
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1751
                                Scid:    chanIDB,
×
1752
                                Version: int16(ProtocolV1),
×
1753
                        },
×
1754
                )
×
1755
                if errors.Is(err, sql.ErrNoRows) {
×
1756
                        // First check if this edge is perhaps in the zombie
×
1757
                        // index.
×
1758
                        zombie, err := db.GetZombieChannel(
×
1759
                                ctx, sqlc.GetZombieChannelParams{
×
1760
                                        Scid:    chanIDB,
×
1761
                                        Version: int16(ProtocolV1),
×
1762
                                },
×
1763
                        )
×
1764
                        if errors.Is(err, sql.ErrNoRows) {
×
1765
                                return ErrEdgeNotFound
×
1766
                        } else if err != nil {
×
1767
                                return fmt.Errorf("unable to check if "+
×
1768
                                        "channel is zombie: %w", err)
×
1769
                        }
×
1770

1771
                        // At this point, we know the channel is a zombie, so
1772
                        // we'll return an error indicating this, and we will
1773
                        // populate the edge info with the public keys of each
1774
                        // party as this is the only information we have about
1775
                        // it.
1776
                        edge = &models.ChannelEdgeInfo{}
×
1777
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1778
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1779

×
1780
                        return ErrZombieEdge
×
1781
                } else if err != nil {
×
1782
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1783
                }
×
1784

1785
                node1, node2, err := buildNodeVertices(
×
1786
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1787
                )
×
1788
                if err != nil {
×
1789
                        return err
×
1790
                }
×
1791

1792
                edge, err = getAndBuildEdgeInfo(
×
1793
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
1794
                )
×
1795
                if err != nil {
×
1796
                        return fmt.Errorf("unable to build channel info: %w",
×
1797
                                err)
×
1798
                }
×
1799

1800
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1801
                if err != nil {
×
1802
                        return fmt.Errorf("unable to extract channel "+
×
1803
                                "policies: %w", err)
×
1804
                }
×
1805

1806
                policy1, policy2, err = getAndBuildChanPolicies(
×
1807
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
1808
                        node1, node2,
×
1809
                )
×
1810
                if err != nil {
×
1811
                        return fmt.Errorf("unable to build channel "+
×
1812
                                "policies: %w", err)
×
1813
                }
×
1814

1815
                return nil
×
1816
        }, sqldb.NoOpReset)
1817
        if err != nil {
×
1818
                // If we are returning the ErrZombieEdge, then we also need to
×
1819
                // return the edge info as the method comment indicates that
×
1820
                // this will be populated when the edge is a zombie.
×
1821
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1822
                        err)
×
1823
        }
×
1824

1825
        return edge, policy1, policy2, nil
×
1826
}
1827

1828
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1829
// the channel identified by the funding outpoint. If the channel can't be
1830
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1831
// information for the channel itself is returned as well as two structs that
1832
// contain the routing policies for the channel in either direction.
1833
//
1834
// NOTE: part of the V1Store interface.
1835
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1836
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1837
        *models.ChannelEdgePolicy, error) {
×
1838

×
1839
        var (
×
1840
                ctx              = context.TODO()
×
1841
                edge             *models.ChannelEdgeInfo
×
1842
                policy1, policy2 *models.ChannelEdgePolicy
×
1843
        )
×
1844
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1845
                row, err := db.GetChannelByOutpointWithPolicies(
×
1846
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1847
                                Outpoint: op.String(),
×
1848
                                Version:  int16(ProtocolV1),
×
1849
                        },
×
1850
                )
×
1851
                if errors.Is(err, sql.ErrNoRows) {
×
1852
                        return ErrEdgeNotFound
×
1853
                } else if err != nil {
×
1854
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1855
                }
×
1856

1857
                node1, node2, err := buildNodeVertices(
×
1858
                        row.Node1Pubkey, row.Node2Pubkey,
×
1859
                )
×
1860
                if err != nil {
×
1861
                        return err
×
1862
                }
×
1863

1864
                edge, err = getAndBuildEdgeInfo(
×
1865
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
1866
                )
×
1867
                if err != nil {
×
1868
                        return fmt.Errorf("unable to build channel info: %w",
×
1869
                                err)
×
1870
                }
×
1871

1872
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1873
                if err != nil {
×
1874
                        return fmt.Errorf("unable to extract channel "+
×
1875
                                "policies: %w", err)
×
1876
                }
×
1877

1878
                policy1, policy2, err = getAndBuildChanPolicies(
×
1879
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
1880
                        node1, node2,
×
1881
                )
×
1882
                if err != nil {
×
1883
                        return fmt.Errorf("unable to build channel "+
×
1884
                                "policies: %w", err)
×
1885
                }
×
1886

1887
                return nil
×
1888
        }, sqldb.NoOpReset)
1889
        if err != nil {
×
1890
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1891
                        err)
×
1892
        }
×
1893

1894
        return edge, policy1, policy2, nil
×
1895
}
1896

1897
// HasChannelEdge returns true if the database knows of a channel edge with the
1898
// passed channel ID, and false otherwise. If an edge with that ID is found
1899
// within the graph, then two time stamps representing the last time the edge
1900
// was updated for both directed edges are returned along with the boolean. If
1901
// it is not found, then the zombie index is checked and its result is returned
1902
// as the second boolean.
1903
//
1904
// NOTE: part of the V1Store interface.
1905
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1906
        bool, error) {
×
1907

×
1908
        ctx := context.TODO()
×
1909

×
1910
        var (
×
1911
                exists          bool
×
1912
                isZombie        bool
×
1913
                node1LastUpdate time.Time
×
1914
                node2LastUpdate time.Time
×
1915
        )
×
1916

×
1917
        // We'll query the cache with the shared lock held to allow multiple
×
1918
        // readers to access values in the cache concurrently if they exist.
×
1919
        s.cacheMu.RLock()
×
1920
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1921
                s.cacheMu.RUnlock()
×
1922
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1923
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1924
                exists, isZombie = entry.flags.unpack()
×
1925

×
1926
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1927
        }
×
1928
        s.cacheMu.RUnlock()
×
1929

×
1930
        s.cacheMu.Lock()
×
1931
        defer s.cacheMu.Unlock()
×
1932

×
1933
        // The item was not found with the shared lock, so we'll acquire the
×
1934
        // exclusive lock and check the cache again in case another method added
×
1935
        // the entry to the cache while no lock was held.
×
1936
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1937
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1938
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1939
                exists, isZombie = entry.flags.unpack()
×
1940

×
1941
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1942
        }
×
1943

1944
        chanIDB := channelIDToBytes(chanID)
×
1945
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1946
                channel, err := db.GetChannelBySCID(
×
1947
                        ctx, sqlc.GetChannelBySCIDParams{
×
1948
                                Scid:    chanIDB,
×
1949
                                Version: int16(ProtocolV1),
×
1950
                        },
×
1951
                )
×
1952
                if errors.Is(err, sql.ErrNoRows) {
×
1953
                        // Check if it is a zombie channel.
×
1954
                        isZombie, err = db.IsZombieChannel(
×
1955
                                ctx, sqlc.IsZombieChannelParams{
×
1956
                                        Scid:    chanIDB,
×
1957
                                        Version: int16(ProtocolV1),
×
1958
                                },
×
1959
                        )
×
1960
                        if err != nil {
×
1961
                                return fmt.Errorf("could not check if channel "+
×
1962
                                        "is zombie: %w", err)
×
1963
                        }
×
1964

1965
                        return nil
×
1966
                } else if err != nil {
×
1967
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1968
                }
×
1969

1970
                exists = true
×
1971

×
1972
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
1973
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1974
                                Version:   int16(ProtocolV1),
×
1975
                                ChannelID: channel.ID,
×
1976
                                NodeID:    channel.NodeID1,
×
1977
                        },
×
1978
                )
×
1979
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1980
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
1981
                                err)
×
1982
                } else if err == nil {
×
1983
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
1984
                }
×
1985

1986
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
1987
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1988
                                Version:   int16(ProtocolV1),
×
1989
                                ChannelID: channel.ID,
×
1990
                                NodeID:    channel.NodeID2,
×
1991
                        },
×
1992
                )
×
1993
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1994
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
1995
                                err)
×
1996
                } else if err == nil {
×
1997
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
1998
                }
×
1999

2000
                return nil
×
2001
        }, sqldb.NoOpReset)
2002
        if err != nil {
×
2003
                return time.Time{}, time.Time{}, false, false,
×
2004
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2005
        }
×
2006

2007
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2008
                upd1Time: node1LastUpdate.Unix(),
×
2009
                upd2Time: node2LastUpdate.Unix(),
×
2010
                flags:    packRejectFlags(exists, isZombie),
×
2011
        })
×
2012

×
2013
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2014
}
2015

2016
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2017
// passed channel point (outpoint). If the passed channel doesn't exist within
2018
// the database, then ErrEdgeNotFound is returned.
2019
//
2020
// NOTE: part of the V1Store interface.
2021
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2022
        var (
×
2023
                ctx       = context.TODO()
×
2024
                channelID uint64
×
2025
        )
×
2026
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2027
                chanID, err := db.GetSCIDByOutpoint(
×
2028
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2029
                                Outpoint: chanPoint.String(),
×
2030
                                Version:  int16(ProtocolV1),
×
2031
                        },
×
2032
                )
×
2033
                if errors.Is(err, sql.ErrNoRows) {
×
2034
                        return ErrEdgeNotFound
×
2035
                } else if err != nil {
×
2036
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2037
                                err)
×
2038
                }
×
2039

2040
                channelID = byteOrder.Uint64(chanID)
×
2041

×
2042
                return nil
×
2043
        }, sqldb.NoOpReset)
2044
        if err != nil {
×
2045
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2046
        }
×
2047

2048
        return channelID, nil
×
2049
}
2050

2051
// IsPublicNode is a helper method that determines whether the node with the
2052
// given public key is seen as a public node in the graph from the graph's
2053
// source node's point of view.
2054
//
2055
// NOTE: part of the V1Store interface.
2056
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2057
        ctx := context.TODO()
×
2058

×
2059
        var isPublic bool
×
2060
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2061
                var err error
×
2062
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2063

×
2064
                return err
×
2065
        }, sqldb.NoOpReset)
×
2066
        if err != nil {
×
2067
                return false, fmt.Errorf("unable to check if node is "+
×
2068
                        "public: %w", err)
×
2069
        }
×
2070

2071
        return isPublic, nil
×
2072
}
2073

2074
// FetchChanInfos returns the set of channel edges that correspond to the passed
2075
// channel ID's. If an edge is the query is unknown to the database, it will
2076
// skipped and the result will contain only those edges that exist at the time
2077
// of the query. This can be used to respond to peer queries that are seeking to
2078
// fill in gaps in their view of the channel graph.
2079
//
2080
// NOTE: part of the V1Store interface.
2081
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2082
        var (
×
2083
                ctx   = context.TODO()
×
2084
                edges = make(map[uint64]ChannelEdge)
×
2085
        )
×
2086
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2087
                // First, collect all channel rows.
×
2088
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2089
                chanCallBack := func(ctx context.Context,
×
2090
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2091

×
2092
                        channelRows = append(channelRows, row)
×
2093
                        return nil
×
2094
                }
×
2095

2096
                err := s.forEachChanWithPoliciesInSCIDList(
×
2097
                        ctx, db, chanCallBack, chanIDs,
×
2098
                )
×
2099
                if err != nil {
×
2100
                        return err
×
2101
                }
×
2102

2103
                if len(channelRows) == 0 {
×
2104
                        return nil
×
2105
                }
×
2106

2107
                // Batch build all channel edges.
2108
                chans, err := batchBuildChannelEdges(
×
2109
                        ctx, s.cfg, db, channelRows,
×
2110
                )
×
2111
                if err != nil {
×
2112
                        return fmt.Errorf("unable to build channel edges: %w",
×
2113
                                err)
×
2114
                }
×
2115

2116
                for _, c := range chans {
×
2117
                        edges[c.Info.ChannelID] = c
×
2118
                }
×
2119

2120
                return err
×
2121
        }, func() {
×
2122
                clear(edges)
×
2123
        })
×
2124
        if err != nil {
×
2125
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2126
        }
×
2127

2128
        res := make([]ChannelEdge, 0, len(edges))
×
2129
        for _, chanID := range chanIDs {
×
2130
                edge, ok := edges[chanID]
×
2131
                if !ok {
×
2132
                        continue
×
2133
                }
2134

2135
                res = append(res, edge)
×
2136
        }
2137

2138
        return res, nil
×
2139
}
2140

2141
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2142
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2143
// channels in a paginated manner.
2144
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2145
        db SQLQueries, cb func(ctx context.Context,
2146
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2147
        chanIDs []uint64) error {
×
2148

×
2149
        queryWrapper := func(ctx context.Context,
×
2150
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2151
                error) {
×
2152

×
2153
                return db.GetChannelsBySCIDWithPolicies(
×
2154
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2155
                                Version: int16(ProtocolV1),
×
2156
                                Scids:   scids,
×
2157
                        },
×
2158
                )
×
2159
        }
×
2160

2161
        return sqldb.ExecuteBatchQuery(
×
2162
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2163
                cb,
×
2164
        )
×
2165
}
2166

2167
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2168
// ID's that we don't know and are not known zombies of the passed set. In other
2169
// words, we perform a set difference of our set of chan ID's and the ones
2170
// passed in. This method can be used by callers to determine the set of
2171
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2172
// known zombies is also returned.
2173
//
2174
// NOTE: part of the V1Store interface.
2175
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2176
        []ChannelUpdateInfo, error) {
×
2177

×
2178
        var (
×
2179
                ctx          = context.TODO()
×
2180
                newChanIDs   []uint64
×
2181
                knownZombies []ChannelUpdateInfo
×
2182
                infoLookup   = make(
×
2183
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2184
                )
×
2185
        )
×
2186

×
2187
        // We first build a lookup map of the channel ID's to the
×
2188
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2189
        // already know about.
×
2190
        for _, chanInfo := range chansInfo {
×
2191
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2192
        }
×
2193

2194
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2195
                // The call-back function deletes known channels from
×
2196
                // infoLookup, so that we can later check which channels are
×
2197
                // zombies by only looking at the remaining channels in the set.
×
2198
                cb := func(ctx context.Context,
×
2199
                        channel sqlc.GraphChannel) error {
×
2200

×
2201
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2202

×
2203
                        return nil
×
2204
                }
×
2205

2206
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2207
                if err != nil {
×
2208
                        return fmt.Errorf("unable to iterate through "+
×
2209
                                "channels: %w", err)
×
2210
                }
×
2211

2212
                // We want to ensure that we deal with the channels in the
2213
                // same order that they were passed in, so we iterate over the
2214
                // original chansInfo slice and then check if that channel is
2215
                // still in the infoLookup map.
2216
                for _, chanInfo := range chansInfo {
×
2217
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2218
                        if _, ok := infoLookup[channelID]; !ok {
×
2219
                                continue
×
2220
                        }
2221

2222
                        isZombie, err := db.IsZombieChannel(
×
2223
                                ctx, sqlc.IsZombieChannelParams{
×
2224
                                        Scid:    channelIDToBytes(channelID),
×
2225
                                        Version: int16(ProtocolV1),
×
2226
                                },
×
2227
                        )
×
2228
                        if err != nil {
×
2229
                                return fmt.Errorf("unable to fetch zombie "+
×
2230
                                        "channel: %w", err)
×
2231
                        }
×
2232

2233
                        if isZombie {
×
2234
                                knownZombies = append(knownZombies, chanInfo)
×
2235

×
2236
                                continue
×
2237
                        }
2238

2239
                        newChanIDs = append(newChanIDs, channelID)
×
2240
                }
2241

2242
                return nil
×
2243
        }, func() {
×
2244
                newChanIDs = nil
×
2245
                knownZombies = nil
×
2246
                // Rebuild the infoLookup map in case of a rollback.
×
2247
                for _, chanInfo := range chansInfo {
×
2248
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2249
                        infoLookup[scid] = chanInfo
×
2250
                }
×
2251
        })
2252
        if err != nil {
×
2253
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2254
        }
×
2255

2256
        return newChanIDs, knownZombies, nil
×
2257
}
2258

2259
// forEachChanInSCIDList is a helper method that executes a paged query
2260
// against the database to fetch all channels that match the passed
2261
// ChannelUpdateInfo slice. The callback function is called for each channel
2262
// that is found.
2263
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2264
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2265
        chansInfo []ChannelUpdateInfo) error {
×
2266

×
2267
        queryWrapper := func(ctx context.Context,
×
2268
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2269

×
2270
                return db.GetChannelsBySCIDs(
×
2271
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2272
                                Version: int16(ProtocolV1),
×
2273
                                Scids:   scids,
×
2274
                        },
×
2275
                )
×
2276
        }
×
2277

2278
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2279
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2280

×
2281
                return channelIDToBytes(channelID)
×
2282
        }
×
2283

2284
        return sqldb.ExecuteBatchQuery(
×
2285
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2286
                cb,
×
2287
        )
×
2288
}
2289

2290
// PruneGraphNodes is a garbage collection method which attempts to prune out
2291
// any nodes from the channel graph that are currently unconnected. This ensure
2292
// that we only maintain a graph of reachable nodes. In the event that a pruned
2293
// node gains more channels, it will be re-added back to the graph.
2294
//
2295
// NOTE: this prunes nodes across protocol versions. It will never prune the
2296
// source nodes.
2297
//
2298
// NOTE: part of the V1Store interface.
2299
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2300
        var ctx = context.TODO()
×
2301

×
2302
        var prunedNodes []route.Vertex
×
2303
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2304
                var err error
×
2305
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2306

×
2307
                return err
×
2308
        }, func() {
×
2309
                prunedNodes = nil
×
2310
        })
×
2311
        if err != nil {
×
2312
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2313
        }
×
2314

2315
        return prunedNodes, nil
×
2316
}
2317

2318
// PruneGraph prunes newly closed channels from the channel graph in response
2319
// to a new block being solved on the network. Any transactions which spend the
2320
// funding output of any known channels within he graph will be deleted.
2321
// Additionally, the "prune tip", or the last block which has been used to
2322
// prune the graph is stored so callers can ensure the graph is fully in sync
2323
// with the current UTXO state. A slice of channels that have been closed by
2324
// the target block along with any pruned nodes are returned if the function
2325
// succeeds without error.
2326
//
2327
// NOTE: part of the V1Store interface.
2328
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2329
        blockHash *chainhash.Hash, blockHeight uint32) (
2330
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2331

×
2332
        ctx := context.TODO()
×
2333

×
2334
        s.cacheMu.Lock()
×
2335
        defer s.cacheMu.Unlock()
×
2336

×
2337
        var (
×
2338
                closedChans []*models.ChannelEdgeInfo
×
2339
                prunedNodes []route.Vertex
×
2340
        )
×
2341
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2342
                // First, collect all channel rows that need to be pruned.
×
2343
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2344
                channelCallback := func(ctx context.Context,
×
2345
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2346

×
2347
                        channelRows = append(channelRows, row)
×
2348

×
2349
                        return nil
×
2350
                }
×
2351

2352
                err := s.forEachChanInOutpoints(
×
2353
                        ctx, db, spentOutputs, channelCallback,
×
2354
                )
×
2355
                if err != nil {
×
2356
                        return fmt.Errorf("unable to fetch channels by "+
×
2357
                                "outpoints: %w", err)
×
2358
                }
×
2359

2360
                if len(channelRows) == 0 {
×
2361
                        // There are no channels to prune. So we can exit early
×
2362
                        // after updating the prune log.
×
2363
                        err = db.UpsertPruneLogEntry(
×
2364
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
2365
                                        BlockHash:   blockHash[:],
×
2366
                                        BlockHeight: int64(blockHeight),
×
2367
                                },
×
2368
                        )
×
2369
                        if err != nil {
×
2370
                                return fmt.Errorf("unable to insert prune log "+
×
2371
                                        "entry: %w", err)
×
2372
                        }
×
2373

2374
                        return nil
×
2375
                }
2376

2377
                // Batch build all channel edges for pruning.
2378
                var chansToDelete []int64
×
2379
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
2380
                        ctx, s.cfg, db, channelRows,
×
2381
                )
×
2382
                if err != nil {
×
2383
                        return err
×
2384
                }
×
2385

2386
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2387
                if err != nil {
×
2388
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2389
                }
×
2390

2391
                err = db.UpsertPruneLogEntry(
×
2392
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2393
                                BlockHash:   blockHash[:],
×
2394
                                BlockHeight: int64(blockHeight),
×
2395
                        },
×
2396
                )
×
2397
                if err != nil {
×
2398
                        return fmt.Errorf("unable to insert prune log "+
×
2399
                                "entry: %w", err)
×
2400
                }
×
2401

2402
                // Now that we've pruned some channels, we'll also prune any
2403
                // nodes that no longer have any channels.
2404
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2405
                if err != nil {
×
2406
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2407
                                err)
×
2408
                }
×
2409

2410
                return nil
×
2411
        }, func() {
×
2412
                prunedNodes = nil
×
2413
                closedChans = nil
×
2414
        })
×
2415
        if err != nil {
×
2416
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2417
        }
×
2418

2419
        for _, channel := range closedChans {
×
2420
                s.rejectCache.remove(channel.ChannelID)
×
2421
                s.chanCache.remove(channel.ChannelID)
×
2422
        }
×
2423

2424
        return closedChans, prunedNodes, nil
×
2425
}
2426

2427
// forEachChanInOutpoints is a helper function that executes a paginated
2428
// query to fetch channels by their outpoints and applies the given call-back
2429
// to each.
2430
//
2431
// NOTE: this fetches channels for all protocol versions.
2432
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2433
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2434
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2435

×
2436
        // Create a wrapper that uses the transaction's db instance to execute
×
2437
        // the query.
×
2438
        queryWrapper := func(ctx context.Context,
×
2439
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2440
                error) {
×
2441

×
2442
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2443
        }
×
2444

2445
        // Define the conversion function from Outpoint to string.
2446
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2447
                return outpoint.String()
×
2448
        }
×
2449

2450
        return sqldb.ExecuteBatchQuery(
×
2451
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2452
                queryWrapper, cb,
×
2453
        )
×
2454
}
2455

2456
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2457
        dbIDs []int64) error {
×
2458

×
2459
        // Create a wrapper that uses the transaction's db instance to execute
×
2460
        // the query.
×
2461
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2462
                return nil, db.DeleteChannels(ctx, ids)
×
2463
        }
×
2464

2465
        idConverter := func(id int64) int64 {
×
2466
                return id
×
2467
        }
×
2468

2469
        return sqldb.ExecuteBatchQuery(
×
2470
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2471
                queryWrapper, func(ctx context.Context, _ any) error {
×
2472
                        return nil
×
2473
                },
×
2474
        )
2475
}
2476

2477
// ChannelView returns the verifiable edge information for each active channel
2478
// within the known channel graph. The set of UTXOs (along with their scripts)
2479
// returned are the ones that need to be watched on chain to detect channel
2480
// closes on the resident blockchain.
2481
//
2482
// NOTE: part of the V1Store interface.
2483
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2484
        var (
×
2485
                ctx        = context.TODO()
×
2486
                edgePoints []EdgePoint
×
2487
        )
×
2488

×
2489
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2490
                handleChannel := func(_ context.Context,
×
2491
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2492

×
2493
                        pkScript, err := genMultiSigP2WSH(
×
2494
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2495
                        )
×
2496
                        if err != nil {
×
2497
                                return err
×
2498
                        }
×
2499

2500
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2501
                        if err != nil {
×
2502
                                return err
×
2503
                        }
×
2504

2505
                        edgePoints = append(edgePoints, EdgePoint{
×
2506
                                FundingPkScript: pkScript,
×
2507
                                OutPoint:        *op,
×
2508
                        })
×
2509

×
2510
                        return nil
×
2511
                }
2512

2513
                queryFunc := func(ctx context.Context, lastID int64,
×
2514
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2515

×
2516
                        return db.ListChannelsPaginated(
×
2517
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2518
                                        Version: int16(ProtocolV1),
×
2519
                                        ID:      lastID,
×
2520
                                        Limit:   limit,
×
2521
                                },
×
2522
                        )
×
2523
                }
×
2524

2525
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2526
                        return row.ID
×
2527
                }
×
2528

2529
                return sqldb.ExecutePaginatedQuery(
×
2530
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
2531
                        extractCursor, handleChannel,
×
2532
                )
×
2533
        }, func() {
×
2534
                edgePoints = nil
×
2535
        })
×
2536
        if err != nil {
×
2537
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2538
        }
×
2539

2540
        return edgePoints, nil
×
2541
}
2542

2543
// PruneTip returns the block height and hash of the latest block that has been
2544
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2545
// to tell if the graph is currently in sync with the current best known UTXO
2546
// state.
2547
//
2548
// NOTE: part of the V1Store interface.
2549
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2550
        var (
×
2551
                ctx       = context.TODO()
×
2552
                tipHash   chainhash.Hash
×
2553
                tipHeight uint32
×
2554
        )
×
2555
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2556
                pruneTip, err := db.GetPruneTip(ctx)
×
2557
                if errors.Is(err, sql.ErrNoRows) {
×
2558
                        return ErrGraphNeverPruned
×
2559
                } else if err != nil {
×
2560
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2561
                }
×
2562

2563
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2564
                tipHeight = uint32(pruneTip.BlockHeight)
×
2565

×
2566
                return nil
×
2567
        }, sqldb.NoOpReset)
2568
        if err != nil {
×
2569
                return nil, 0, err
×
2570
        }
×
2571

2572
        return &tipHash, tipHeight, nil
×
2573
}
2574

2575
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2576
//
2577
// NOTE: this prunes nodes across protocol versions. It will never prune the
2578
// source nodes.
2579
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2580
        db SQLQueries) ([]route.Vertex, error) {
×
2581

×
2582
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2583
        if err != nil {
×
2584
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2585
                        "nodes: %w", err)
×
2586
        }
×
2587

2588
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2589
        for i, nodeKey := range nodeKeys {
×
2590
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2591
                if err != nil {
×
2592
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2593
                                "from bytes: %w", err)
×
2594
                }
×
2595

2596
                prunedNodes[i] = pub
×
2597
        }
2598

2599
        return prunedNodes, nil
×
2600
}
2601

2602
// DisconnectBlockAtHeight is used to indicate that the block specified
2603
// by the passed height has been disconnected from the main chain. This
2604
// will "rewind" the graph back to the height below, deleting channels
2605
// that are no longer confirmed from the graph. The prune log will be
2606
// set to the last prune height valid for the remaining chain.
2607
// Channels that were removed from the graph resulting from the
2608
// disconnected block are returned.
2609
//
2610
// NOTE: part of the V1Store interface.
2611
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2612
        []*models.ChannelEdgeInfo, error) {
×
2613

×
2614
        ctx := context.TODO()
×
2615

×
2616
        var (
×
2617
                // Every channel having a ShortChannelID starting at 'height'
×
2618
                // will no longer be confirmed.
×
2619
                startShortChanID = lnwire.ShortChannelID{
×
2620
                        BlockHeight: height,
×
2621
                }
×
2622

×
2623
                // Delete everything after this height from the db up until the
×
2624
                // SCID alias range.
×
2625
                endShortChanID = aliasmgr.StartingAlias
×
2626

×
2627
                removedChans []*models.ChannelEdgeInfo
×
2628

×
2629
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2630
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2631
        )
×
2632

×
2633
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2634
                rows, err := db.GetChannelsBySCIDRange(
×
2635
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2636
                                StartScid: chanIDStart,
×
2637
                                EndScid:   chanIDEnd,
×
2638
                        },
×
2639
                )
×
2640
                if err != nil {
×
2641
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2642
                }
×
2643

2644
                if len(rows) == 0 {
×
2645
                        // No channels to disconnect, but still clean up prune
×
2646
                        // log.
×
2647
                        return db.DeletePruneLogEntriesInRange(
×
2648
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2649
                                        StartHeight: int64(height),
×
2650
                                        EndHeight: int64(
×
2651
                                                endShortChanID.BlockHeight,
×
2652
                                        ),
×
2653
                                },
×
2654
                        )
×
2655
                }
×
2656

2657
                // Batch build all channel edges for disconnection.
2658
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
2659
                        ctx, s.cfg, db, rows,
×
2660
                )
×
2661
                if err != nil {
×
2662
                        return err
×
2663
                }
×
2664

2665
                removedChans = channelEdges
×
2666

×
2667
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2668
                if err != nil {
×
2669
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2670
                }
×
2671

2672
                return db.DeletePruneLogEntriesInRange(
×
2673
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2674
                                StartHeight: int64(height),
×
2675
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2676
                        },
×
2677
                )
×
2678
        }, func() {
×
2679
                removedChans = nil
×
2680
        })
×
2681
        if err != nil {
×
2682
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2683
                        "height: %w", err)
×
2684
        }
×
2685

2686
        for _, channel := range removedChans {
×
2687
                s.rejectCache.remove(channel.ChannelID)
×
2688
                s.chanCache.remove(channel.ChannelID)
×
2689
        }
×
2690

2691
        return removedChans, nil
×
2692
}
2693

2694
// AddEdgeProof sets the proof of an existing edge in the graph database.
2695
//
2696
// NOTE: part of the V1Store interface.
2697
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2698
        proof *models.ChannelAuthProof) error {
×
2699

×
2700
        var (
×
2701
                ctx       = context.TODO()
×
2702
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2703
        )
×
2704

×
2705
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2706
                res, err := db.AddV1ChannelProof(
×
2707
                        ctx, sqlc.AddV1ChannelProofParams{
×
2708
                                Scid:              scidBytes,
×
2709
                                Node1Signature:    proof.NodeSig1Bytes,
×
2710
                                Node2Signature:    proof.NodeSig2Bytes,
×
2711
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2712
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2713
                        },
×
2714
                )
×
2715
                if err != nil {
×
2716
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2717
                }
×
2718

2719
                n, err := res.RowsAffected()
×
2720
                if err != nil {
×
2721
                        return err
×
2722
                }
×
2723

2724
                if n == 0 {
×
2725
                        return fmt.Errorf("no rows affected when adding edge "+
×
2726
                                "proof for SCID %v", scid)
×
2727
                } else if n > 1 {
×
2728
                        return fmt.Errorf("multiple rows affected when adding "+
×
2729
                                "edge proof for SCID %v: %d rows affected",
×
2730
                                scid, n)
×
2731
                }
×
2732

2733
                return nil
×
2734
        }, sqldb.NoOpReset)
2735
        if err != nil {
×
2736
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2737
        }
×
2738

2739
        return nil
×
2740
}
2741

2742
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2743
// that we can ignore channel announcements that we know to be closed without
2744
// having to validate them and fetch a block.
2745
//
2746
// NOTE: part of the V1Store interface.
2747
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2748
        var (
×
2749
                ctx     = context.TODO()
×
2750
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2751
        )
×
2752

×
2753
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2754
                return db.InsertClosedChannel(ctx, chanIDB)
×
2755
        }, sqldb.NoOpReset)
×
2756
}
2757

2758
// IsClosedScid checks whether a channel identified by the passed in scid is
2759
// closed. This helps avoid having to perform expensive validation checks.
2760
//
2761
// NOTE: part of the V1Store interface.
2762
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2763
        var (
×
2764
                ctx      = context.TODO()
×
2765
                isClosed bool
×
2766
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2767
        )
×
2768
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2769
                var err error
×
2770
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2771
                if err != nil {
×
2772
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2773
                                err)
×
2774
                }
×
2775

2776
                return nil
×
2777
        }, sqldb.NoOpReset)
2778
        if err != nil {
×
2779
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2780
                        err)
×
2781
        }
×
2782

2783
        return isClosed, nil
×
2784
}
2785

2786
// GraphSession will provide the call-back with access to a NodeTraverser
2787
// instance which can be used to perform queries against the channel graph.
2788
//
2789
// NOTE: part of the V1Store interface.
2790
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
2791
        reset func()) error {
×
2792

×
2793
        var ctx = context.TODO()
×
2794

×
2795
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2796
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2797
        }, reset)
×
2798
}
2799

2800
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2801
// read only transaction for a consistent view of the graph.
2802
type sqlNodeTraverser struct {
2803
        db    SQLQueries
2804
        chain chainhash.Hash
2805
}
2806

2807
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2808
// NodeTraverser interface.
2809
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2810

2811
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2812
func newSQLNodeTraverser(db SQLQueries,
2813
        chain chainhash.Hash) *sqlNodeTraverser {
×
2814

×
2815
        return &sqlNodeTraverser{
×
2816
                db:    db,
×
2817
                chain: chain,
×
2818
        }
×
2819
}
×
2820

2821
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2822
// node.
2823
//
2824
// NOTE: Part of the NodeTraverser interface.
2825
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2826
        cb func(channel *DirectedChannel) error, _ func()) error {
×
2827

×
2828
        ctx := context.TODO()
×
2829

×
2830
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2831
}
×
2832

2833
// FetchNodeFeatures returns the features of the given node. If the node is
2834
// unknown, assume no additional features are supported.
2835
//
2836
// NOTE: Part of the NodeTraverser interface.
2837
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2838
        *lnwire.FeatureVector, error) {
×
2839

×
2840
        ctx := context.TODO()
×
2841

×
2842
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2843
}
×
2844

2845
// forEachNodeDirectedChannel iterates through all channels of a given
2846
// node, executing the passed callback on the directed edge representing the
2847
// channel and its incoming policy. If the node is not found, no error is
2848
// returned.
2849
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2850
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2851

×
2852
        toNodeCallback := func() route.Vertex {
×
2853
                return nodePub
×
2854
        }
×
2855

2856
        dbID, err := db.GetNodeIDByPubKey(
×
2857
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2858
                        Version: int16(ProtocolV1),
×
2859
                        PubKey:  nodePub[:],
×
2860
                },
×
2861
        )
×
2862
        if errors.Is(err, sql.ErrNoRows) {
×
2863
                return nil
×
2864
        } else if err != nil {
×
2865
                return fmt.Errorf("unable to fetch node: %w", err)
×
2866
        }
×
2867

2868
        rows, err := db.ListChannelsByNodeID(
×
2869
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2870
                        Version: int16(ProtocolV1),
×
2871
                        NodeID1: dbID,
×
2872
                },
×
2873
        )
×
2874
        if err != nil {
×
2875
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2876
        }
×
2877

2878
        // Exit early if there are no channels for this node so we don't
2879
        // do the unnecessary feature fetching.
2880
        if len(rows) == 0 {
×
2881
                return nil
×
2882
        }
×
2883

2884
        features, err := getNodeFeatures(ctx, db, dbID)
×
2885
        if err != nil {
×
2886
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2887
        }
×
2888

2889
        for _, row := range rows {
×
2890
                node1, node2, err := buildNodeVertices(
×
2891
                        row.Node1Pubkey, row.Node2Pubkey,
×
2892
                )
×
2893
                if err != nil {
×
2894
                        return fmt.Errorf("unable to build node vertices: %w",
×
2895
                                err)
×
2896
                }
×
2897

2898
                edge := buildCacheableChannelInfo(
×
2899
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
2900
                        node1, node2,
×
2901
                )
×
2902

×
2903
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2904
                if err != nil {
×
2905
                        return err
×
2906
                }
×
2907

2908
                p1, p2, err := buildCachedChanPolicies(
×
2909
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
2910
                )
×
2911
                if err != nil {
×
2912
                        return err
×
2913
                }
×
2914

2915
                // Determine the outgoing and incoming policy for this
2916
                // channel and node combo.
2917
                outPolicy, inPolicy := p1, p2
×
2918
                if p1 != nil && node2 == nodePub {
×
2919
                        outPolicy, inPolicy = p2, p1
×
2920
                } else if p2 != nil && node1 != nodePub {
×
2921
                        outPolicy, inPolicy = p2, p1
×
2922
                }
×
2923

2924
                var cachedInPolicy *models.CachedEdgePolicy
×
2925
                if inPolicy != nil {
×
2926
                        cachedInPolicy = inPolicy
×
2927
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2928
                        cachedInPolicy.ToNodeFeatures = features
×
2929
                }
×
2930

2931
                directedChannel := &DirectedChannel{
×
2932
                        ChannelID:    edge.ChannelID,
×
2933
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2934
                        OtherNode:    edge.NodeKey2Bytes,
×
2935
                        Capacity:     edge.Capacity,
×
2936
                        OutPolicySet: outPolicy != nil,
×
2937
                        InPolicy:     cachedInPolicy,
×
2938
                }
×
2939
                if outPolicy != nil {
×
2940
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2941
                                directedChannel.InboundFee = fee
×
2942
                        })
×
2943
                }
2944

2945
                if nodePub == edge.NodeKey2Bytes {
×
2946
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2947
                }
×
2948

2949
                if err := cb(directedChannel); err != nil {
×
2950
                        return err
×
2951
                }
×
2952
        }
2953

2954
        return nil
×
2955
}
2956

2957
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2958
// and executes the provided callback for each node. It does so via pagination
2959
// along with batch loading of the node feature bits.
2960
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
2961
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
2962
                features *lnwire.FeatureVector) error) error {
×
2963

×
2964
        handleNode := func(_ context.Context,
×
2965
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
2966
                featureBits map[int64][]int) error {
×
2967

×
2968
                fv := lnwire.EmptyFeatureVector()
×
2969
                if features, exists := featureBits[dbNode.ID]; exists {
×
2970
                        for _, bit := range features {
×
2971
                                fv.Set(lnwire.FeatureBit(bit))
×
2972
                        }
×
2973
                }
2974

2975
                var pub route.Vertex
×
2976
                copy(pub[:], dbNode.PubKey)
×
2977

×
2978
                return processNode(dbNode.ID, pub, fv)
×
2979
        }
2980

2981
        queryFunc := func(ctx context.Context, lastID int64,
×
2982
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
2983

×
2984
                return db.ListNodeIDsAndPubKeys(
×
2985
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
2986
                                Version: int16(ProtocolV1),
×
2987
                                ID:      lastID,
×
2988
                                Limit:   limit,
×
2989
                        },
×
2990
                )
×
2991
        }
×
2992

2993
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
2994
                return row.ID
×
2995
        }
×
2996

2997
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
2998
                return node.ID, nil
×
2999
        }
×
3000

3001
        batchQueryFunc := func(ctx context.Context,
×
3002
                nodeIDs []int64) (map[int64][]int, error) {
×
3003

×
3004
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
3005
        }
×
3006

3007
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
3008
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
3009
                batchQueryFunc, handleNode,
×
3010
        )
×
3011
}
3012

3013
// forEachNodeChannel iterates through all channels of a node, executing
3014
// the passed callback on each. The call-back is provided with the channel's
3015
// edge information, the outgoing policy and the incoming policy for the
3016
// channel and node combo.
3017
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3018
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
3019
                *models.ChannelEdgePolicy,
3020
                *models.ChannelEdgePolicy) error) error {
×
3021

×
3022
        // Get all the V1 channels for this node.
×
3023
        rows, err := db.ListChannelsByNodeID(
×
3024
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3025
                        Version: int16(ProtocolV1),
×
3026
                        NodeID1: id,
×
3027
                },
×
3028
        )
×
3029
        if err != nil {
×
3030
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3031
        }
×
3032

3033
        // Collect all the channel and policy IDs.
3034
        var (
×
3035
                chanIDs   = make([]int64, 0, len(rows))
×
3036
                policyIDs = make([]int64, 0, 2*len(rows))
×
3037
        )
×
3038
        for _, row := range rows {
×
3039
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3040

×
3041
                if row.Policy1ID.Valid {
×
3042
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3043
                }
×
3044
                if row.Policy2ID.Valid {
×
3045
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3046
                }
×
3047
        }
3048

3049
        batchData, err := batchLoadChannelData(
×
3050
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3051
        )
×
3052
        if err != nil {
×
3053
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3054
        }
×
3055

3056
        // Call the call-back for each channel and its known policies.
3057
        for _, row := range rows {
×
3058
                node1, node2, err := buildNodeVertices(
×
3059
                        row.Node1Pubkey, row.Node2Pubkey,
×
3060
                )
×
3061
                if err != nil {
×
3062
                        return fmt.Errorf("unable to build node vertices: %w",
×
3063
                                err)
×
3064
                }
×
3065

3066
                edge, err := buildEdgeInfoWithBatchData(
×
3067
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3068
                        batchData,
×
3069
                )
×
3070
                if err != nil {
×
3071
                        return fmt.Errorf("unable to build channel info: %w",
×
3072
                                err)
×
3073
                }
×
3074

3075
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3076
                if err != nil {
×
3077
                        return fmt.Errorf("unable to extract channel "+
×
3078
                                "policies: %w", err)
×
3079
                }
×
3080

3081
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3082
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3083
                )
×
3084
                if err != nil {
×
3085
                        return fmt.Errorf("unable to build channel "+
×
3086
                                "policies: %w", err)
×
3087
                }
×
3088

3089
                // Determine the outgoing and incoming policy for this
3090
                // channel and node combo.
3091
                p1ToNode := row.GraphChannel.NodeID2
×
3092
                p2ToNode := row.GraphChannel.NodeID1
×
3093
                outPolicy, inPolicy := p1, p2
×
3094
                if (p1 != nil && p1ToNode == id) ||
×
3095
                        (p2 != nil && p2ToNode != id) {
×
3096

×
3097
                        outPolicy, inPolicy = p2, p1
×
3098
                }
×
3099

3100
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3101
                        return err
×
3102
                }
×
3103
        }
3104

3105
        return nil
×
3106
}
3107

3108
// updateChanEdgePolicy upserts the channel policy info we have stored for
3109
// a channel we already know of.
3110
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3111
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3112
        error) {
×
3113

×
3114
        var (
×
3115
                node1Pub, node2Pub route.Vertex
×
3116
                isNode1            bool
×
3117
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3118
        )
×
3119

×
3120
        // Check that this edge policy refers to a channel that we already
×
3121
        // know of. We do this explicitly so that we can return the appropriate
×
3122
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3123
        // abort the transaction which would abort the entire batch.
×
3124
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3125
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3126
                        Scid:    chanIDB,
×
3127
                        Version: int16(ProtocolV1),
×
3128
                },
×
3129
        )
×
3130
        if errors.Is(err, sql.ErrNoRows) {
×
3131
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3132
        } else if err != nil {
×
3133
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3134
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3135
        }
×
3136

3137
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3138
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3139

×
3140
        // Figure out which node this edge is from.
×
3141
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3142
        nodeID := dbChan.NodeID1
×
3143
        if !isNode1 {
×
3144
                nodeID = dbChan.NodeID2
×
3145
        }
×
3146

3147
        var (
×
3148
                inboundBase sql.NullInt64
×
3149
                inboundRate sql.NullInt64
×
3150
        )
×
3151
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3152
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3153
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3154
        })
×
3155

3156
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3157
                Version:     int16(ProtocolV1),
×
3158
                ChannelID:   dbChan.ID,
×
3159
                NodeID:      nodeID,
×
3160
                Timelock:    int32(edge.TimeLockDelta),
×
3161
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3162
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3163
                MinHtlcMsat: int64(edge.MinHTLC),
×
3164
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3165
                Disabled: sql.NullBool{
×
3166
                        Valid: true,
×
3167
                        Bool:  edge.IsDisabled(),
×
3168
                },
×
3169
                MaxHtlcMsat: sql.NullInt64{
×
3170
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3171
                        Int64: int64(edge.MaxHTLC),
×
3172
                },
×
3173
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3174
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3175
                InboundBaseFeeMsat:      inboundBase,
×
3176
                InboundFeeRateMilliMsat: inboundRate,
×
3177
                Signature:               edge.SigBytes,
×
3178
        })
×
3179
        if err != nil {
×
3180
                return node1Pub, node2Pub, isNode1,
×
3181
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3182
        }
×
3183

3184
        // Convert the flat extra opaque data into a map of TLV types to
3185
        // values.
3186
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3187
        if err != nil {
×
3188
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3189
                        "marshal extra opaque data: %w", err)
×
3190
        }
×
3191

3192
        // Update the channel policy's extra signed fields.
3193
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3194
        if err != nil {
×
3195
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3196
                        "policy extra TLVs: %w", err)
×
3197
        }
×
3198

3199
        return node1Pub, node2Pub, isNode1, nil
×
3200
}
3201

3202
// getNodeByPubKey attempts to look up a target node by its public key.
3203
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3204
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3205

×
3206
        dbNode, err := db.GetNodeByPubKey(
×
3207
                ctx, sqlc.GetNodeByPubKeyParams{
×
3208
                        Version: int16(ProtocolV1),
×
3209
                        PubKey:  pubKey[:],
×
3210
                },
×
3211
        )
×
3212
        if errors.Is(err, sql.ErrNoRows) {
×
3213
                return 0, nil, ErrGraphNodeNotFound
×
3214
        } else if err != nil {
×
3215
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3216
        }
×
3217

3218
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3219
        if err != nil {
×
3220
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3221
        }
×
3222

3223
        return dbNode.ID, node, nil
×
3224
}
3225

3226
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3227
// provided parameters.
3228
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3229
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3230

×
3231
        return &models.CachedEdgeInfo{
×
3232
                ChannelID:     byteOrder.Uint64(scid),
×
3233
                NodeKey1Bytes: node1Pub,
×
3234
                NodeKey2Bytes: node2Pub,
×
3235
                Capacity:      btcutil.Amount(capacity),
×
3236
        }
×
3237
}
×
3238

3239
// buildNode constructs a LightningNode instance from the given database node
3240
// record. The node's features, addresses and extra signed fields are also
3241
// fetched from the database and set on the node.
3242
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3243
        dbNode sqlc.GraphNode) (*models.LightningNode, error) {
×
3244

×
3245
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3246
        if err != nil {
×
3247
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3248
                        err)
×
3249
        }
×
3250

3251
        return buildNodeWithBatchData(dbNode, data)
×
3252
}
3253

3254
// buildNodeWithBatchData builds a models.LightningNode instance
3255
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3256
// features/addresses/extra fields, then the corresponding fields are expected
3257
// to be present in the batchNodeData.
3258
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3259
        batchData *batchNodeData) (*models.LightningNode, error) {
×
3260

×
3261
        if dbNode.Version != int16(ProtocolV1) {
×
3262
                return nil, fmt.Errorf("unsupported node version: %d",
×
3263
                        dbNode.Version)
×
3264
        }
×
3265

3266
        var pub [33]byte
×
3267
        copy(pub[:], dbNode.PubKey)
×
3268

×
3269
        node := &models.LightningNode{
×
3270
                PubKeyBytes: pub,
×
3271
                Features:    lnwire.EmptyFeatureVector(),
×
3272
                LastUpdate:  time.Unix(0, 0),
×
3273
        }
×
3274

×
3275
        if len(dbNode.Signature) == 0 {
×
3276
                return node, nil
×
3277
        }
×
3278

3279
        node.HaveNodeAnnouncement = true
×
3280
        node.AuthSigBytes = dbNode.Signature
×
3281
        node.Alias = dbNode.Alias.String
×
3282
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3283

×
3284
        var err error
×
3285
        if dbNode.Color.Valid {
×
3286
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3287
                if err != nil {
×
3288
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3289
                                err)
×
3290
                }
×
3291
        }
3292

3293
        // Use preloaded features.
3294
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3295
                fv := lnwire.EmptyFeatureVector()
×
3296
                for _, bit := range features {
×
3297
                        fv.Set(lnwire.FeatureBit(bit))
×
3298
                }
×
3299
                node.Features = fv
×
3300
        }
3301

3302
        // Use preloaded addresses.
3303
        addresses, exists := batchData.addresses[dbNode.ID]
×
3304
        if exists && len(addresses) > 0 {
×
3305
                node.Addresses, err = buildNodeAddresses(addresses)
×
3306
                if err != nil {
×
3307
                        return nil, fmt.Errorf("unable to build addresses "+
×
3308
                                "for node(%d): %w", dbNode.ID, err)
×
3309
                }
×
3310
        }
3311

3312
        // Use preloaded extra fields.
3313
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3314
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3315
                if err != nil {
×
3316
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3317
                                "signed fields: %w", err)
×
3318
                }
×
3319
                if len(recs) != 0 {
×
3320
                        node.ExtraOpaqueData = recs
×
3321
                }
×
3322
        }
3323

3324
        return node, nil
×
3325
}
3326

3327
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3328
// with the preloaded data, and executes the provided callback for each node.
3329
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3330
        db SQLQueries, nodes []sqlc.GraphNode,
3331
        cb func(dbID int64, node *models.LightningNode) error) error {
×
3332

×
3333
        // Extract node IDs for batch loading.
×
3334
        nodeIDs := make([]int64, len(nodes))
×
3335
        for i, node := range nodes {
×
3336
                nodeIDs[i] = node.ID
×
3337
        }
×
3338

3339
        // Batch load all related data for this page.
3340
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3341
        if err != nil {
×
3342
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3343
        }
×
3344

3345
        for _, dbNode := range nodes {
×
3346
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
3347
                if err != nil {
×
3348
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3349
                                dbNode.ID, err)
×
3350
                }
×
3351

3352
                if err := cb(dbNode.ID, node); err != nil {
×
3353
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3354
                                dbNode.ID, err)
×
3355
                }
×
3356
        }
3357

3358
        return nil
×
3359
}
3360

3361
// getNodeFeatures fetches the feature bits and constructs the feature vector
3362
// for a node with the given DB ID.
3363
func getNodeFeatures(ctx context.Context, db SQLQueries,
3364
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3365

×
3366
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3367
        if err != nil {
×
3368
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3369
                        nodeID, err)
×
3370
        }
×
3371

3372
        features := lnwire.EmptyFeatureVector()
×
3373
        for _, feature := range rows {
×
3374
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3375
        }
×
3376

3377
        return features, nil
×
3378
}
3379

3380
// upsertNode upserts the node record into the database. If the node already
3381
// exists, then the node's information is updated. If the node doesn't exist,
3382
// then a new node is created. The node's features, addresses and extra TLV
3383
// types are also updated. The node's DB ID is returned.
3384
func upsertNode(ctx context.Context, db SQLQueries,
3385
        node *models.LightningNode) (int64, error) {
×
3386

×
3387
        params := sqlc.UpsertNodeParams{
×
3388
                Version: int16(ProtocolV1),
×
3389
                PubKey:  node.PubKeyBytes[:],
×
3390
        }
×
3391

×
3392
        if node.HaveNodeAnnouncement {
×
3393
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3394
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3395
                params.Alias = sqldb.SQLStr(node.Alias)
×
3396
                params.Signature = node.AuthSigBytes
×
3397
        }
×
3398

3399
        nodeID, err := db.UpsertNode(ctx, params)
×
3400
        if err != nil {
×
3401
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3402
                        err)
×
3403
        }
×
3404

3405
        // We can exit here if we don't have the announcement yet.
3406
        if !node.HaveNodeAnnouncement {
×
3407
                return nodeID, nil
×
3408
        }
×
3409

3410
        // Update the node's features.
3411
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3412
        if err != nil {
×
3413
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3414
        }
×
3415

3416
        // Update the node's addresses.
3417
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3418
        if err != nil {
×
3419
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3420
        }
×
3421

3422
        // Convert the flat extra opaque data into a map of TLV types to
3423
        // values.
3424
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3425
        if err != nil {
×
3426
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3427
                        err)
×
3428
        }
×
3429

3430
        // Update the node's extra signed fields.
3431
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3432
        if err != nil {
×
3433
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3434
        }
×
3435

3436
        return nodeID, nil
×
3437
}
3438

3439
// upsertNodeFeatures updates the node's features node_features table. This
3440
// includes deleting any feature bits no longer present and inserting any new
3441
// feature bits. If the feature bit does not yet exist in the features table,
3442
// then an entry is created in that table first.
3443
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3444
        features *lnwire.FeatureVector) error {
×
3445

×
3446
        // Get any existing features for the node.
×
3447
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3448
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3449
                return err
×
3450
        }
×
3451

3452
        // Copy the nodes latest set of feature bits.
3453
        newFeatures := make(map[int32]struct{})
×
3454
        if features != nil {
×
3455
                for feature := range features.Features() {
×
3456
                        newFeatures[int32(feature)] = struct{}{}
×
3457
                }
×
3458
        }
3459

3460
        // For any current feature that already exists in the DB, remove it from
3461
        // the in-memory map. For any existing feature that does not exist in
3462
        // the in-memory map, delete it from the database.
3463
        for _, feature := range existingFeatures {
×
3464
                // The feature is still present, so there are no updates to be
×
3465
                // made.
×
3466
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3467
                        delete(newFeatures, feature.FeatureBit)
×
3468
                        continue
×
3469
                }
3470

3471
                // The feature is no longer present, so we remove it from the
3472
                // database.
3473
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3474
                        NodeID:     nodeID,
×
3475
                        FeatureBit: feature.FeatureBit,
×
3476
                })
×
3477
                if err != nil {
×
3478
                        return fmt.Errorf("unable to delete node(%d) "+
×
3479
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3480
                                err)
×
3481
                }
×
3482
        }
3483

3484
        // Any remaining entries in newFeatures are new features that need to be
3485
        // added to the database for the first time.
3486
        for feature := range newFeatures {
×
3487
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3488
                        NodeID:     nodeID,
×
3489
                        FeatureBit: feature,
×
3490
                })
×
3491
                if err != nil {
×
3492
                        return fmt.Errorf("unable to insert node(%d) "+
×
3493
                                "feature(%v): %w", nodeID, feature, err)
×
3494
                }
×
3495
        }
3496

3497
        return nil
×
3498
}
3499

3500
// fetchNodeFeatures fetches the features for a node with the given public key.
3501
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3502
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3503

×
3504
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3505
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3506
                        PubKey:  nodePub[:],
×
3507
                        Version: int16(ProtocolV1),
×
3508
                },
×
3509
        )
×
3510
        if err != nil {
×
3511
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3512
                        nodePub, err)
×
3513
        }
×
3514

3515
        features := lnwire.EmptyFeatureVector()
×
3516
        for _, bit := range rows {
×
3517
                features.Set(lnwire.FeatureBit(bit))
×
3518
        }
×
3519

3520
        return features, nil
×
3521
}
3522

3523
// dbAddressType is an enum type that represents the different address types
3524
// that we store in the node_addresses table. The address type determines how
3525
// the address is to be serialised/deserialize.
3526
type dbAddressType uint8
3527

3528
const (
3529
        addressTypeIPv4   dbAddressType = 1
3530
        addressTypeIPv6   dbAddressType = 2
3531
        addressTypeTorV2  dbAddressType = 3
3532
        addressTypeTorV3  dbAddressType = 4
3533
        addressTypeDNS    dbAddressType = 5
3534
        addressTypeOpaque dbAddressType = math.MaxInt8
3535
)
3536

3537
// collectAddressRecords collects the addresses from the provided
3538
// net.Addr slice and returns a map of dbAddressType to a slice of address
3539
// strings.
3540
func collectAddressRecords(addresses []net.Addr) (map[dbAddressType][]string,
3541
        error) {
×
3542

×
3543
        // Copy the nodes latest set of addresses.
×
3544
        newAddresses := map[dbAddressType][]string{
×
3545
                addressTypeIPv4:   {},
×
3546
                addressTypeIPv6:   {},
×
3547
                addressTypeTorV2:  {},
×
3548
                addressTypeTorV3:  {},
×
NEW
3549
                addressTypeDNS:    {},
×
3550
                addressTypeOpaque: {},
×
3551
        }
×
3552
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3553
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3554
        }
×
3555

3556
        for _, address := range addresses {
×
3557
                switch addr := address.(type) {
×
3558
                case *net.TCPAddr:
×
3559
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3560
                                addAddr(addressTypeIPv4, addr)
×
3561
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3562
                                addAddr(addressTypeIPv6, addr)
×
3563
                        } else {
×
3564
                                return nil, fmt.Errorf("unhandled IP "+
×
3565
                                        "address: %v", addr)
×
3566
                        }
×
3567

3568
                case *tor.OnionAddr:
×
3569
                        switch len(addr.OnionService) {
×
3570
                        case tor.V2Len:
×
3571
                                addAddr(addressTypeTorV2, addr)
×
3572
                        case tor.V3Len:
×
3573
                                addAddr(addressTypeTorV3, addr)
×
3574
                        default:
×
3575
                                return nil, fmt.Errorf("invalid length for " +
×
3576
                                        "a tor address")
×
3577
                        }
3578

NEW
3579
                case *lnwire.DNSAddress:
×
NEW
3580
                        addAddr(addressTypeDNS, addr)
×
3581

3582
                case *lnwire.OpaqueAddrs:
×
3583
                        addAddr(addressTypeOpaque, addr)
×
3584

3585
                default:
×
3586
                        return nil, fmt.Errorf("unhandled address type: %T",
×
3587
                                addr)
×
3588
                }
3589
        }
3590

3591
        return newAddresses, nil
×
3592
}
3593

3594
// upsertNodeAddresses updates the node's addresses in the database. This
3595
// includes deleting any existing addresses and inserting the new set of
3596
// addresses. The deletion is necessary since the ordering of the addresses may
3597
// change, and we need to ensure that the database reflects the latest set of
3598
// addresses so that at the time of reconstructing the node announcement, the
3599
// order is preserved and the signature over the message remains valid.
3600
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3601
        addresses []net.Addr) error {
×
3602

×
3603
        // Delete any existing addresses for the node. This is required since
×
3604
        // even if the new set of addresses is the same, the ordering may have
×
3605
        // changed for a given address type.
×
3606
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3607
        if err != nil {
×
3608
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3609
                        nodeID, err)
×
3610
        }
×
3611

3612
        newAddresses, err := collectAddressRecords(addresses)
×
3613
        if err != nil {
×
3614
                return err
×
3615
        }
×
3616

3617
        // Any remaining entries in newAddresses are new addresses that need to
3618
        // be added to the database for the first time.
3619
        for addrType, addrList := range newAddresses {
×
3620
                for position, addr := range addrList {
×
3621
                        err := db.UpsertNodeAddress(
×
3622
                                ctx, sqlc.UpsertNodeAddressParams{
×
3623
                                        NodeID:   nodeID,
×
3624
                                        Type:     int16(addrType),
×
3625
                                        Address:  addr,
×
3626
                                        Position: int32(position),
×
3627
                                },
×
3628
                        )
×
3629
                        if err != nil {
×
3630
                                return fmt.Errorf("unable to insert "+
×
3631
                                        "node(%d) address(%v): %w", nodeID,
×
3632
                                        addr, err)
×
3633
                        }
×
3634
                }
3635
        }
3636

3637
        return nil
×
3638
}
3639

3640
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3641
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
3642
        error) {
×
3643

×
3644
        // GetNodeAddresses ensures that the addresses for a given type are
×
3645
        // returned in the same order as they were inserted.
×
3646
        rows, err := db.GetNodeAddresses(ctx, id)
×
3647
        if err != nil {
×
3648
                return nil, err
×
3649
        }
×
3650

3651
        addresses := make([]net.Addr, 0, len(rows))
×
3652
        for _, row := range rows {
×
3653
                address := row.Address
×
3654

×
3655
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
3656
                if err != nil {
×
3657
                        return nil, fmt.Errorf("unable to parse address "+
×
3658
                                "for node(%d): %v: %w", id, address, err)
×
3659
                }
×
3660

3661
                addresses = append(addresses, addr)
×
3662
        }
3663

3664
        // If we have no addresses, then we'll return nil instead of an
3665
        // empty slice.
3666
        if len(addresses) == 0 {
×
3667
                addresses = nil
×
3668
        }
×
3669

3670
        return addresses, nil
×
3671
}
3672

3673
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3674
// database. This includes updating any existing types, inserting any new types,
3675
// and deleting any types that are no longer present.
3676
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3677
        nodeID int64, extraFields map[uint64][]byte) error {
×
3678

×
3679
        // Get any existing extra signed fields for the node.
×
3680
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3681
        if err != nil {
×
3682
                return err
×
3683
        }
×
3684

3685
        // Make a lookup map of the existing field types so that we can use it
3686
        // to keep track of any fields we should delete.
3687
        m := make(map[uint64]bool)
×
3688
        for _, field := range existingFields {
×
3689
                m[uint64(field.Type)] = true
×
3690
        }
×
3691

3692
        // For all the new fields, we'll upsert them and remove them from the
3693
        // map of existing fields.
3694
        for tlvType, value := range extraFields {
×
3695
                err = db.UpsertNodeExtraType(
×
3696
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3697
                                NodeID: nodeID,
×
3698
                                Type:   int64(tlvType),
×
3699
                                Value:  value,
×
3700
                        },
×
3701
                )
×
3702
                if err != nil {
×
3703
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3704
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3705
                }
×
3706

3707
                // Remove the field from the map of existing fields if it was
3708
                // present.
3709
                delete(m, tlvType)
×
3710
        }
3711

3712
        // For all the fields that are left in the map of existing fields, we'll
3713
        // delete them as they are no longer present in the new set of fields.
3714
        for tlvType := range m {
×
3715
                err = db.DeleteExtraNodeType(
×
3716
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3717
                                NodeID: nodeID,
×
3718
                                Type:   int64(tlvType),
×
3719
                        },
×
3720
                )
×
3721
                if err != nil {
×
3722
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3723
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3724
                }
×
3725
        }
3726

3727
        return nil
×
3728
}
3729

3730
// srcNodeInfo holds the information about the source node of the graph.
3731
type srcNodeInfo struct {
3732
        // id is the DB level ID of the source node entry in the "nodes" table.
3733
        id int64
3734

3735
        // pub is the public key of the source node.
3736
        pub route.Vertex
3737
}
3738

3739
// sourceNode returns the DB node ID and pub key of the source node for the
3740
// specified protocol version.
3741
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3742
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3743

×
3744
        s.srcNodeMu.Lock()
×
3745
        defer s.srcNodeMu.Unlock()
×
3746

×
3747
        // If we already have the source node ID and pub key cached, then
×
3748
        // return them.
×
3749
        if info, ok := s.srcNodes[version]; ok {
×
3750
                return info.id, info.pub, nil
×
3751
        }
×
3752

3753
        var pubKey route.Vertex
×
3754

×
3755
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3756
        if err != nil {
×
3757
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3758
                        err)
×
3759
        }
×
3760

3761
        if len(nodes) == 0 {
×
3762
                return 0, pubKey, ErrSourceNodeNotSet
×
3763
        } else if len(nodes) > 1 {
×
3764
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3765
                        "protocol %s found", version)
×
3766
        }
×
3767

3768
        copy(pubKey[:], nodes[0].PubKey)
×
3769

×
3770
        s.srcNodes[version] = &srcNodeInfo{
×
3771
                id:  nodes[0].NodeID,
×
3772
                pub: pubKey,
×
3773
        }
×
3774

×
3775
        return nodes[0].NodeID, pubKey, nil
×
3776
}
3777

3778
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3779
// This then produces a map from TLV type to value. If the input is not a
3780
// valid TLV stream, then an error is returned.
3781
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3782
        r := bytes.NewReader(data)
×
3783

×
3784
        tlvStream, err := tlv.NewStream()
×
3785
        if err != nil {
×
3786
                return nil, err
×
3787
        }
×
3788

3789
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3790
        // pass it into the P2P decoding variant.
3791
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3792
        if err != nil {
×
3793
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3794
        }
×
3795
        if len(parsedTypes) == 0 {
×
3796
                return nil, nil
×
3797
        }
×
3798

3799
        records := make(map[uint64][]byte)
×
3800
        for k, v := range parsedTypes {
×
3801
                records[uint64(k)] = v
×
3802
        }
×
3803

3804
        return records, nil
×
3805
}
3806

3807
// insertChannel inserts a new channel record into the database.
3808
func insertChannel(ctx context.Context, db SQLQueries,
3809
        edge *models.ChannelEdgeInfo) error {
×
3810

×
3811
        // Make sure that at least a "shell" entry for each node is present in
×
3812
        // the nodes table.
×
3813
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3814
        if err != nil {
×
3815
                return fmt.Errorf("unable to create shell node: %w", err)
×
3816
        }
×
3817

3818
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3819
        if err != nil {
×
3820
                return fmt.Errorf("unable to create shell node: %w", err)
×
3821
        }
×
3822

3823
        var capacity sql.NullInt64
×
3824
        if edge.Capacity != 0 {
×
3825
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3826
        }
×
3827

3828
        createParams := sqlc.CreateChannelParams{
×
3829
                Version:     int16(ProtocolV1),
×
3830
                Scid:        channelIDToBytes(edge.ChannelID),
×
3831
                NodeID1:     node1DBID,
×
3832
                NodeID2:     node2DBID,
×
3833
                Outpoint:    edge.ChannelPoint.String(),
×
3834
                Capacity:    capacity,
×
3835
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3836
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3837
        }
×
3838

×
3839
        if edge.AuthProof != nil {
×
3840
                proof := edge.AuthProof
×
3841

×
3842
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3843
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3844
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3845
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3846
        }
×
3847

3848
        // Insert the new channel record.
3849
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3850
        if err != nil {
×
3851
                return err
×
3852
        }
×
3853

3854
        // Insert any channel features.
3855
        for feature := range edge.Features.Features() {
×
3856
                err = db.InsertChannelFeature(
×
3857
                        ctx, sqlc.InsertChannelFeatureParams{
×
3858
                                ChannelID:  dbChanID,
×
3859
                                FeatureBit: int32(feature),
×
3860
                        },
×
3861
                )
×
3862
                if err != nil {
×
3863
                        return fmt.Errorf("unable to insert channel(%d) "+
×
3864
                                "feature(%v): %w", dbChanID, feature, err)
×
3865
                }
×
3866
        }
3867

3868
        // Finally, insert any extra TLV fields in the channel announcement.
3869
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3870
        if err != nil {
×
3871
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
3872
                        err)
×
3873
        }
×
3874

3875
        for tlvType, value := range extra {
×
3876
                err := db.UpsertChannelExtraType(
×
3877
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
3878
                                ChannelID: dbChanID,
×
3879
                                Type:      int64(tlvType),
×
3880
                                Value:     value,
×
3881
                        },
×
3882
                )
×
3883
                if err != nil {
×
3884
                        return fmt.Errorf("unable to upsert channel(%d) "+
×
3885
                                "extra signed field(%v): %w", edge.ChannelID,
×
3886
                                tlvType, err)
×
3887
                }
×
3888
        }
3889

3890
        return nil
×
3891
}
3892

3893
// maybeCreateShellNode checks if a shell node entry exists for the
3894
// given public key. If it does not exist, then a new shell node entry is
3895
// created. The ID of the node is returned. A shell node only has a protocol
3896
// version and public key persisted.
3897
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3898
        pubKey route.Vertex) (int64, error) {
×
3899

×
3900
        dbNode, err := db.GetNodeByPubKey(
×
3901
                ctx, sqlc.GetNodeByPubKeyParams{
×
3902
                        PubKey:  pubKey[:],
×
3903
                        Version: int16(ProtocolV1),
×
3904
                },
×
3905
        )
×
3906
        // The node exists. Return the ID.
×
3907
        if err == nil {
×
3908
                return dbNode.ID, nil
×
3909
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3910
                return 0, err
×
3911
        }
×
3912

3913
        // Otherwise, the node does not exist, so we create a shell entry for
3914
        // it.
3915
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3916
                Version: int16(ProtocolV1),
×
3917
                PubKey:  pubKey[:],
×
3918
        })
×
3919
        if err != nil {
×
3920
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3921
        }
×
3922

3923
        return id, nil
×
3924
}
3925

3926
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3927
// the database. This includes deleting any existing types and then inserting
3928
// the new types.
3929
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3930
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3931

×
3932
        // Delete all existing extra signed fields for the channel policy.
×
3933
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3934
        if err != nil {
×
3935
                return fmt.Errorf("unable to delete "+
×
3936
                        "existing policy extra signed fields for policy %d: %w",
×
3937
                        chanPolicyID, err)
×
3938
        }
×
3939

3940
        // Insert all new extra signed fields for the channel policy.
3941
        for tlvType, value := range extraFields {
×
3942
                err = db.UpsertChanPolicyExtraType(
×
3943
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
3944
                                ChannelPolicyID: chanPolicyID,
×
3945
                                Type:            int64(tlvType),
×
3946
                                Value:           value,
×
3947
                        },
×
3948
                )
×
3949
                if err != nil {
×
3950
                        return fmt.Errorf("unable to insert "+
×
3951
                                "channel_policy(%d) extra signed field(%v): %w",
×
3952
                                chanPolicyID, tlvType, err)
×
3953
                }
×
3954
        }
3955

3956
        return nil
×
3957
}
3958

3959
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3960
// provided dbChanRow and also fetches any other required information
3961
// to construct the edge info.
3962
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
3963
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
3964
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3965

×
3966
        data, err := batchLoadChannelData(
×
3967
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
3968
        )
×
3969
        if err != nil {
×
3970
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
3971
                        err)
×
3972
        }
×
3973

3974
        return buildEdgeInfoWithBatchData(
×
3975
                cfg.ChainHash, dbChan, node1, node2, data,
×
3976
        )
×
3977
}
3978

3979
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
3980
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
3981
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
3982
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
3983

×
3984
        if dbChan.Version != int16(ProtocolV1) {
×
3985
                return nil, fmt.Errorf("unsupported channel version: %d",
×
3986
                        dbChan.Version)
×
3987
        }
×
3988

3989
        // Use pre-loaded features and extras types.
3990
        fv := lnwire.EmptyFeatureVector()
×
3991
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
3992
                for _, bit := range features {
×
3993
                        fv.Set(lnwire.FeatureBit(bit))
×
3994
                }
×
3995
        }
3996

3997
        var extras map[uint64][]byte
×
3998
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
3999
        if exists {
×
4000
                extras = channelExtras
×
4001
        } else {
×
4002
                extras = make(map[uint64][]byte)
×
4003
        }
×
4004

4005
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4006
        if err != nil {
×
4007
                return nil, err
×
4008
        }
×
4009

4010
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4011
        if err != nil {
×
4012
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4013
                        "fields: %w", err)
×
4014
        }
×
4015
        if recs == nil {
×
4016
                recs = make([]byte, 0)
×
4017
        }
×
4018

4019
        var btcKey1, btcKey2 route.Vertex
×
4020
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4021
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4022

×
4023
        channel := &models.ChannelEdgeInfo{
×
4024
                ChainHash:        chain,
×
4025
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4026
                NodeKey1Bytes:    node1,
×
4027
                NodeKey2Bytes:    node2,
×
4028
                BitcoinKey1Bytes: btcKey1,
×
4029
                BitcoinKey2Bytes: btcKey2,
×
4030
                ChannelPoint:     *op,
×
4031
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4032
                Features:         fv,
×
4033
                ExtraOpaqueData:  recs,
×
4034
        }
×
4035

×
4036
        // We always set all the signatures at the same time, so we can
×
4037
        // safely check if one signature is present to determine if we have the
×
4038
        // rest of the signatures for the auth proof.
×
4039
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4040
                channel.AuthProof = &models.ChannelAuthProof{
×
4041
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4042
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4043
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4044
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4045
                }
×
4046
        }
×
4047

4048
        return channel, nil
×
4049
}
4050

4051
// buildNodeVertices is a helper that converts raw node public keys
4052
// into route.Vertex instances.
4053
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4054
        route.Vertex, error) {
×
4055

×
4056
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4057
        if err != nil {
×
4058
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4059
                        "create vertex from node1 pubkey: %w", err)
×
4060
        }
×
4061

4062
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4063
        if err != nil {
×
4064
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4065
                        "create vertex from node2 pubkey: %w", err)
×
4066
        }
×
4067

4068
        return node1Vertex, node2Vertex, nil
×
4069
}
4070

4071
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4072
// retrieves all the extra info required to build the complete
4073
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4074
// the provided sqlc.GraphChannelPolicy records are nil.
4075
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
4076
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4077
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
4078
        *models.ChannelEdgePolicy, error) {
×
4079

×
4080
        if dbPol1 == nil && dbPol2 == nil {
×
4081
                return nil, nil, nil
×
4082
        }
×
4083

4084
        var policyIDs = make([]int64, 0, 2)
×
4085
        if dbPol1 != nil {
×
4086
                policyIDs = append(policyIDs, dbPol1.ID)
×
4087
        }
×
4088
        if dbPol2 != nil {
×
4089
                policyIDs = append(policyIDs, dbPol2.ID)
×
4090
        }
×
4091

4092
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4093
        if err != nil {
×
4094
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4095
                        "data: %w", err)
×
4096
        }
×
4097

4098
        pol1, err := buildChanPolicyWithBatchData(
×
4099
                dbPol1, channelID, node2, batchData,
×
4100
        )
×
4101
        if err != nil {
×
4102
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4103
        }
×
4104

4105
        pol2, err := buildChanPolicyWithBatchData(
×
4106
                dbPol2, channelID, node1, batchData,
×
4107
        )
×
4108
        if err != nil {
×
4109
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4110
        }
×
4111

4112
        return pol1, pol2, nil
×
4113
}
4114

4115
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4116
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4117
// then nil is returned for it.
4118
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4119
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4120
        *models.CachedEdgePolicy, error) {
×
4121

×
4122
        var p1, p2 *models.CachedEdgePolicy
×
4123
        if dbPol1 != nil {
×
4124
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
4125
                if err != nil {
×
4126
                        return nil, nil, err
×
4127
                }
×
4128

4129
                p1 = models.NewCachedPolicy(policy1)
×
4130
        }
4131
        if dbPol2 != nil {
×
4132
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4133
                if err != nil {
×
4134
                        return nil, nil, err
×
4135
                }
×
4136

4137
                p2 = models.NewCachedPolicy(policy2)
×
4138
        }
4139

4140
        return p1, p2, nil
×
4141
}
4142

4143
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4144
// provided sqlc.GraphChannelPolicy and other required information.
4145
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4146
        extras map[uint64][]byte,
4147
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4148

×
4149
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4150
        if err != nil {
×
4151
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4152
                        "fields: %w", err)
×
4153
        }
×
4154

4155
        var inboundFee fn.Option[lnwire.Fee]
×
4156
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4157
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4158

×
4159
                inboundFee = fn.Some(lnwire.Fee{
×
4160
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4161
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4162
                })
×
4163
        }
×
4164

4165
        return &models.ChannelEdgePolicy{
×
4166
                SigBytes:  dbPolicy.Signature,
×
4167
                ChannelID: channelID,
×
4168
                LastUpdate: time.Unix(
×
4169
                        dbPolicy.LastUpdate.Int64, 0,
×
4170
                ),
×
4171
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4172
                        dbPolicy.MessageFlags,
×
4173
                ),
×
4174
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4175
                        dbPolicy.ChannelFlags,
×
4176
                ),
×
4177
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4178
                MinHTLC: lnwire.MilliSatoshi(
×
4179
                        dbPolicy.MinHtlcMsat,
×
4180
                ),
×
4181
                MaxHTLC: lnwire.MilliSatoshi(
×
4182
                        dbPolicy.MaxHtlcMsat.Int64,
×
4183
                ),
×
4184
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4185
                        dbPolicy.BaseFeeMsat,
×
4186
                ),
×
4187
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4188
                ToNode:                    toNode,
×
4189
                InboundFee:                inboundFee,
×
4190
                ExtraOpaqueData:           recs,
×
4191
        }, nil
×
4192
}
4193

4194
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4195
// row which is expected to be a sqlc type that contains channel policy
4196
// information. It returns two policies, which may be nil if the policy
4197
// information is not present in the row.
4198
//
4199
//nolint:ll,dupl,funlen
4200
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4201
        *sqlc.GraphChannelPolicy, error) {
×
4202

×
4203
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4204
        switch r := row.(type) {
×
4205
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4206
                if r.Policy1Timelock.Valid {
×
4207
                        policy1 = &sqlc.GraphChannelPolicy{
×
4208
                                Timelock:                r.Policy1Timelock.Int32,
×
4209
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4210
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4211
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4212
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4213
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4214
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4215
                                Disabled:                r.Policy1Disabled,
×
4216
                                MessageFlags:            r.Policy1MessageFlags,
×
4217
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4218
                        }
×
4219
                }
×
4220
                if r.Policy2Timelock.Valid {
×
4221
                        policy2 = &sqlc.GraphChannelPolicy{
×
4222
                                Timelock:                r.Policy2Timelock.Int32,
×
4223
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4224
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4225
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4226
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4227
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4228
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4229
                                Disabled:                r.Policy2Disabled,
×
4230
                                MessageFlags:            r.Policy2MessageFlags,
×
4231
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4232
                        }
×
4233
                }
×
4234

4235
                return policy1, policy2, nil
×
4236

4237
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4238
                if r.Policy1ID.Valid {
×
4239
                        policy1 = &sqlc.GraphChannelPolicy{
×
4240
                                ID:                      r.Policy1ID.Int64,
×
4241
                                Version:                 r.Policy1Version.Int16,
×
4242
                                ChannelID:               r.GraphChannel.ID,
×
4243
                                NodeID:                  r.Policy1NodeID.Int64,
×
4244
                                Timelock:                r.Policy1Timelock.Int32,
×
4245
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4246
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4247
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4248
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4249
                                LastUpdate:              r.Policy1LastUpdate,
×
4250
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4251
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4252
                                Disabled:                r.Policy1Disabled,
×
4253
                                MessageFlags:            r.Policy1MessageFlags,
×
4254
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4255
                                Signature:               r.Policy1Signature,
×
4256
                        }
×
4257
                }
×
4258
                if r.Policy2ID.Valid {
×
4259
                        policy2 = &sqlc.GraphChannelPolicy{
×
4260
                                ID:                      r.Policy2ID.Int64,
×
4261
                                Version:                 r.Policy2Version.Int16,
×
4262
                                ChannelID:               r.GraphChannel.ID,
×
4263
                                NodeID:                  r.Policy2NodeID.Int64,
×
4264
                                Timelock:                r.Policy2Timelock.Int32,
×
4265
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4266
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4267
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4268
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4269
                                LastUpdate:              r.Policy2LastUpdate,
×
4270
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4271
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4272
                                Disabled:                r.Policy2Disabled,
×
4273
                                MessageFlags:            r.Policy2MessageFlags,
×
4274
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4275
                                Signature:               r.Policy2Signature,
×
4276
                        }
×
4277
                }
×
4278

4279
                return policy1, policy2, nil
×
4280

4281
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4282
                if r.Policy1ID.Valid {
×
4283
                        policy1 = &sqlc.GraphChannelPolicy{
×
4284
                                ID:                      r.Policy1ID.Int64,
×
4285
                                Version:                 r.Policy1Version.Int16,
×
4286
                                ChannelID:               r.GraphChannel.ID,
×
4287
                                NodeID:                  r.Policy1NodeID.Int64,
×
4288
                                Timelock:                r.Policy1Timelock.Int32,
×
4289
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4290
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4291
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4292
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4293
                                LastUpdate:              r.Policy1LastUpdate,
×
4294
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4295
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4296
                                Disabled:                r.Policy1Disabled,
×
4297
                                MessageFlags:            r.Policy1MessageFlags,
×
4298
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4299
                                Signature:               r.Policy1Signature,
×
4300
                        }
×
4301
                }
×
4302
                if r.Policy2ID.Valid {
×
4303
                        policy2 = &sqlc.GraphChannelPolicy{
×
4304
                                ID:                      r.Policy2ID.Int64,
×
4305
                                Version:                 r.Policy2Version.Int16,
×
4306
                                ChannelID:               r.GraphChannel.ID,
×
4307
                                NodeID:                  r.Policy2NodeID.Int64,
×
4308
                                Timelock:                r.Policy2Timelock.Int32,
×
4309
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4310
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4311
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4312
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4313
                                LastUpdate:              r.Policy2LastUpdate,
×
4314
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4315
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4316
                                Disabled:                r.Policy2Disabled,
×
4317
                                MessageFlags:            r.Policy2MessageFlags,
×
4318
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4319
                                Signature:               r.Policy2Signature,
×
4320
                        }
×
4321
                }
×
4322

4323
                return policy1, policy2, nil
×
4324

4325
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4326
                if r.Policy1ID.Valid {
×
4327
                        policy1 = &sqlc.GraphChannelPolicy{
×
4328
                                ID:                      r.Policy1ID.Int64,
×
4329
                                Version:                 r.Policy1Version.Int16,
×
4330
                                ChannelID:               r.GraphChannel.ID,
×
4331
                                NodeID:                  r.Policy1NodeID.Int64,
×
4332
                                Timelock:                r.Policy1Timelock.Int32,
×
4333
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4334
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4335
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4336
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4337
                                LastUpdate:              r.Policy1LastUpdate,
×
4338
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4339
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4340
                                Disabled:                r.Policy1Disabled,
×
4341
                                MessageFlags:            r.Policy1MessageFlags,
×
4342
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4343
                                Signature:               r.Policy1Signature,
×
4344
                        }
×
4345
                }
×
4346
                if r.Policy2ID.Valid {
×
4347
                        policy2 = &sqlc.GraphChannelPolicy{
×
4348
                                ID:                      r.Policy2ID.Int64,
×
4349
                                Version:                 r.Policy2Version.Int16,
×
4350
                                ChannelID:               r.GraphChannel.ID,
×
4351
                                NodeID:                  r.Policy2NodeID.Int64,
×
4352
                                Timelock:                r.Policy2Timelock.Int32,
×
4353
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4354
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4355
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4356
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4357
                                LastUpdate:              r.Policy2LastUpdate,
×
4358
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4359
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4360
                                Disabled:                r.Policy2Disabled,
×
4361
                                MessageFlags:            r.Policy2MessageFlags,
×
4362
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4363
                                Signature:               r.Policy2Signature,
×
4364
                        }
×
4365
                }
×
4366

4367
                return policy1, policy2, nil
×
4368

4369
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4370
                if r.Policy1ID.Valid {
×
4371
                        policy1 = &sqlc.GraphChannelPolicy{
×
4372
                                ID:                      r.Policy1ID.Int64,
×
4373
                                Version:                 r.Policy1Version.Int16,
×
4374
                                ChannelID:               r.GraphChannel.ID,
×
4375
                                NodeID:                  r.Policy1NodeID.Int64,
×
4376
                                Timelock:                r.Policy1Timelock.Int32,
×
4377
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4378
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4379
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4380
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4381
                                LastUpdate:              r.Policy1LastUpdate,
×
4382
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4383
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4384
                                Disabled:                r.Policy1Disabled,
×
4385
                                MessageFlags:            r.Policy1MessageFlags,
×
4386
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4387
                                Signature:               r.Policy1Signature,
×
4388
                        }
×
4389
                }
×
4390
                if r.Policy2ID.Valid {
×
4391
                        policy2 = &sqlc.GraphChannelPolicy{
×
4392
                                ID:                      r.Policy2ID.Int64,
×
4393
                                Version:                 r.Policy2Version.Int16,
×
4394
                                ChannelID:               r.GraphChannel.ID,
×
4395
                                NodeID:                  r.Policy2NodeID.Int64,
×
4396
                                Timelock:                r.Policy2Timelock.Int32,
×
4397
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4398
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4399
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4400
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4401
                                LastUpdate:              r.Policy2LastUpdate,
×
4402
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4403
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4404
                                Disabled:                r.Policy2Disabled,
×
4405
                                MessageFlags:            r.Policy2MessageFlags,
×
4406
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4407
                                Signature:               r.Policy2Signature,
×
4408
                        }
×
4409
                }
×
4410

4411
                return policy1, policy2, nil
×
4412

4413
        case sqlc.ListChannelsForNodeIDsRow:
×
4414
                if r.Policy1ID.Valid {
×
4415
                        policy1 = &sqlc.GraphChannelPolicy{
×
4416
                                ID:                      r.Policy1ID.Int64,
×
4417
                                Version:                 r.Policy1Version.Int16,
×
4418
                                ChannelID:               r.GraphChannel.ID,
×
4419
                                NodeID:                  r.Policy1NodeID.Int64,
×
4420
                                Timelock:                r.Policy1Timelock.Int32,
×
4421
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4422
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4423
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4424
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4425
                                LastUpdate:              r.Policy1LastUpdate,
×
4426
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4427
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4428
                                Disabled:                r.Policy1Disabled,
×
4429
                                MessageFlags:            r.Policy1MessageFlags,
×
4430
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4431
                                Signature:               r.Policy1Signature,
×
4432
                        }
×
4433
                }
×
4434
                if r.Policy2ID.Valid {
×
4435
                        policy2 = &sqlc.GraphChannelPolicy{
×
4436
                                ID:                      r.Policy2ID.Int64,
×
4437
                                Version:                 r.Policy2Version.Int16,
×
4438
                                ChannelID:               r.GraphChannel.ID,
×
4439
                                NodeID:                  r.Policy2NodeID.Int64,
×
4440
                                Timelock:                r.Policy2Timelock.Int32,
×
4441
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4442
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4443
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4444
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4445
                                LastUpdate:              r.Policy2LastUpdate,
×
4446
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4447
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4448
                                Disabled:                r.Policy2Disabled,
×
4449
                                MessageFlags:            r.Policy2MessageFlags,
×
4450
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4451
                                Signature:               r.Policy2Signature,
×
4452
                        }
×
4453
                }
×
4454

4455
                return policy1, policy2, nil
×
4456

4457
        case sqlc.ListChannelsByNodeIDRow:
×
4458
                if r.Policy1ID.Valid {
×
4459
                        policy1 = &sqlc.GraphChannelPolicy{
×
4460
                                ID:                      r.Policy1ID.Int64,
×
4461
                                Version:                 r.Policy1Version.Int16,
×
4462
                                ChannelID:               r.GraphChannel.ID,
×
4463
                                NodeID:                  r.Policy1NodeID.Int64,
×
4464
                                Timelock:                r.Policy1Timelock.Int32,
×
4465
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4466
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4467
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4468
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4469
                                LastUpdate:              r.Policy1LastUpdate,
×
4470
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4471
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4472
                                Disabled:                r.Policy1Disabled,
×
4473
                                MessageFlags:            r.Policy1MessageFlags,
×
4474
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4475
                                Signature:               r.Policy1Signature,
×
4476
                        }
×
4477
                }
×
4478
                if r.Policy2ID.Valid {
×
4479
                        policy2 = &sqlc.GraphChannelPolicy{
×
4480
                                ID:                      r.Policy2ID.Int64,
×
4481
                                Version:                 r.Policy2Version.Int16,
×
4482
                                ChannelID:               r.GraphChannel.ID,
×
4483
                                NodeID:                  r.Policy2NodeID.Int64,
×
4484
                                Timelock:                r.Policy2Timelock.Int32,
×
4485
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4486
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4487
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4488
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4489
                                LastUpdate:              r.Policy2LastUpdate,
×
4490
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4491
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4492
                                Disabled:                r.Policy2Disabled,
×
4493
                                MessageFlags:            r.Policy2MessageFlags,
×
4494
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4495
                                Signature:               r.Policy2Signature,
×
4496
                        }
×
4497
                }
×
4498

4499
                return policy1, policy2, nil
×
4500

4501
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4502
                if r.Policy1ID.Valid {
×
4503
                        policy1 = &sqlc.GraphChannelPolicy{
×
4504
                                ID:                      r.Policy1ID.Int64,
×
4505
                                Version:                 r.Policy1Version.Int16,
×
4506
                                ChannelID:               r.GraphChannel.ID,
×
4507
                                NodeID:                  r.Policy1NodeID.Int64,
×
4508
                                Timelock:                r.Policy1Timelock.Int32,
×
4509
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4510
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4511
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4512
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4513
                                LastUpdate:              r.Policy1LastUpdate,
×
4514
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4515
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4516
                                Disabled:                r.Policy1Disabled,
×
4517
                                MessageFlags:            r.Policy1MessageFlags,
×
4518
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4519
                                Signature:               r.Policy1Signature,
×
4520
                        }
×
4521
                }
×
4522
                if r.Policy2ID.Valid {
×
4523
                        policy2 = &sqlc.GraphChannelPolicy{
×
4524
                                ID:                      r.Policy2ID.Int64,
×
4525
                                Version:                 r.Policy2Version.Int16,
×
4526
                                ChannelID:               r.GraphChannel.ID,
×
4527
                                NodeID:                  r.Policy2NodeID.Int64,
×
4528
                                Timelock:                r.Policy2Timelock.Int32,
×
4529
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4530
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4531
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4532
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4533
                                LastUpdate:              r.Policy2LastUpdate,
×
4534
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4535
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4536
                                Disabled:                r.Policy2Disabled,
×
4537
                                MessageFlags:            r.Policy2MessageFlags,
×
4538
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4539
                                Signature:               r.Policy2Signature,
×
4540
                        }
×
4541
                }
×
4542

4543
                return policy1, policy2, nil
×
4544

4545
        case sqlc.GetChannelsByIDsRow:
×
4546
                if r.Policy1ID.Valid {
×
4547
                        policy1 = &sqlc.GraphChannelPolicy{
×
4548
                                ID:                      r.Policy1ID.Int64,
×
4549
                                Version:                 r.Policy1Version.Int16,
×
4550
                                ChannelID:               r.GraphChannel.ID,
×
4551
                                NodeID:                  r.Policy1NodeID.Int64,
×
4552
                                Timelock:                r.Policy1Timelock.Int32,
×
4553
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4554
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4555
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4556
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4557
                                LastUpdate:              r.Policy1LastUpdate,
×
4558
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4559
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4560
                                Disabled:                r.Policy1Disabled,
×
4561
                                MessageFlags:            r.Policy1MessageFlags,
×
4562
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4563
                                Signature:               r.Policy1Signature,
×
4564
                        }
×
4565
                }
×
4566
                if r.Policy2ID.Valid {
×
4567
                        policy2 = &sqlc.GraphChannelPolicy{
×
4568
                                ID:                      r.Policy2ID.Int64,
×
4569
                                Version:                 r.Policy2Version.Int16,
×
4570
                                ChannelID:               r.GraphChannel.ID,
×
4571
                                NodeID:                  r.Policy2NodeID.Int64,
×
4572
                                Timelock:                r.Policy2Timelock.Int32,
×
4573
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4574
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4575
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4576
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4577
                                LastUpdate:              r.Policy2LastUpdate,
×
4578
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4579
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4580
                                Disabled:                r.Policy2Disabled,
×
4581
                                MessageFlags:            r.Policy2MessageFlags,
×
4582
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4583
                                Signature:               r.Policy2Signature,
×
4584
                        }
×
4585
                }
×
4586

4587
                return policy1, policy2, nil
×
4588

4589
        default:
×
4590
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4591
                        "extractChannelPolicies: %T", r)
×
4592
        }
4593
}
4594

4595
// channelIDToBytes converts a channel ID (SCID) to a byte array
4596
// representation.
4597
func channelIDToBytes(channelID uint64) []byte {
×
4598
        var chanIDB [8]byte
×
4599
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4600

×
4601
        return chanIDB[:]
×
4602
}
×
4603

4604
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4605
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4606
        if len(addresses) == 0 {
×
4607
                return nil, nil
×
4608
        }
×
4609

4610
        result := make([]net.Addr, 0, len(addresses))
×
4611
        for _, addr := range addresses {
×
4612
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
4613
                if err != nil {
×
4614
                        return nil, fmt.Errorf("unable to parse address %s "+
×
4615
                                "of type %d: %w", addr.address, addr.addrType,
×
4616
                                err)
×
4617
                }
×
4618
                if netAddr != nil {
×
4619
                        result = append(result, netAddr)
×
4620
                }
×
4621
        }
4622

4623
        // If we have no valid addresses, return nil instead of empty slice.
4624
        if len(result) == 0 {
×
4625
                return nil, nil
×
4626
        }
×
4627

4628
        return result, nil
×
4629
}
4630

4631
// parseAddress parses the given address string based on the address type
4632
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4633
// and opaque addresses.
4634
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4635
        switch addrType {
×
4636
        case addressTypeIPv4:
×
4637
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4638
                if err != nil {
×
4639
                        return nil, err
×
4640
                }
×
4641

4642
                tcp.IP = tcp.IP.To4()
×
4643

×
4644
                return tcp, nil
×
4645

4646
        case addressTypeIPv6:
×
4647
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
4648
                if err != nil {
×
4649
                        return nil, err
×
4650
                }
×
4651

4652
                return tcp, nil
×
4653

4654
        case addressTypeTorV3, addressTypeTorV2:
×
4655
                service, portStr, err := net.SplitHostPort(address)
×
4656
                if err != nil {
×
4657
                        return nil, fmt.Errorf("unable to split tor "+
×
4658
                                "address: %v", address)
×
4659
                }
×
4660

4661
                port, err := strconv.Atoi(portStr)
×
4662
                if err != nil {
×
4663
                        return nil, err
×
4664
                }
×
4665

4666
                return &tor.OnionAddr{
×
4667
                        OnionService: service,
×
4668
                        Port:         port,
×
4669
                }, nil
×
4670

NEW
4671
        case addressTypeDNS:
×
NEW
4672
                hostname, portStr, err := net.SplitHostPort(address)
×
NEW
4673
                if err != nil {
×
NEW
4674
                        return nil, fmt.Errorf("unable to split DNS "+
×
NEW
4675
                                "address: %v", address)
×
NEW
4676
                }
×
4677

NEW
4678
                port, err := strconv.Atoi(portStr)
×
NEW
4679
                if err != nil {
×
NEW
4680
                        return nil, err
×
NEW
4681
                }
×
4682

NEW
4683
                return &lnwire.DNSAddress{
×
NEW
4684
                        Hostname: hostname,
×
NEW
4685
                        Port:     uint16(port),
×
NEW
4686
                }, nil
×
4687

4688
        case addressTypeOpaque:
×
4689
                opaque, err := hex.DecodeString(address)
×
4690
                if err != nil {
×
4691
                        return nil, fmt.Errorf("unable to decode opaque "+
×
4692
                                "address: %v", address)
×
4693
                }
×
4694

4695
                return &lnwire.OpaqueAddrs{
×
4696
                        Payload: opaque,
×
4697
                }, nil
×
4698

4699
        default:
×
4700
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
4701
        }
4702
}
4703

4704
// batchNodeData holds all the related data for a batch of nodes.
4705
type batchNodeData struct {
4706
        // features is a map from a DB node ID to the feature bits for that
4707
        // node.
4708
        features map[int64][]int
4709

4710
        // addresses is a map from a DB node ID to the node's addresses.
4711
        addresses map[int64][]nodeAddress
4712

4713
        // extraFields is a map from a DB node ID to the extra signed fields
4714
        // for that node.
4715
        extraFields map[int64]map[uint64][]byte
4716
}
4717

4718
// nodeAddress holds the address type, position and address string for a
4719
// node. This is used to batch the fetching of node addresses.
4720
type nodeAddress struct {
4721
        addrType dbAddressType
4722
        position int32
4723
        address  string
4724
}
4725

4726
// batchLoadNodeData loads all related data for a batch of node IDs using the
4727
// provided SQLQueries interface. It returns a batchNodeData instance containing
4728
// the node features, addresses and extra signed fields.
4729
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
4730
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
4731

×
4732
        // Batch load the node features.
×
4733
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
4734
        if err != nil {
×
4735
                return nil, fmt.Errorf("unable to batch load node "+
×
4736
                        "features: %w", err)
×
4737
        }
×
4738

4739
        // Batch load the node addresses.
4740
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
4741
        if err != nil {
×
4742
                return nil, fmt.Errorf("unable to batch load node "+
×
4743
                        "addresses: %w", err)
×
4744
        }
×
4745

4746
        // Batch load the node extra signed fields.
4747
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
4748
        if err != nil {
×
4749
                return nil, fmt.Errorf("unable to batch load node extra "+
×
4750
                        "signed fields: %w", err)
×
4751
        }
×
4752

4753
        return &batchNodeData{
×
4754
                features:    features,
×
4755
                addresses:   addrs,
×
4756
                extraFields: extraTypes,
×
4757
        }, nil
×
4758
}
4759

4760
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
4761
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
4762
func batchLoadNodeFeaturesHelper(ctx context.Context,
4763
        cfg *sqldb.QueryConfig, db SQLQueries,
4764
        nodeIDs []int64) (map[int64][]int, error) {
×
4765

×
4766
        features := make(map[int64][]int)
×
4767

×
4768
        return features, sqldb.ExecuteBatchQuery(
×
4769
                ctx, cfg, nodeIDs,
×
4770
                func(id int64) int64 {
×
4771
                        return id
×
4772
                },
×
4773
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
4774
                        error) {
×
4775

×
4776
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
4777
                },
×
4778
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
4779
                        features[feature.NodeID] = append(
×
4780
                                features[feature.NodeID],
×
4781
                                int(feature.FeatureBit),
×
4782
                        )
×
4783

×
4784
                        return nil
×
4785
                },
×
4786
        )
4787
}
4788

4789
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
4790
// wrapper around the GetNodeAddressesBatch query. It returns a map from
4791
// node ID to a slice of nodeAddress structs.
4792
func batchLoadNodeAddressesHelper(ctx context.Context,
4793
        cfg *sqldb.QueryConfig, db SQLQueries,
4794
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
4795

×
4796
        addrs := make(map[int64][]nodeAddress)
×
4797

×
4798
        return addrs, sqldb.ExecuteBatchQuery(
×
4799
                ctx, cfg, nodeIDs,
×
4800
                func(id int64) int64 {
×
4801
                        return id
×
4802
                },
×
4803
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
4804
                        error) {
×
4805

×
4806
                        return db.GetNodeAddressesBatch(ctx, ids)
×
4807
                },
×
4808
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
4809
                        addrs[addr.NodeID] = append(
×
4810
                                addrs[addr.NodeID], nodeAddress{
×
4811
                                        addrType: dbAddressType(addr.Type),
×
4812
                                        position: addr.Position,
×
4813
                                        address:  addr.Address,
×
4814
                                },
×
4815
                        )
×
4816

×
4817
                        return nil
×
4818
                },
×
4819
        )
4820
}
4821

4822
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
4823
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
4824
// query.
4825
func batchLoadNodeExtraTypesHelper(ctx context.Context,
4826
        cfg *sqldb.QueryConfig, db SQLQueries,
4827
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4828

×
4829
        extraFields := make(map[int64]map[uint64][]byte)
×
4830

×
4831
        callback := func(ctx context.Context,
×
4832
                field sqlc.GraphNodeExtraType) error {
×
4833

×
4834
                if extraFields[field.NodeID] == nil {
×
4835
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
4836
                }
×
4837
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
4838

×
4839
                return nil
×
4840
        }
4841

4842
        return extraFields, sqldb.ExecuteBatchQuery(
×
4843
                ctx, cfg, nodeIDs,
×
4844
                func(id int64) int64 {
×
4845
                        return id
×
4846
                },
×
4847
                func(ctx context.Context, ids []int64) (
4848
                        []sqlc.GraphNodeExtraType, error) {
×
4849

×
4850
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
4851
                },
×
4852
                callback,
4853
        )
4854
}
4855

4856
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
4857
// from the provided sqlc.GraphChannelPolicy records and the
4858
// provided batchChannelData.
4859
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4860
        channelID uint64, node1, node2 route.Vertex,
4861
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
4862
        *models.ChannelEdgePolicy, error) {
×
4863

×
4864
        pol1, err := buildChanPolicyWithBatchData(
×
4865
                dbPol1, channelID, node2, batchData,
×
4866
        )
×
4867
        if err != nil {
×
4868
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4869
        }
×
4870

4871
        pol2, err := buildChanPolicyWithBatchData(
×
4872
                dbPol2, channelID, node1, batchData,
×
4873
        )
×
4874
        if err != nil {
×
4875
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4876
        }
×
4877

4878
        return pol1, pol2, nil
×
4879
}
4880

4881
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
4882
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
4883
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
4884
        channelID uint64, toNode route.Vertex,
4885
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
4886

×
4887
        if dbPol == nil {
×
4888
                return nil, nil
×
4889
        }
×
4890

4891
        var dbPol1Extras map[uint64][]byte
×
4892
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
4893
                dbPol1Extras = extras
×
4894
        } else {
×
4895
                dbPol1Extras = make(map[uint64][]byte)
×
4896
        }
×
4897

4898
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
4899
}
4900

4901
// batchChannelData holds all the related data for a batch of channels.
4902
type batchChannelData struct {
4903
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
4904
        chanfeatures map[int64][]int
4905

4906
        // chanExtras is a map from DB channel ID to a map of TLV type to
4907
        // extra signed field bytes.
4908
        chanExtraTypes map[int64]map[uint64][]byte
4909

4910
        // policyExtras is a map from DB channel policy ID to a map of TLV type
4911
        // to extra signed field bytes.
4912
        policyExtras map[int64]map[uint64][]byte
4913
}
4914

4915
// batchLoadChannelData loads all related data for batches of channels and
4916
// policies.
4917
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
4918
        db SQLQueries, channelIDs []int64,
4919
        policyIDs []int64) (*batchChannelData, error) {
×
4920

×
4921
        batchData := &batchChannelData{
×
4922
                chanfeatures:   make(map[int64][]int),
×
4923
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
4924
                policyExtras:   make(map[int64]map[uint64][]byte),
×
4925
        }
×
4926

×
4927
        // Batch load channel features and extras
×
4928
        var err error
×
4929
        if len(channelIDs) > 0 {
×
4930
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
4931
                        ctx, cfg, db, channelIDs,
×
4932
                )
×
4933
                if err != nil {
×
4934
                        return nil, fmt.Errorf("unable to batch load "+
×
4935
                                "channel features: %w", err)
×
4936
                }
×
4937

4938
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
4939
                        ctx, cfg, db, channelIDs,
×
4940
                )
×
4941
                if err != nil {
×
4942
                        return nil, fmt.Errorf("unable to batch load "+
×
4943
                                "channel extras: %w", err)
×
4944
                }
×
4945
        }
4946

4947
        if len(policyIDs) > 0 {
×
4948
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
4949
                        ctx, cfg, db, policyIDs,
×
4950
                )
×
4951
                if err != nil {
×
4952
                        return nil, fmt.Errorf("unable to batch load "+
×
4953
                                "policy extras: %w", err)
×
4954
                }
×
4955
                batchData.policyExtras = policyExtras
×
4956
        }
4957

4958
        return batchData, nil
×
4959
}
4960

4961
// batchLoadChannelFeaturesHelper loads channel features for a batch of
4962
// channel IDs using ExecuteBatchQuery wrapper around the
4963
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
4964
// slice of feature bits.
4965
func batchLoadChannelFeaturesHelper(ctx context.Context,
4966
        cfg *sqldb.QueryConfig, db SQLQueries,
4967
        channelIDs []int64) (map[int64][]int, error) {
×
4968

×
4969
        features := make(map[int64][]int)
×
4970

×
4971
        return features, sqldb.ExecuteBatchQuery(
×
4972
                ctx, cfg, channelIDs,
×
4973
                func(id int64) int64 {
×
4974
                        return id
×
4975
                },
×
4976
                func(ctx context.Context,
4977
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
4978

×
4979
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
4980
                },
×
4981
                func(ctx context.Context,
4982
                        feature sqlc.GraphChannelFeature) error {
×
4983

×
4984
                        features[feature.ChannelID] = append(
×
4985
                                features[feature.ChannelID],
×
4986
                                int(feature.FeatureBit),
×
4987
                        )
×
4988

×
4989
                        return nil
×
4990
                },
×
4991
        )
4992
}
4993

4994
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
4995
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
4996
// query. It returns a map from DB channel ID to a map of TLV type to extra
4997
// signed field bytes.
4998
func batchLoadChannelExtrasHelper(ctx context.Context,
4999
        cfg *sqldb.QueryConfig, db SQLQueries,
5000
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5001

×
5002
        extras := make(map[int64]map[uint64][]byte)
×
5003

×
5004
        cb := func(ctx context.Context,
×
5005
                extra sqlc.GraphChannelExtraType) error {
×
5006

×
5007
                if extras[extra.ChannelID] == nil {
×
5008
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5009
                }
×
5010
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
5011

×
5012
                return nil
×
5013
        }
5014

5015
        return extras, sqldb.ExecuteBatchQuery(
×
5016
                ctx, cfg, channelIDs,
×
5017
                func(id int64) int64 {
×
5018
                        return id
×
5019
                },
×
5020
                func(ctx context.Context,
5021
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5022

×
5023
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5024
                }, cb,
×
5025
        )
5026
}
5027

5028
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5029
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5030
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5031
// a map of TLV type to extra signed field bytes.
5032
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5033
        cfg *sqldb.QueryConfig, db SQLQueries,
5034
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5035

×
5036
        extras := make(map[int64]map[uint64][]byte)
×
5037

×
5038
        return extras, sqldb.ExecuteBatchQuery(
×
5039
                ctx, cfg, policyIDs,
×
5040
                func(id int64) int64 {
×
5041
                        return id
×
5042
                },
×
5043
                func(ctx context.Context, ids []int64) (
5044
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5045

×
5046
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5047
                },
×
5048
                func(ctx context.Context,
5049
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5050

×
5051
                        if extras[row.PolicyID] == nil {
×
5052
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5053
                        }
×
5054
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5055

×
5056
                        return nil
×
5057
                },
5058
        )
5059
}
5060

5061
// forEachNodePaginated executes a paginated query to process each node in the
5062
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5063
// and applies the provided processNode function to each node.
5064
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5065
        db SQLQueries, protocol ProtocolVersion,
5066
        processNode func(context.Context, int64,
5067
                *models.LightningNode) error) error {
×
5068

×
5069
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5070
                limit int32) ([]sqlc.GraphNode, error) {
×
5071

×
5072
                return db.ListNodesPaginated(
×
5073
                        ctx, sqlc.ListNodesPaginatedParams{
×
5074
                                Version: int16(protocol),
×
5075
                                ID:      lastID,
×
5076
                                Limit:   limit,
×
5077
                        },
×
5078
                )
×
5079
        }
×
5080

5081
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5082
                return node.ID
×
5083
        }
×
5084

5085
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5086
                return node.ID, nil
×
5087
        }
×
5088

5089
        batchQueryFunc := func(ctx context.Context,
×
5090
                nodeIDs []int64) (*batchNodeData, error) {
×
5091

×
5092
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5093
        }
×
5094

5095
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5096
                batchData *batchNodeData) error {
×
5097

×
5098
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
5099
                if err != nil {
×
5100
                        return fmt.Errorf("unable to build "+
×
5101
                                "node(id=%d): %w", dbNode.ID, err)
×
5102
                }
×
5103

5104
                return processNode(ctx, dbNode.ID, node)
×
5105
        }
5106

5107
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5108
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5109
                collectFunc, batchQueryFunc, processItem,
×
5110
        )
×
5111
}
5112

5113
// forEachChannelWithPolicies executes a paginated query to process each channel
5114
// with policies in the graph.
5115
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5116
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5117
                *models.ChannelEdgePolicy,
5118
                *models.ChannelEdgePolicy) error) error {
×
5119

×
5120
        type channelBatchIDs struct {
×
5121
                channelID int64
×
5122
                policyIDs []int64
×
5123
        }
×
5124

×
5125
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5126
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5127
                error) {
×
5128

×
5129
                return db.ListChannelsWithPoliciesPaginated(
×
5130
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
5131
                                Version: int16(ProtocolV1),
×
5132
                                ID:      lastID,
×
5133
                                Limit:   limit,
×
5134
                        },
×
5135
                )
×
5136
        }
×
5137

5138
        extractPageCursor := func(
×
5139
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5140

×
5141
                return row.GraphChannel.ID
×
5142
        }
×
5143

5144
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
5145
                channelBatchIDs, error) {
×
5146

×
5147
                ids := channelBatchIDs{
×
5148
                        channelID: row.GraphChannel.ID,
×
5149
                }
×
5150

×
5151
                // Extract policy IDs from the row.
×
5152
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5153
                if err != nil {
×
5154
                        return ids, err
×
5155
                }
×
5156

5157
                if dbPol1 != nil {
×
5158
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5159
                }
×
5160
                if dbPol2 != nil {
×
5161
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5162
                }
×
5163

5164
                return ids, nil
×
5165
        }
5166

5167
        batchDataFunc := func(ctx context.Context,
×
5168
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5169

×
5170
                // Separate channel IDs from policy IDs.
×
5171
                var (
×
5172
                        channelIDs = make([]int64, len(allIDs))
×
5173
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5174
                )
×
5175

×
5176
                for i, ids := range allIDs {
×
5177
                        channelIDs[i] = ids.channelID
×
5178
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5179
                }
×
5180

5181
                return batchLoadChannelData(
×
5182
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5183
                )
×
5184
        }
5185

5186
        processItem := func(ctx context.Context,
×
5187
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5188
                batchData *batchChannelData) error {
×
5189

×
5190
                node1, node2, err := buildNodeVertices(
×
5191
                        row.Node1Pubkey, row.Node2Pubkey,
×
5192
                )
×
5193
                if err != nil {
×
5194
                        return err
×
5195
                }
×
5196

5197
                edge, err := buildEdgeInfoWithBatchData(
×
5198
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5199
                        batchData,
×
5200
                )
×
5201
                if err != nil {
×
5202
                        return fmt.Errorf("unable to build channel info: %w",
×
5203
                                err)
×
5204
                }
×
5205

5206
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5207
                if err != nil {
×
5208
                        return err
×
5209
                }
×
5210

5211
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5212
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
5213
                )
×
5214
                if err != nil {
×
5215
                        return err
×
5216
                }
×
5217

5218
                return processChannel(edge, p1, p2)
×
5219
        }
5220

5221
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5222
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5223
                collectFunc, batchDataFunc, processItem,
×
5224
        )
×
5225
}
5226

5227
// buildDirectedChannel builds a DirectedChannel instance from the provided
5228
// data.
5229
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5230
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5231
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5232
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
5233

×
5234
        node1, node2, err := buildNodeVertices(
×
5235
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5236
        )
×
5237
        if err != nil {
×
5238
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5239
        }
×
5240

5241
        edge, err := buildEdgeInfoWithBatchData(
×
5242
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5243
        )
×
5244
        if err != nil {
×
5245
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5246
        }
×
5247

5248
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
5249
        if err != nil {
×
5250
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5251
                        err)
×
5252
        }
×
5253

5254
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5255
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5256
                channelBatchData,
×
5257
        )
×
5258
        if err != nil {
×
5259
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5260
                        err)
×
5261
        }
×
5262

5263
        // Determine outgoing and incoming policy for this specific node.
5264
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5265
        p2ToNode := channelRow.GraphChannel.NodeID1
×
5266
        outPolicy, inPolicy := p1, p2
×
5267
        if (p1 != nil && p1ToNode == nodeID) ||
×
5268
                (p2 != nil && p2ToNode != nodeID) {
×
5269

×
5270
                outPolicy, inPolicy = p2, p1
×
5271
        }
×
5272

5273
        // Build cached policy.
5274
        var cachedInPolicy *models.CachedEdgePolicy
×
5275
        if inPolicy != nil {
×
5276
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5277
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5278
                cachedInPolicy.ToNodeFeatures = features
×
5279
        }
×
5280

5281
        // Extract inbound fee.
5282
        var inboundFee lnwire.Fee
×
5283
        if outPolicy != nil {
×
5284
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5285
                        inboundFee = fee
×
5286
                })
×
5287
        }
5288

5289
        // Build directed channel.
5290
        directedChannel := &DirectedChannel{
×
5291
                ChannelID:    edge.ChannelID,
×
5292
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
5293
                OtherNode:    edge.NodeKey2Bytes,
×
5294
                Capacity:     edge.Capacity,
×
5295
                OutPolicySet: outPolicy != nil,
×
5296
                InPolicy:     cachedInPolicy,
×
5297
                InboundFee:   inboundFee,
×
5298
        }
×
5299

×
5300
        if nodePub == edge.NodeKey2Bytes {
×
5301
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5302
        }
×
5303

5304
        return directedChannel, nil
×
5305
}
5306

5307
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
5308
// provided rows. It uses batch loading for channels, policies, and nodes.
5309
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
5310
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
5311

×
5312
        var (
×
5313
                channelIDs = make([]int64, len(rows))
×
5314
                policyIDs  = make([]int64, 0, len(rows)*2)
×
5315
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
5316

×
5317
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
5318
                nodeIDSet = make(map[int64]bool)
×
5319

×
5320
                // edges will hold the final channel edges built from the rows.
×
5321
                edges = make([]ChannelEdge, 0, len(rows))
×
5322
        )
×
5323

×
5324
        // Collect all IDs needed for batch loading.
×
5325
        for i, row := range rows {
×
5326
                channelIDs[i] = row.Channel().ID
×
5327

×
5328
                // Collect policy IDs
×
5329
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5330
                if err != nil {
×
5331
                        return nil, fmt.Errorf("unable to extract channel "+
×
5332
                                "policies: %w", err)
×
5333
                }
×
5334
                if dbPol1 != nil {
×
5335
                        policyIDs = append(policyIDs, dbPol1.ID)
×
5336
                }
×
5337
                if dbPol2 != nil {
×
5338
                        policyIDs = append(policyIDs, dbPol2.ID)
×
5339
                }
×
5340

5341
                var (
×
5342
                        node1ID = row.Node1().ID
×
5343
                        node2ID = row.Node2().ID
×
5344
                )
×
5345

×
5346
                // Collect unique node IDs.
×
5347
                if !nodeIDSet[node1ID] {
×
5348
                        nodeIDs = append(nodeIDs, node1ID)
×
5349
                        nodeIDSet[node1ID] = true
×
5350
                }
×
5351

5352
                if !nodeIDSet[node2ID] {
×
5353
                        nodeIDs = append(nodeIDs, node2ID)
×
5354
                        nodeIDSet[node2ID] = true
×
5355
                }
×
5356
        }
5357

5358
        // Batch the data for all the channels and policies.
5359
        channelBatchData, err := batchLoadChannelData(
×
5360
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5361
        )
×
5362
        if err != nil {
×
5363
                return nil, fmt.Errorf("unable to batch load channel and "+
×
5364
                        "policy data: %w", err)
×
5365
        }
×
5366

5367
        // Batch the data for all the nodes.
5368
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
5369
        if err != nil {
×
5370
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
5371
                        err)
×
5372
        }
×
5373

5374
        // Build all channel edges using batch data.
5375
        for _, row := range rows {
×
5376
                // Build nodes using batch data.
×
5377
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
5378
                if err != nil {
×
5379
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
5380
                }
×
5381

5382
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
5383
                if err != nil {
×
5384
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
5385
                }
×
5386

5387
                // Build channel info using batch data.
5388
                channel, err := buildEdgeInfoWithBatchData(
×
5389
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
5390
                        node2.PubKeyBytes, channelBatchData,
×
5391
                )
×
5392
                if err != nil {
×
5393
                        return nil, fmt.Errorf("unable to build channel "+
×
5394
                                "info: %w", err)
×
5395
                }
×
5396

5397
                // Extract and build policies using batch data.
5398
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5399
                if err != nil {
×
5400
                        return nil, fmt.Errorf("unable to extract channel "+
×
5401
                                "policies: %w", err)
×
5402
                }
×
5403

5404
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5405
                        dbPol1, dbPol2, channel.ChannelID,
×
5406
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
5407
                )
×
5408
                if err != nil {
×
5409
                        return nil, fmt.Errorf("unable to build channel "+
×
5410
                                "policies: %w", err)
×
5411
                }
×
5412

5413
                edges = append(edges, ChannelEdge{
×
5414
                        Info:    channel,
×
5415
                        Policy1: p1,
×
5416
                        Policy2: p2,
×
5417
                        Node1:   node1,
×
5418
                        Node2:   node2,
×
5419
                })
×
5420
        }
5421

5422
        return edges, nil
×
5423
}
5424

5425
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
5426
// instances from the provided rows using batch loading for channel data.
5427
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
5428
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
5429
        []*models.ChannelEdgeInfo, []int64, error) {
×
5430

×
5431
        if len(rows) == 0 {
×
5432
                return nil, nil, nil
×
5433
        }
×
5434

5435
        // Collect all the channel IDs needed for batch loading.
5436
        channelIDs := make([]int64, len(rows))
×
5437
        for i, row := range rows {
×
5438
                channelIDs[i] = row.Channel().ID
×
5439
        }
×
5440

5441
        // Batch load the channel data.
5442
        channelBatchData, err := batchLoadChannelData(
×
5443
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
5444
        )
×
5445
        if err != nil {
×
5446
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
5447
                        "data: %w", err)
×
5448
        }
×
5449

5450
        // Build all channel edges using batch data.
5451
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
5452
        for _, row := range rows {
×
5453
                node1, node2, err := buildNodeVertices(
×
5454
                        row.Node1Pub(), row.Node2Pub(),
×
5455
                )
×
5456
                if err != nil {
×
5457
                        return nil, nil, err
×
5458
                }
×
5459

5460
                // Build channel info using batch data
5461
                info, err := buildEdgeInfoWithBatchData(
×
5462
                        cfg.ChainHash, row.Channel(), node1, node2,
×
5463
                        channelBatchData,
×
5464
                )
×
5465
                if err != nil {
×
5466
                        return nil, nil, err
×
5467
                }
×
5468

5469
                edges = append(edges, info)
×
5470
        }
5471

5472
        return edges, channelIDs, nil
×
5473
}
5474

5475
// handleZombieMarking is a helper function that handles the logic of
5476
// marking a channel as a zombie in the database. It takes into account whether
5477
// we are in strict zombie pruning mode, and adjusts the node public keys
5478
// accordingly based on the last update timestamps of the channel policies.
5479
func handleZombieMarking(ctx context.Context, db SQLQueries,
5480
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
5481
        strictZombiePruning bool, scid uint64) error {
×
5482

×
5483
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
5484

×
5485
        if strictZombiePruning {
×
5486
                var e1UpdateTime, e2UpdateTime *time.Time
×
5487
                if row.Policy1LastUpdate.Valid {
×
5488
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
5489
                        e1UpdateTime = &e1Time
×
5490
                }
×
5491
                if row.Policy2LastUpdate.Valid {
×
5492
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
5493
                        e2UpdateTime = &e2Time
×
5494
                }
×
5495

5496
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
5497
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
5498
                        e2UpdateTime,
×
5499
                )
×
5500
        }
5501

5502
        return db.UpsertZombieChannel(
×
5503
                ctx, sqlc.UpsertZombieChannelParams{
×
5504
                        Version:  int16(ProtocolV1),
×
5505
                        Scid:     channelIDToBytes(scid),
×
5506
                        NodeKey1: nodeKey1[:],
×
5507
                        NodeKey2: nodeKey2[:],
×
5508
                },
×
5509
        )
×
5510
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc