• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16930624683

13 Aug 2025 07:31AM UTC coverage: 66.908% (+10.0%) from 56.955%
16930624683

Pull #10148

github

web-flow
Merge faa71c073 into 8810793e6
Pull Request #10148: graph/db+sqldb: different defaults for SQLite and Postgres query options

11 of 80 new or added lines in 7 files covered. (13.75%)

62 existing lines in 9 files now uncovered.

135807 of 202975 relevant lines covered (66.91%)

21557.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// ProtocolVersion is an enum that defines the gossip protocol version of a
35
// message.
36
type ProtocolVersion uint8
37

38
const (
39
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
40
        ProtocolV1 ProtocolVersion = 1
41
)
42

43
// String returns a string representation of the protocol version.
44
func (v ProtocolVersion) String() string {
×
45
        return fmt.Sprintf("V%d", v)
×
46
}
×
47

48
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
49
// execute queries against the SQL graph tables.
50
//
51
//nolint:ll,interfacebloat
52
type SQLQueries interface {
53
        /*
54
                Node queries.
55
        */
56
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
57
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
58
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
59
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
60
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
61
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
62
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
63
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
64
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
65
        DeleteNode(ctx context.Context, id int64) error
66

67
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
68
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
69
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
70
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
71

72
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
73
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
74
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
75
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
76

77
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
78
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
79
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
80
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
81
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
82

83
        /*
84
                Source node queries.
85
        */
86
        AddSourceNode(ctx context.Context, nodeID int64) error
87
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
88

89
        /*
90
                Channel queries.
91
        */
92
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
93
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
94
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
95
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
96
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
97
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
98
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
99
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
100
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
101
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
102
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
103
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
104
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
105
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
106
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
107
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
108
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
109
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
110
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
111
        DeleteChannels(ctx context.Context, ids []int64) error
112

113
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
114
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
115
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
116
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
117

118
        /*
119
                Channel Policy table queries.
120
        */
121
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
122
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
123
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
124

125
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
126
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
127
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
128

129
        /*
130
                Zombie index queries.
131
        */
132
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
133
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
134
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
135
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
136
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
137

138
        /*
139
                Prune log table queries.
140
        */
141
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
142
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
143
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
144
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
145

146
        /*
147
                Closed SCID table queries.
148
        */
149
        InsertClosedChannel(ctx context.Context, scid []byte) error
150
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
151
}
152

153
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
154
// database operations.
155
type BatchedSQLQueries interface {
156
        SQLQueries
157
        sqldb.BatchedTx[SQLQueries]
158
}
159

160
// SQLStore is an implementation of the V1Store interface that uses a SQL
161
// database as the backend.
162
type SQLStore struct {
163
        cfg *SQLStoreConfig
164
        db  BatchedSQLQueries
165

166
        // cacheMu guards all caches (rejectCache and chanCache). If
167
        // this mutex will be acquired at the same time as the DB mutex then
168
        // the cacheMu MUST be acquired first to prevent deadlock.
169
        cacheMu     sync.RWMutex
170
        rejectCache *rejectCache
171
        chanCache   *channelCache
172

173
        chanScheduler batch.Scheduler[SQLQueries]
174
        nodeScheduler batch.Scheduler[SQLQueries]
175

176
        srcNodes  map[ProtocolVersion]*srcNodeInfo
177
        srcNodeMu sync.Mutex
178
}
179

180
// A compile-time assertion to ensure that SQLStore implements the V1Store
181
// interface.
182
var _ V1Store = (*SQLStore)(nil)
183

184
// SQLStoreConfig holds the configuration for the SQLStore.
185
type SQLStoreConfig struct {
186
        // ChainHash is the genesis hash for the chain that all the gossip
187
        // messages in this store are aimed at.
188
        ChainHash chainhash.Hash
189

190
        // QueryConfig holds configuration values for SQL queries.
191
        QueryCfg *sqldb.QueryConfig
192
}
193

194
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
195
// storage backend.
196
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
197
        options ...StoreOptionModifier) (*SQLStore, error) {
×
198

×
199
        opts := DefaultOptions()
×
200
        for _, o := range options {
×
201
                o(opts)
×
202
        }
×
203

204
        if opts.NoMigration {
×
205
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
206
                        "supported for SQL stores")
×
207
        }
×
208

209
        s := &SQLStore{
×
210
                cfg:         cfg,
×
211
                db:          db,
×
212
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
213
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
214
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
215
        }
×
216

×
217
        s.chanScheduler = batch.NewTimeScheduler(
×
218
                db, &s.cacheMu, opts.BatchCommitInterval,
×
219
        )
×
220
        s.nodeScheduler = batch.NewTimeScheduler(
×
221
                db, nil, opts.BatchCommitInterval,
×
222
        )
×
223

×
224
        return s, nil
×
225
}
226

227
// AddLightningNode adds a vertex/node to the graph database. If the node is not
228
// in the database from before, this will add a new, unconnected one to the
229
// graph. If it is present from before, this will update that node's
230
// information.
231
//
232
// NOTE: part of the V1Store interface.
233
func (s *SQLStore) AddLightningNode(ctx context.Context,
234
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
235

×
236
        r := &batch.Request[SQLQueries]{
×
237
                Opts: batch.NewSchedulerOptions(opts...),
×
238
                Do: func(queries SQLQueries) error {
×
239
                        _, err := upsertNode(ctx, queries, node)
×
240
                        return err
×
241
                },
×
242
        }
243

244
        return s.nodeScheduler.Execute(ctx, r)
×
245
}
246

247
// FetchLightningNode attempts to look up a target node by its identity public
248
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
249
// returned.
250
//
251
// NOTE: part of the V1Store interface.
252
func (s *SQLStore) FetchLightningNode(ctx context.Context,
253
        pubKey route.Vertex) (*models.LightningNode, error) {
×
254

×
255
        var node *models.LightningNode
×
256
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
257
                var err error
×
NEW
258
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, pubKey)
×
259

×
260
                return err
×
261
        }, sqldb.NoOpReset)
×
262
        if err != nil {
×
263
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
264
        }
×
265

266
        return node, nil
×
267
}
268

269
// HasLightningNode determines if the graph has a vertex identified by the
270
// target node identity public key. If the node exists in the database, a
271
// timestamp of when the data for the node was lasted updated is returned along
272
// with a true boolean. Otherwise, an empty time.Time is returned with a false
273
// boolean.
274
//
275
// NOTE: part of the V1Store interface.
276
func (s *SQLStore) HasLightningNode(ctx context.Context,
277
        pubKey [33]byte) (time.Time, bool, error) {
×
278

×
279
        var (
×
280
                exists     bool
×
281
                lastUpdate time.Time
×
282
        )
×
283
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
284
                dbNode, err := db.GetNodeByPubKey(
×
285
                        ctx, sqlc.GetNodeByPubKeyParams{
×
286
                                Version: int16(ProtocolV1),
×
287
                                PubKey:  pubKey[:],
×
288
                        },
×
289
                )
×
290
                if errors.Is(err, sql.ErrNoRows) {
×
291
                        return nil
×
292
                } else if err != nil {
×
293
                        return fmt.Errorf("unable to fetch node: %w", err)
×
294
                }
×
295

296
                exists = true
×
297

×
298
                if dbNode.LastUpdate.Valid {
×
299
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
300
                }
×
301

302
                return nil
×
303
        }, sqldb.NoOpReset)
304
        if err != nil {
×
305
                return time.Time{}, false,
×
306
                        fmt.Errorf("unable to fetch node: %w", err)
×
307
        }
×
308

309
        return lastUpdate, exists, nil
×
310
}
311

312
// AddrsForNode returns all known addresses for the target node public key
313
// that the graph DB is aware of. The returned boolean indicates if the
314
// given node is unknown to the graph DB or not.
315
//
316
// NOTE: part of the V1Store interface.
317
func (s *SQLStore) AddrsForNode(ctx context.Context,
318
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
319

×
320
        var (
×
321
                addresses []net.Addr
×
322
                known     bool
×
323
        )
×
324
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
325
                // First, check if the node exists and get its DB ID if it
×
326
                // does.
×
327
                dbID, err := db.GetNodeIDByPubKey(
×
328
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
329
                                Version: int16(ProtocolV1),
×
330
                                PubKey:  nodePub.SerializeCompressed(),
×
331
                        },
×
332
                )
×
333
                if errors.Is(err, sql.ErrNoRows) {
×
334
                        return nil
×
335
                }
×
336

337
                known = true
×
338

×
339
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
340
                if err != nil {
×
341
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
342
                                err)
×
343
                }
×
344

345
                return nil
×
346
        }, sqldb.NoOpReset)
347
        if err != nil {
×
348
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
349
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
350
        }
×
351

352
        return known, addresses, nil
×
353
}
354

355
// DeleteLightningNode starts a new database transaction to remove a vertex/node
356
// from the database according to the node's public key.
357
//
358
// NOTE: part of the V1Store interface.
359
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
360
        pubKey route.Vertex) error {
×
361

×
362
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
363
                res, err := db.DeleteNodeByPubKey(
×
364
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
365
                                Version: int16(ProtocolV1),
×
366
                                PubKey:  pubKey[:],
×
367
                        },
×
368
                )
×
369
                if err != nil {
×
370
                        return err
×
371
                }
×
372

373
                rows, err := res.RowsAffected()
×
374
                if err != nil {
×
375
                        return err
×
376
                }
×
377

378
                if rows == 0 {
×
379
                        return ErrGraphNodeNotFound
×
380
                } else if rows > 1 {
×
381
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
382
                }
×
383

384
                return err
×
385
        }, sqldb.NoOpReset)
386
        if err != nil {
×
387
                return fmt.Errorf("unable to delete node: %w", err)
×
388
        }
×
389

390
        return nil
×
391
}
392

393
// FetchNodeFeatures returns the features of the given node. If no features are
394
// known for the node, an empty feature vector is returned.
395
//
396
// NOTE: this is part of the graphdb.NodeTraverser interface.
397
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
398
        *lnwire.FeatureVector, error) {
×
399

×
400
        ctx := context.TODO()
×
401

×
402
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
403
}
×
404

405
// DisabledChannelIDs returns the channel ids of disabled channels.
406
// A channel is disabled when two of the associated ChanelEdgePolicies
407
// have their disabled bit on.
408
//
409
// NOTE: part of the V1Store interface.
410
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
411
        var (
×
412
                ctx     = context.TODO()
×
413
                chanIDs []uint64
×
414
        )
×
415
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
416
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
417
                if err != nil {
×
418
                        return fmt.Errorf("unable to fetch disabled "+
×
419
                                "channels: %w", err)
×
420
                }
×
421

422
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
423

×
424
                return nil
×
425
        }, sqldb.NoOpReset)
426
        if err != nil {
×
427
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
428
                        err)
×
429
        }
×
430

431
        return chanIDs, nil
×
432
}
433

434
// LookupAlias attempts to return the alias as advertised by the target node.
435
//
436
// NOTE: part of the V1Store interface.
437
func (s *SQLStore) LookupAlias(ctx context.Context,
438
        pub *btcec.PublicKey) (string, error) {
×
439

×
440
        var alias string
×
441
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
442
                dbNode, err := db.GetNodeByPubKey(
×
443
                        ctx, sqlc.GetNodeByPubKeyParams{
×
444
                                Version: int16(ProtocolV1),
×
445
                                PubKey:  pub.SerializeCompressed(),
×
446
                        },
×
447
                )
×
448
                if errors.Is(err, sql.ErrNoRows) {
×
449
                        return ErrNodeAliasNotFound
×
450
                } else if err != nil {
×
451
                        return fmt.Errorf("unable to fetch node: %w", err)
×
452
                }
×
453

454
                if !dbNode.Alias.Valid {
×
455
                        return ErrNodeAliasNotFound
×
456
                }
×
457

458
                alias = dbNode.Alias.String
×
459

×
460
                return nil
×
461
        }, sqldb.NoOpReset)
462
        if err != nil {
×
463
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
464
        }
×
465

466
        return alias, nil
×
467
}
468

469
// SourceNode returns the source node of the graph. The source node is treated
470
// as the center node within a star-graph. This method may be used to kick off
471
// a path finding algorithm in order to explore the reachability of another
472
// node based off the source node.
473
//
474
// NOTE: part of the V1Store interface.
475
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
476
        error) {
×
477

×
478
        var node *models.LightningNode
×
479
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
480
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
481
                if err != nil {
×
482
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
483
                                err)
×
484
                }
×
485

NEW
486
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, nodePub)
×
487

×
488
                return err
×
489
        }, sqldb.NoOpReset)
490
        if err != nil {
×
491
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
492
        }
×
493

494
        return node, nil
×
495
}
496

497
// SetSourceNode sets the source node within the graph database. The source
498
// node is to be used as the center of a star-graph within path finding
499
// algorithms.
500
//
501
// NOTE: part of the V1Store interface.
502
func (s *SQLStore) SetSourceNode(ctx context.Context,
503
        node *models.LightningNode) error {
×
504

×
505
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
506
                id, err := upsertNode(ctx, db, node)
×
507
                if err != nil {
×
508
                        return fmt.Errorf("unable to upsert source node: %w",
×
509
                                err)
×
510
                }
×
511

512
                // Make sure that if a source node for this version is already
513
                // set, then the ID is the same as the one we are about to set.
514
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
515
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
516
                        return fmt.Errorf("unable to fetch source node: %w",
×
517
                                err)
×
518
                } else if err == nil {
×
519
                        if dbSourceNodeID != id {
×
520
                                return fmt.Errorf("v1 source node already "+
×
521
                                        "set to a different node: %d vs %d",
×
522
                                        dbSourceNodeID, id)
×
523
                        }
×
524

525
                        return nil
×
526
                }
527

528
                return db.AddSourceNode(ctx, id)
×
529
        }, sqldb.NoOpReset)
530
}
531

532
// NodeUpdatesInHorizon returns all the known lightning node which have an
533
// update timestamp within the passed range. This method can be used by two
534
// nodes to quickly determine if they have the same set of up to date node
535
// announcements.
536
//
537
// NOTE: This is part of the V1Store interface.
538
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
539
        endTime time.Time) ([]models.LightningNode, error) {
×
540

×
541
        ctx := context.TODO()
×
542

×
543
        var nodes []models.LightningNode
×
544
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
545
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
546
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
547
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
548
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
549
                        },
×
550
                )
×
551
                if err != nil {
×
552
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
553
                }
×
554

555
                err = forEachNodeInBatch(
×
556
                        ctx, s.cfg.QueryCfg, db, dbNodes,
×
557
                        func(_ int64, node *models.LightningNode) error {
×
558
                                nodes = append(nodes, *node)
×
559

×
560
                                return nil
×
561
                        },
×
562
                )
563
                if err != nil {
×
564
                        return fmt.Errorf("unable to build nodes: %w", err)
×
565
                }
×
566

567
                return nil
×
568
        }, sqldb.NoOpReset)
569
        if err != nil {
×
570
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
571
        }
×
572

573
        return nodes, nil
×
574
}
575

576
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
577
// undirected edge from the two target nodes are created. The information stored
578
// denotes the static attributes of the channel, such as the channelID, the keys
579
// involved in creation of the channel, and the set of features that the channel
580
// supports. The chanPoint and chanID are used to uniquely identify the edge
581
// globally within the database.
582
//
583
// NOTE: part of the V1Store interface.
584
func (s *SQLStore) AddChannelEdge(ctx context.Context,
585
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
586

×
587
        var alreadyExists bool
×
588
        r := &batch.Request[SQLQueries]{
×
589
                Opts: batch.NewSchedulerOptions(opts...),
×
590
                Reset: func() {
×
591
                        alreadyExists = false
×
592
                },
×
593
                Do: func(tx SQLQueries) error {
×
594
                        _, err := insertChannel(ctx, tx, edge)
×
595

×
596
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
597
                        // succeed, but propagate the error via local state.
×
598
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
599
                                alreadyExists = true
×
600
                                return nil
×
601
                        }
×
602

603
                        return err
×
604
                },
605
                OnCommit: func(err error) error {
×
606
                        switch {
×
607
                        case err != nil:
×
608
                                return err
×
609
                        case alreadyExists:
×
610
                                return ErrEdgeAlreadyExist
×
611
                        default:
×
612
                                s.rejectCache.remove(edge.ChannelID)
×
613
                                s.chanCache.remove(edge.ChannelID)
×
614
                                return nil
×
615
                        }
616
                },
617
        }
618

619
        return s.chanScheduler.Execute(ctx, r)
×
620
}
621

622
// HighestChanID returns the "highest" known channel ID in the channel graph.
623
// This represents the "newest" channel from the PoV of the chain. This method
624
// can be used by peers to quickly determine if their graphs are in sync.
625
//
626
// NOTE: This is part of the V1Store interface.
627
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
628
        var highestChanID uint64
×
629
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
630
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
631
                if errors.Is(err, sql.ErrNoRows) {
×
632
                        return nil
×
633
                } else if err != nil {
×
634
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
635
                                err)
×
636
                }
×
637

638
                highestChanID = byteOrder.Uint64(chanID)
×
639

×
640
                return nil
×
641
        }, sqldb.NoOpReset)
642
        if err != nil {
×
643
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
644
        }
×
645

646
        return highestChanID, nil
×
647
}
648

649
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
650
// within the database for the referenced channel. The `flags` attribute within
651
// the ChannelEdgePolicy determines which of the directed edges are being
652
// updated. If the flag is 1, then the first node's information is being
653
// updated, otherwise it's the second node's information. The node ordering is
654
// determined by the lexicographical ordering of the identity public keys of the
655
// nodes on either side of the channel.
656
//
657
// NOTE: part of the V1Store interface.
658
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
659
        edge *models.ChannelEdgePolicy,
660
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
661

×
662
        var (
×
663
                isUpdate1    bool
×
664
                edgeNotFound bool
×
665
                from, to     route.Vertex
×
666
        )
×
667

×
668
        r := &batch.Request[SQLQueries]{
×
669
                Opts: batch.NewSchedulerOptions(opts...),
×
670
                Reset: func() {
×
671
                        isUpdate1 = false
×
672
                        edgeNotFound = false
×
673
                },
×
674
                Do: func(tx SQLQueries) error {
×
675
                        var err error
×
676
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
677
                                ctx, tx, edge,
×
678
                        )
×
679
                        if err != nil {
×
680
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
681
                        }
×
682

683
                        // Silence ErrEdgeNotFound so that the batch can
684
                        // succeed, but propagate the error via local state.
685
                        if errors.Is(err, ErrEdgeNotFound) {
×
686
                                edgeNotFound = true
×
687
                                return nil
×
688
                        }
×
689

690
                        return err
×
691
                },
692
                OnCommit: func(err error) error {
×
693
                        switch {
×
694
                        case err != nil:
×
695
                                return err
×
696
                        case edgeNotFound:
×
697
                                return ErrEdgeNotFound
×
698
                        default:
×
699
                                s.updateEdgeCache(edge, isUpdate1)
×
700
                                return nil
×
701
                        }
702
                },
703
        }
704

705
        err := s.chanScheduler.Execute(ctx, r)
×
706

×
707
        return from, to, err
×
708
}
709

710
// updateEdgeCache updates our reject and channel caches with the new
711
// edge policy information.
712
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
713
        isUpdate1 bool) {
×
714

×
715
        // If an entry for this channel is found in reject cache, we'll modify
×
716
        // the entry with the updated timestamp for the direction that was just
×
717
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
718
        // during the next query for this edge.
×
719
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
720
                if isUpdate1 {
×
721
                        entry.upd1Time = e.LastUpdate.Unix()
×
722
                } else {
×
723
                        entry.upd2Time = e.LastUpdate.Unix()
×
724
                }
×
725
                s.rejectCache.insert(e.ChannelID, entry)
×
726
        }
727

728
        // If an entry for this channel is found in channel cache, we'll modify
729
        // the entry with the updated policy for the direction that was just
730
        // written. If the edge doesn't exist, we'll defer loading the info and
731
        // policies and lazily read from disk during the next query.
732
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
733
                if isUpdate1 {
×
734
                        channel.Policy1 = e
×
735
                } else {
×
736
                        channel.Policy2 = e
×
737
                }
×
738
                s.chanCache.insert(e.ChannelID, channel)
×
739
        }
740
}
741

742
// ForEachSourceNodeChannel iterates through all channels of the source node,
743
// executing the passed callback on each. The call-back is provided with the
744
// channel's outpoint, whether we have a policy for the channel and the channel
745
// peer's node information.
746
//
747
// NOTE: part of the V1Store interface.
748
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
749
        cb func(chanPoint wire.OutPoint, havePolicy bool,
750
                otherNode *models.LightningNode) error, reset func()) error {
×
751

×
752
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
753
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
754
                if err != nil {
×
755
                        return fmt.Errorf("unable to fetch source node: %w",
×
756
                                err)
×
757
                }
×
758

759
                return forEachNodeChannel(
×
760
                        ctx, db, s.cfg, nodeID,
×
761
                        func(info *models.ChannelEdgeInfo,
×
762
                                outPolicy *models.ChannelEdgePolicy,
×
763
                                _ *models.ChannelEdgePolicy) error {
×
764

×
765
                                // Fetch the other node.
×
766
                                var (
×
767
                                        otherNodePub [33]byte
×
768
                                        node1        = info.NodeKey1Bytes
×
769
                                        node2        = info.NodeKey2Bytes
×
770
                                )
×
771
                                switch {
×
772
                                case bytes.Equal(node1[:], nodePub[:]):
×
773
                                        otherNodePub = node2
×
774
                                case bytes.Equal(node2[:], nodePub[:]):
×
775
                                        otherNodePub = node1
×
776
                                default:
×
777
                                        return fmt.Errorf("node not " +
×
778
                                                "participating in this channel")
×
779
                                }
780

781
                                _, otherNode, err := getNodeByPubKey(
×
NEW
782
                                        ctx, s.cfg.QueryCfg, db, otherNodePub,
×
783
                                )
×
784
                                if err != nil {
×
785
                                        return fmt.Errorf("unable to fetch "+
×
786
                                                "other node(%x): %w",
×
787
                                                otherNodePub, err)
×
788
                                }
×
789

790
                                return cb(
×
791
                                        info.ChannelPoint, outPolicy != nil,
×
792
                                        otherNode,
×
793
                                )
×
794
                        },
795
                )
796
        }, reset)
797
}
798

799
// ForEachNode iterates through all the stored vertices/nodes in the graph,
800
// executing the passed callback with each node encountered. If the callback
801
// returns an error, then the transaction is aborted and the iteration stops
802
// early.
803
//
804
// NOTE: part of the V1Store interface.
805
func (s *SQLStore) ForEachNode(ctx context.Context,
806
        cb func(node *models.LightningNode) error, reset func()) error {
×
807

×
808
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
809
                return forEachNodePaginated(
×
810
                        ctx, s.cfg.QueryCfg, db,
×
811
                        ProtocolV1, func(_ context.Context, _ int64,
×
812
                                node *models.LightningNode) error {
×
813

×
814
                                return cb(node)
×
815
                        },
×
816
                )
817
        }, reset)
818
}
819

820
// ForEachNodeDirectedChannel iterates through all channels of a given node,
821
// executing the passed callback on the directed edge representing the channel
822
// and its incoming policy. If the callback returns an error, then the iteration
823
// is halted with the error propagated back up to the caller.
824
//
825
// Unknown policies are passed into the callback as nil values.
826
//
827
// NOTE: this is part of the graphdb.NodeTraverser interface.
828
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
829
        cb func(channel *DirectedChannel) error, reset func()) error {
×
830

×
831
        var ctx = context.TODO()
×
832

×
833
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
834
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
835
        }, reset)
×
836
}
837

838
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
839
// graph, executing the passed callback with each node encountered. If the
840
// callback returns an error, then the transaction is aborted and the iteration
841
// stops early.
842
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
843
        cb func(route.Vertex, *lnwire.FeatureVector) error,
844
        reset func()) error {
×
845

×
846
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
847
                return forEachNodeCacheable(
×
848
                        ctx, s.cfg.QueryCfg, db,
×
849
                        func(_ int64, nodePub route.Vertex,
×
850
                                features *lnwire.FeatureVector) error {
×
851

×
852
                                return cb(nodePub, features)
×
853
                        },
×
854
                )
855
        }, reset)
856
        if err != nil {
×
857
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
858
        }
×
859

860
        return nil
×
861
}
862

863
// ForEachNodeChannel iterates through all channels of the given node,
864
// executing the passed callback with an edge info structure and the policies
865
// of each end of the channel. The first edge policy is the outgoing edge *to*
866
// the connecting node, while the second is the incoming edge *from* the
867
// connecting node. If the callback returns an error, then the iteration is
868
// halted with the error propagated back up to the caller.
869
//
870
// Unknown policies are passed into the callback as nil values.
871
//
872
// NOTE: part of the V1Store interface.
873
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
874
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
875
                *models.ChannelEdgePolicy) error, reset func()) error {
×
876

×
877
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
878
                dbNode, err := db.GetNodeByPubKey(
×
879
                        ctx, sqlc.GetNodeByPubKeyParams{
×
880
                                Version: int16(ProtocolV1),
×
881
                                PubKey:  nodePub[:],
×
882
                        },
×
883
                )
×
884
                if errors.Is(err, sql.ErrNoRows) {
×
885
                        return nil
×
886
                } else if err != nil {
×
887
                        return fmt.Errorf("unable to fetch node: %w", err)
×
888
                }
×
889

890
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
891
        }, reset)
892
}
893

894
// ChanUpdatesInHorizon returns all the known channel edges which have at least
895
// one edge that has an update timestamp within the specified horizon.
896
//
897
// NOTE: This is part of the V1Store interface.
898
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
899
        endTime time.Time) ([]ChannelEdge, error) {
×
900

×
901
        s.cacheMu.Lock()
×
902
        defer s.cacheMu.Unlock()
×
903

×
904
        var (
×
905
                ctx = context.TODO()
×
906
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
907
                // an additional map to keep track of the edges already seen to
×
908
                // prevent re-adding it.
×
909
                edgesSeen    = make(map[uint64]struct{})
×
910
                edgesToCache = make(map[uint64]ChannelEdge)
×
911
                edges        []ChannelEdge
×
912
                hits         int
×
913
        )
×
914

×
915
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
916
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
917
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
918
                                Version:   int16(ProtocolV1),
×
919
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
920
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
921
                        },
×
922
                )
×
923
                if err != nil {
×
924
                        return err
×
925
                }
×
926

927
                if len(rows) == 0 {
×
928
                        return nil
×
929
                }
×
930

931
                // We'll pre-allocate the slices and maps here with a best
932
                // effort size in order to avoid unnecessary allocations later
933
                // on.
934
                uncachedRows := make(
×
935
                        []sqlc.GetChannelsByPolicyLastUpdateRangeRow, 0,
×
936
                        len(rows),
×
937
                )
×
938
                edgesToCache = make(map[uint64]ChannelEdge, len(rows))
×
939
                edgesSeen = make(map[uint64]struct{}, len(rows))
×
940
                edges = make([]ChannelEdge, 0, len(rows))
×
941

×
942
                // Separate cached from non-cached channels since we will only
×
943
                // batch load the data for the ones we haven't cached yet.
×
944
                for _, row := range rows {
×
945
                        chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid)
×
946

×
947
                        // Skip duplicates.
×
948
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
949
                                continue
×
950
                        }
951
                        edgesSeen[chanIDInt] = struct{}{}
×
952

×
953
                        // Check cache first.
×
954
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
955
                                hits++
×
956
                                edges = append(edges, channel)
×
957
                                continue
×
958
                        }
959

960
                        // Mark this row as one we need to batch load data for.
961
                        uncachedRows = append(uncachedRows, row)
×
962
                }
963

964
                // If there are no uncached rows, then we can return early.
965
                if len(uncachedRows) == 0 {
×
966
                        return nil
×
967
                }
×
968

969
                // Batch load data for all uncached channels.
970
                newEdges, err := batchBuildChannelEdges(
×
971
                        ctx, s.cfg, db, uncachedRows,
×
972
                )
×
973
                if err != nil {
×
974
                        return fmt.Errorf("unable to batch build channel "+
×
975
                                "edges: %w", err)
×
976
                }
×
977

978
                edges = append(edges, newEdges...)
×
979

×
980
                return nil
×
981
        }, sqldb.NoOpReset)
982
        if err != nil {
×
983
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
984
        }
×
985

986
        // Insert any edges loaded from disk into the cache.
987
        for chanid, channel := range edgesToCache {
×
988
                s.chanCache.insert(chanid, channel)
×
989
        }
×
990

991
        if len(edges) > 0 {
×
992
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
993
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
994
        } else {
×
995
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
996
                        "horizon (%s, %s)", startTime, endTime)
×
997
        }
×
998

999
        return edges, nil
×
1000
}
1001

1002
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1003
// data to the call-back. If withAddrs is true, then the call-back will also be
1004
// provided with the addresses associated with the node. The address retrieval
1005
// result in an additional round-trip to the database, so it should only be used
1006
// if the addresses are actually needed.
1007
//
1008
// NOTE: part of the V1Store interface.
1009
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1010
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1011
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1012

×
1013
        type nodeCachedBatchData struct {
×
1014
                features      map[int64][]int
×
1015
                addrs         map[int64][]nodeAddress
×
1016
                chanBatchData *batchChannelData
×
1017
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1018
        }
×
1019

×
1020
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1021
                // pageQueryFunc is used to query the next page of nodes.
×
1022
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1023
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1024

×
1025
                        return db.ListNodeIDsAndPubKeys(
×
1026
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1027
                                        Version: int16(ProtocolV1),
×
1028
                                        ID:      lastID,
×
1029
                                        Limit:   limit,
×
1030
                                },
×
1031
                        )
×
1032
                }
×
1033

1034
                // batchDataFunc is then used to batch load the data required
1035
                // for each page of nodes.
1036
                batchDataFunc := func(ctx context.Context,
×
1037
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1038

×
1039
                        // Batch load node features.
×
1040
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1041
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1042
                        )
×
1043
                        if err != nil {
×
1044
                                return nil, fmt.Errorf("unable to batch load "+
×
1045
                                        "node features: %w", err)
×
1046
                        }
×
1047

1048
                        // Maybe fetch the node's addresses if requested.
1049
                        var nodeAddrs map[int64][]nodeAddress
×
1050
                        if withAddrs {
×
1051
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1052
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1053
                                )
×
1054
                                if err != nil {
×
1055
                                        return nil, fmt.Errorf("unable to "+
×
1056
                                                "batch load node "+
×
1057
                                                "addresses: %w", err)
×
1058
                                }
×
1059
                        }
1060

1061
                        // Batch load ALL unique channels for ALL nodes in this
1062
                        // page.
1063
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1064
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1065
                                        Version:  int16(ProtocolV1),
×
1066
                                        Node1Ids: nodeIDs,
×
1067
                                        Node2Ids: nodeIDs,
×
1068
                                },
×
1069
                        )
×
1070
                        if err != nil {
×
1071
                                return nil, fmt.Errorf("unable to batch "+
×
1072
                                        "fetch channels for nodes: %w", err)
×
1073
                        }
×
1074

1075
                        // Deduplicate channels and collect IDs.
1076
                        var (
×
1077
                                allChannelIDs []int64
×
1078
                                allPolicyIDs  []int64
×
1079
                        )
×
1080
                        uniqueChannels := make(
×
1081
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1082
                        )
×
1083

×
1084
                        for _, channel := range allChannels {
×
1085
                                channelID := channel.GraphChannel.ID
×
1086

×
1087
                                // Only process each unique channel once.
×
1088
                                _, exists := uniqueChannels[channelID]
×
1089
                                if exists {
×
1090
                                        continue
×
1091
                                }
1092

1093
                                uniqueChannels[channelID] = channel
×
1094
                                allChannelIDs = append(allChannelIDs, channelID)
×
1095

×
1096
                                if channel.Policy1ID.Valid {
×
1097
                                        allPolicyIDs = append(
×
1098
                                                allPolicyIDs,
×
1099
                                                channel.Policy1ID.Int64,
×
1100
                                        )
×
1101
                                }
×
1102
                                if channel.Policy2ID.Valid {
×
1103
                                        allPolicyIDs = append(
×
1104
                                                allPolicyIDs,
×
1105
                                                channel.Policy2ID.Int64,
×
1106
                                        )
×
1107
                                }
×
1108
                        }
1109

1110
                        // Batch load channel data for all unique channels.
1111
                        channelBatchData, err := batchLoadChannelData(
×
1112
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1113
                                allPolicyIDs,
×
1114
                        )
×
1115
                        if err != nil {
×
1116
                                return nil, fmt.Errorf("unable to batch "+
×
1117
                                        "load channel data: %w", err)
×
1118
                        }
×
1119

1120
                        // Create map of node ID to channels that involve this
1121
                        // node.
1122
                        nodeIDSet := make(map[int64]bool)
×
1123
                        for _, nodeID := range nodeIDs {
×
1124
                                nodeIDSet[nodeID] = true
×
1125
                        }
×
1126

1127
                        nodeChannelMap := make(
×
1128
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1129
                        )
×
1130
                        for _, channel := range uniqueChannels {
×
1131
                                // Add channel to both nodes if they're in our
×
1132
                                // current page.
×
1133
                                node1 := channel.GraphChannel.NodeID1
×
1134
                                if nodeIDSet[node1] {
×
1135
                                        nodeChannelMap[node1] = append(
×
1136
                                                nodeChannelMap[node1], channel,
×
1137
                                        )
×
1138
                                }
×
1139
                                node2 := channel.GraphChannel.NodeID2
×
1140
                                if nodeIDSet[node2] {
×
1141
                                        nodeChannelMap[node2] = append(
×
1142
                                                nodeChannelMap[node2], channel,
×
1143
                                        )
×
1144
                                }
×
1145
                        }
1146

1147
                        return &nodeCachedBatchData{
×
1148
                                features:      nodeFeatures,
×
1149
                                addrs:         nodeAddrs,
×
1150
                                chanBatchData: channelBatchData,
×
1151
                                chanMap:       nodeChannelMap,
×
1152
                        }, nil
×
1153
                }
1154

1155
                // processItem is used to process each node in the current page.
1156
                processItem := func(ctx context.Context,
×
1157
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1158
                        batchData *nodeCachedBatchData) error {
×
1159

×
1160
                        // Build feature vector for this node.
×
1161
                        fv := lnwire.EmptyFeatureVector()
×
1162
                        features, exists := batchData.features[nodeData.ID]
×
1163
                        if exists {
×
1164
                                for _, bit := range features {
×
1165
                                        fv.Set(lnwire.FeatureBit(bit))
×
1166
                                }
×
1167
                        }
1168

1169
                        var nodePub route.Vertex
×
1170
                        copy(nodePub[:], nodeData.PubKey)
×
1171

×
1172
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1173

×
1174
                        toNodeCallback := func() route.Vertex {
×
1175
                                return nodePub
×
1176
                        }
×
1177

1178
                        // Build cached channels map for this node.
1179
                        channels := make(map[uint64]*DirectedChannel)
×
1180
                        for _, channelRow := range nodeChannels {
×
1181
                                directedChan, err := buildDirectedChannel(
×
1182
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1183
                                        channelRow, batchData.chanBatchData, fv,
×
1184
                                        toNodeCallback,
×
1185
                                )
×
1186
                                if err != nil {
×
1187
                                        return err
×
1188
                                }
×
1189

1190
                                channels[directedChan.ChannelID] = directedChan
×
1191
                        }
1192

1193
                        addrs, err := buildNodeAddresses(
×
1194
                                batchData.addrs[nodeData.ID],
×
1195
                        )
×
1196
                        if err != nil {
×
1197
                                return fmt.Errorf("unable to build node "+
×
1198
                                        "addresses: %w", err)
×
1199
                        }
×
1200

1201
                        return cb(ctx, nodePub, addrs, channels)
×
1202
                }
1203

1204
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1205
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1206
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1207
                                return node.ID
×
1208
                        },
×
1209
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1210
                                error) {
×
1211

×
1212
                                return node.ID, nil
×
1213
                        },
×
1214
                        batchDataFunc, processItem,
1215
                )
1216
        }, reset)
1217
}
1218

1219
// ForEachChannelCacheable iterates through all the channel edges stored
1220
// within the graph and invokes the passed callback for each edge. The
1221
// callback takes two edges as since this is a directed graph, both the
1222
// in/out edges are visited. If the callback returns an error, then the
1223
// transaction is aborted and the iteration stops early.
1224
//
1225
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1226
// pointer for that particular channel edge routing policy will be
1227
// passed into the callback.
1228
//
1229
// NOTE: this method is like ForEachChannel but fetches only the data
1230
// required for the graph cache.
1231
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1232
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1233
        reset func()) error {
×
1234

×
1235
        ctx := context.TODO()
×
1236

×
1237
        handleChannel := func(_ context.Context,
×
1238
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1239

×
1240
                node1, node2, err := buildNodeVertices(
×
1241
                        row.Node1Pubkey, row.Node2Pubkey,
×
1242
                )
×
1243
                if err != nil {
×
1244
                        return err
×
1245
                }
×
1246

1247
                edge := buildCacheableChannelInfo(
×
1248
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1249
                )
×
1250

×
1251
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1252
                if err != nil {
×
1253
                        return err
×
1254
                }
×
1255

1256
                pol1, pol2, err := buildCachedChanPolicies(
×
1257
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1258
                )
×
1259
                if err != nil {
×
1260
                        return err
×
1261
                }
×
1262

1263
                return cb(edge, pol1, pol2)
×
1264
        }
1265

1266
        extractCursor := func(
×
1267
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1268

×
1269
                return row.ID
×
1270
        }
×
1271

1272
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1273
                //nolint:ll
×
1274
                queryFunc := func(ctx context.Context, lastID int64,
×
1275
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1276
                        error) {
×
1277

×
1278
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1279
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1280
                                        Version: int16(ProtocolV1),
×
1281
                                        ID:      lastID,
×
1282
                                        Limit:   limit,
×
1283
                                },
×
1284
                        )
×
1285
                }
×
1286

1287
                return sqldb.ExecutePaginatedQuery(
×
1288
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1289
                        extractCursor, handleChannel,
×
1290
                )
×
1291
        }, reset)
1292
}
1293

1294
// ForEachChannel iterates through all the channel edges stored within the
1295
// graph and invokes the passed callback for each edge. The callback takes two
1296
// edges as since this is a directed graph, both the in/out edges are visited.
1297
// If the callback returns an error, then the transaction is aborted and the
1298
// iteration stops early.
1299
//
1300
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1301
// for that particular channel edge routing policy will be passed into the
1302
// callback.
1303
//
1304
// NOTE: part of the V1Store interface.
1305
func (s *SQLStore) ForEachChannel(ctx context.Context,
1306
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1307
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1308

×
1309
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1310
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1311
        }, reset)
×
1312
}
1313

1314
// FilterChannelRange returns the channel ID's of all known channels which were
1315
// mined in a block height within the passed range. The channel IDs are grouped
1316
// by their common block height. This method can be used to quickly share with a
1317
// peer the set of channels we know of within a particular range to catch them
1318
// up after a period of time offline. If withTimestamps is true then the
1319
// timestamp info of the latest received channel update messages of the channel
1320
// will be included in the response.
1321
//
1322
// NOTE: This is part of the V1Store interface.
1323
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1324
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1325

×
1326
        var (
×
1327
                ctx       = context.TODO()
×
1328
                startSCID = &lnwire.ShortChannelID{
×
1329
                        BlockHeight: startHeight,
×
1330
                }
×
1331
                endSCID = lnwire.ShortChannelID{
×
1332
                        BlockHeight: endHeight,
×
1333
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1334
                        TxPosition:  math.MaxUint16,
×
1335
                }
×
1336
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1337
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1338
        )
×
1339

×
1340
        // 1) get all channels where channelID is between start and end chan ID.
×
1341
        // 2) skip if not public (ie, no channel_proof)
×
1342
        // 3) collect that channel.
×
1343
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1344
        //    and add those timestamps to the collected channel.
×
1345
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1346
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1347
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1348
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1349
                                StartScid: chanIDStart,
×
1350
                                EndScid:   chanIDEnd,
×
1351
                        },
×
1352
                )
×
1353
                if err != nil {
×
1354
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1355
                                err)
×
1356
                }
×
1357

1358
                for _, dbChan := range dbChans {
×
1359
                        cid := lnwire.NewShortChanIDFromInt(
×
1360
                                byteOrder.Uint64(dbChan.Scid),
×
1361
                        )
×
1362
                        chanInfo := NewChannelUpdateInfo(
×
1363
                                cid, time.Time{}, time.Time{},
×
1364
                        )
×
1365

×
1366
                        if !withTimestamps {
×
1367
                                channelsPerBlock[cid.BlockHeight] = append(
×
1368
                                        channelsPerBlock[cid.BlockHeight],
×
1369
                                        chanInfo,
×
1370
                                )
×
1371

×
1372
                                continue
×
1373
                        }
1374

1375
                        //nolint:ll
1376
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1377
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1378
                                        Version:   int16(ProtocolV1),
×
1379
                                        ChannelID: dbChan.ID,
×
1380
                                        NodeID:    dbChan.NodeID1,
×
1381
                                },
×
1382
                        )
×
1383
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1384
                                return fmt.Errorf("unable to fetch node1 "+
×
1385
                                        "policy: %w", err)
×
1386
                        } else if err == nil {
×
1387
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1388
                                        node1Policy.LastUpdate.Int64, 0,
×
1389
                                )
×
1390
                        }
×
1391

1392
                        //nolint:ll
1393
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1394
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1395
                                        Version:   int16(ProtocolV1),
×
1396
                                        ChannelID: dbChan.ID,
×
1397
                                        NodeID:    dbChan.NodeID2,
×
1398
                                },
×
1399
                        )
×
1400
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1401
                                return fmt.Errorf("unable to fetch node2 "+
×
1402
                                        "policy: %w", err)
×
1403
                        } else if err == nil {
×
1404
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1405
                                        node2Policy.LastUpdate.Int64, 0,
×
1406
                                )
×
1407
                        }
×
1408

1409
                        channelsPerBlock[cid.BlockHeight] = append(
×
1410
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1411
                        )
×
1412
                }
1413

1414
                return nil
×
1415
        }, func() {
×
1416
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1417
        })
×
1418
        if err != nil {
×
1419
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1420
        }
×
1421

1422
        if len(channelsPerBlock) == 0 {
×
1423
                return nil, nil
×
1424
        }
×
1425

1426
        // Return the channel ranges in ascending block height order.
1427
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1428
        slices.Sort(blocks)
×
1429

×
1430
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1431
                return BlockChannelRange{
×
1432
                        Height:   block,
×
1433
                        Channels: channelsPerBlock[block],
×
1434
                }
×
1435
        }), nil
×
1436
}
1437

1438
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1439
// zombie. This method is used on an ad-hoc basis, when channels need to be
1440
// marked as zombies outside the normal pruning cycle.
1441
//
1442
// NOTE: part of the V1Store interface.
1443
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1444
        pubKey1, pubKey2 [33]byte) error {
×
1445

×
1446
        ctx := context.TODO()
×
1447

×
1448
        s.cacheMu.Lock()
×
1449
        defer s.cacheMu.Unlock()
×
1450

×
1451
        chanIDB := channelIDToBytes(chanID)
×
1452

×
1453
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1454
                return db.UpsertZombieChannel(
×
1455
                        ctx, sqlc.UpsertZombieChannelParams{
×
1456
                                Version:  int16(ProtocolV1),
×
1457
                                Scid:     chanIDB,
×
1458
                                NodeKey1: pubKey1[:],
×
1459
                                NodeKey2: pubKey2[:],
×
1460
                        },
×
1461
                )
×
1462
        }, sqldb.NoOpReset)
×
1463
        if err != nil {
×
1464
                return fmt.Errorf("unable to upsert zombie channel "+
×
1465
                        "(channel_id=%d): %w", chanID, err)
×
1466
        }
×
1467

1468
        s.rejectCache.remove(chanID)
×
1469
        s.chanCache.remove(chanID)
×
1470

×
1471
        return nil
×
1472
}
1473

1474
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1475
//
1476
// NOTE: part of the V1Store interface.
1477
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1478
        s.cacheMu.Lock()
×
1479
        defer s.cacheMu.Unlock()
×
1480

×
1481
        var (
×
1482
                ctx     = context.TODO()
×
1483
                chanIDB = channelIDToBytes(chanID)
×
1484
        )
×
1485

×
1486
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1487
                res, err := db.DeleteZombieChannel(
×
1488
                        ctx, sqlc.DeleteZombieChannelParams{
×
1489
                                Scid:    chanIDB,
×
1490
                                Version: int16(ProtocolV1),
×
1491
                        },
×
1492
                )
×
1493
                if err != nil {
×
1494
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1495
                                err)
×
1496
                }
×
1497

1498
                rows, err := res.RowsAffected()
×
1499
                if err != nil {
×
1500
                        return err
×
1501
                }
×
1502

1503
                if rows == 0 {
×
1504
                        return ErrZombieEdgeNotFound
×
1505
                } else if rows > 1 {
×
1506
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1507
                                "expected 1", rows)
×
1508
                }
×
1509

1510
                return nil
×
1511
        }, sqldb.NoOpReset)
1512
        if err != nil {
×
1513
                return fmt.Errorf("unable to mark edge live "+
×
1514
                        "(channel_id=%d): %w", chanID, err)
×
1515
        }
×
1516

1517
        s.rejectCache.remove(chanID)
×
1518
        s.chanCache.remove(chanID)
×
1519

×
1520
        return err
×
1521
}
1522

1523
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1524
// zombie, then the two node public keys corresponding to this edge are also
1525
// returned.
1526
//
1527
// NOTE: part of the V1Store interface.
1528
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1529
        error) {
×
1530

×
1531
        var (
×
1532
                ctx              = context.TODO()
×
1533
                isZombie         bool
×
1534
                pubKey1, pubKey2 route.Vertex
×
1535
                chanIDB          = channelIDToBytes(chanID)
×
1536
        )
×
1537

×
1538
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1539
                zombie, err := db.GetZombieChannel(
×
1540
                        ctx, sqlc.GetZombieChannelParams{
×
1541
                                Scid:    chanIDB,
×
1542
                                Version: int16(ProtocolV1),
×
1543
                        },
×
1544
                )
×
1545
                if errors.Is(err, sql.ErrNoRows) {
×
1546
                        return nil
×
1547
                }
×
1548
                if err != nil {
×
1549
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1550
                                err)
×
1551
                }
×
1552

1553
                copy(pubKey1[:], zombie.NodeKey1)
×
1554
                copy(pubKey2[:], zombie.NodeKey2)
×
1555
                isZombie = true
×
1556

×
1557
                return nil
×
1558
        }, sqldb.NoOpReset)
1559
        if err != nil {
×
1560
                return false, route.Vertex{}, route.Vertex{},
×
1561
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1562
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1563
        }
×
1564

1565
        return isZombie, pubKey1, pubKey2, nil
×
1566
}
1567

1568
// NumZombies returns the current number of zombie channels in the graph.
1569
//
1570
// NOTE: part of the V1Store interface.
1571
func (s *SQLStore) NumZombies() (uint64, error) {
×
1572
        var (
×
1573
                ctx        = context.TODO()
×
1574
                numZombies uint64
×
1575
        )
×
1576
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1577
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1578
                if err != nil {
×
1579
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1580
                                err)
×
1581
                }
×
1582

1583
                numZombies = uint64(count)
×
1584

×
1585
                return nil
×
1586
        }, sqldb.NoOpReset)
1587
        if err != nil {
×
1588
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1589
        }
×
1590

1591
        return numZombies, nil
×
1592
}
1593

1594
// DeleteChannelEdges removes edges with the given channel IDs from the
1595
// database and marks them as zombies. This ensures that we're unable to re-add
1596
// it to our database once again. If an edge does not exist within the
1597
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1598
// true, then when we mark these edges as zombies, we'll set up the keys such
1599
// that we require the node that failed to send the fresh update to be the one
1600
// that resurrects the channel from its zombie state. The markZombie bool
1601
// denotes whether to mark the channel as a zombie.
1602
//
1603
// NOTE: part of the V1Store interface.
1604
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1605
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1606

×
1607
        s.cacheMu.Lock()
×
1608
        defer s.cacheMu.Unlock()
×
1609

×
1610
        // Keep track of which channels we end up finding so that we can
×
1611
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1612
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1613
        for _, chanID := range chanIDs {
×
1614
                chanLookup[chanID] = struct{}{}
×
1615
        }
×
1616

1617
        var (
×
1618
                ctx   = context.TODO()
×
1619
                edges []*models.ChannelEdgeInfo
×
1620
        )
×
1621
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1622
                // First, collect all channel rows.
×
1623
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1624
                chanCallBack := func(ctx context.Context,
×
1625
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1626

×
1627
                        // Deleting the entry from the map indicates that we
×
1628
                        // have found the channel.
×
1629
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1630
                        delete(chanLookup, scid)
×
1631

×
1632
                        channelRows = append(channelRows, row)
×
1633

×
1634
                        return nil
×
1635
                }
×
1636

1637
                err := s.forEachChanWithPoliciesInSCIDList(
×
1638
                        ctx, db, chanCallBack, chanIDs,
×
1639
                )
×
1640
                if err != nil {
×
1641
                        return err
×
1642
                }
×
1643

1644
                if len(chanLookup) > 0 {
×
1645
                        return ErrEdgeNotFound
×
1646
                }
×
1647

1648
                if len(channelRows) == 0 {
×
1649
                        return nil
×
1650
                }
×
1651

1652
                // Batch build all channel edges.
1653
                var chanIDsToDelete []int64
×
1654
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
1655
                        ctx, s.cfg, db, channelRows,
×
1656
                )
×
1657
                if err != nil {
×
1658
                        return err
×
1659
                }
×
1660

1661
                if markZombie {
×
1662
                        for i, row := range channelRows {
×
1663
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1664

×
1665
                                err := handleZombieMarking(
×
1666
                                        ctx, db, row, edges[i],
×
1667
                                        strictZombiePruning, scid,
×
1668
                                )
×
1669
                                if err != nil {
×
1670
                                        return fmt.Errorf("unable to mark "+
×
1671
                                                "channel as zombie: %w", err)
×
1672
                                }
×
1673
                        }
1674
                }
1675

1676
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1677
        }, func() {
×
1678
                edges = nil
×
1679

×
1680
                // Re-fill the lookup map.
×
1681
                for _, chanID := range chanIDs {
×
1682
                        chanLookup[chanID] = struct{}{}
×
1683
                }
×
1684
        })
1685
        if err != nil {
×
1686
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1687
                        err)
×
1688
        }
×
1689

1690
        for _, chanID := range chanIDs {
×
1691
                s.rejectCache.remove(chanID)
×
1692
                s.chanCache.remove(chanID)
×
1693
        }
×
1694

1695
        return edges, nil
×
1696
}
1697

1698
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1699
// channel identified by the channel ID. If the channel can't be found, then
1700
// ErrEdgeNotFound is returned. A struct which houses the general information
1701
// for the channel itself is returned as well as two structs that contain the
1702
// routing policies for the channel in either direction.
1703
//
1704
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1705
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1706
// the ChannelEdgeInfo will only include the public keys of each node.
1707
//
1708
// NOTE: part of the V1Store interface.
1709
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1710
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1711
        *models.ChannelEdgePolicy, error) {
×
1712

×
1713
        var (
×
1714
                ctx              = context.TODO()
×
1715
                edge             *models.ChannelEdgeInfo
×
1716
                policy1, policy2 *models.ChannelEdgePolicy
×
1717
                chanIDB          = channelIDToBytes(chanID)
×
1718
        )
×
1719
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1720
                row, err := db.GetChannelBySCIDWithPolicies(
×
1721
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1722
                                Scid:    chanIDB,
×
1723
                                Version: int16(ProtocolV1),
×
1724
                        },
×
1725
                )
×
1726
                if errors.Is(err, sql.ErrNoRows) {
×
1727
                        // First check if this edge is perhaps in the zombie
×
1728
                        // index.
×
1729
                        zombie, err := db.GetZombieChannel(
×
1730
                                ctx, sqlc.GetZombieChannelParams{
×
1731
                                        Scid:    chanIDB,
×
1732
                                        Version: int16(ProtocolV1),
×
1733
                                },
×
1734
                        )
×
1735
                        if errors.Is(err, sql.ErrNoRows) {
×
1736
                                return ErrEdgeNotFound
×
1737
                        } else if err != nil {
×
1738
                                return fmt.Errorf("unable to check if "+
×
1739
                                        "channel is zombie: %w", err)
×
1740
                        }
×
1741

1742
                        // At this point, we know the channel is a zombie, so
1743
                        // we'll return an error indicating this, and we will
1744
                        // populate the edge info with the public keys of each
1745
                        // party as this is the only information we have about
1746
                        // it.
1747
                        edge = &models.ChannelEdgeInfo{}
×
1748
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1749
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1750

×
1751
                        return ErrZombieEdge
×
1752
                } else if err != nil {
×
1753
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1754
                }
×
1755

1756
                node1, node2, err := buildNodeVertices(
×
1757
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1758
                )
×
1759
                if err != nil {
×
1760
                        return err
×
1761
                }
×
1762

1763
                edge, err = getAndBuildEdgeInfo(
×
NEW
1764
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
1765
                )
×
1766
                if err != nil {
×
1767
                        return fmt.Errorf("unable to build channel info: %w",
×
1768
                                err)
×
1769
                }
×
1770

1771
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1772
                if err != nil {
×
1773
                        return fmt.Errorf("unable to extract channel "+
×
1774
                                "policies: %w", err)
×
1775
                }
×
1776

1777
                policy1, policy2, err = getAndBuildChanPolicies(
×
NEW
1778
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
NEW
1779
                        node1, node2,
×
1780
                )
×
1781
                if err != nil {
×
1782
                        return fmt.Errorf("unable to build channel "+
×
1783
                                "policies: %w", err)
×
1784
                }
×
1785

1786
                return nil
×
1787
        }, sqldb.NoOpReset)
1788
        if err != nil {
×
1789
                // If we are returning the ErrZombieEdge, then we also need to
×
1790
                // return the edge info as the method comment indicates that
×
1791
                // this will be populated when the edge is a zombie.
×
1792
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1793
                        err)
×
1794
        }
×
1795

1796
        return edge, policy1, policy2, nil
×
1797
}
1798

1799
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1800
// the channel identified by the funding outpoint. If the channel can't be
1801
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1802
// information for the channel itself is returned as well as two structs that
1803
// contain the routing policies for the channel in either direction.
1804
//
1805
// NOTE: part of the V1Store interface.
1806
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1807
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1808
        *models.ChannelEdgePolicy, error) {
×
1809

×
1810
        var (
×
1811
                ctx              = context.TODO()
×
1812
                edge             *models.ChannelEdgeInfo
×
1813
                policy1, policy2 *models.ChannelEdgePolicy
×
1814
        )
×
1815
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1816
                row, err := db.GetChannelByOutpointWithPolicies(
×
1817
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1818
                                Outpoint: op.String(),
×
1819
                                Version:  int16(ProtocolV1),
×
1820
                        },
×
1821
                )
×
1822
                if errors.Is(err, sql.ErrNoRows) {
×
1823
                        return ErrEdgeNotFound
×
1824
                } else if err != nil {
×
1825
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1826
                }
×
1827

1828
                node1, node2, err := buildNodeVertices(
×
1829
                        row.Node1Pubkey, row.Node2Pubkey,
×
1830
                )
×
1831
                if err != nil {
×
1832
                        return err
×
1833
                }
×
1834

1835
                edge, err = getAndBuildEdgeInfo(
×
NEW
1836
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
1837
                )
×
1838
                if err != nil {
×
1839
                        return fmt.Errorf("unable to build channel info: %w",
×
1840
                                err)
×
1841
                }
×
1842

1843
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1844
                if err != nil {
×
1845
                        return fmt.Errorf("unable to extract channel "+
×
1846
                                "policies: %w", err)
×
1847
                }
×
1848

1849
                policy1, policy2, err = getAndBuildChanPolicies(
×
NEW
1850
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
NEW
1851
                        node1, node2,
×
1852
                )
×
1853
                if err != nil {
×
1854
                        return fmt.Errorf("unable to build channel "+
×
1855
                                "policies: %w", err)
×
1856
                }
×
1857

1858
                return nil
×
1859
        }, sqldb.NoOpReset)
1860
        if err != nil {
×
1861
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1862
                        err)
×
1863
        }
×
1864

1865
        return edge, policy1, policy2, nil
×
1866
}
1867

1868
// HasChannelEdge returns true if the database knows of a channel edge with the
1869
// passed channel ID, and false otherwise. If an edge with that ID is found
1870
// within the graph, then two time stamps representing the last time the edge
1871
// was updated for both directed edges are returned along with the boolean. If
1872
// it is not found, then the zombie index is checked and its result is returned
1873
// as the second boolean.
1874
//
1875
// NOTE: part of the V1Store interface.
1876
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1877
        bool, error) {
×
1878

×
1879
        ctx := context.TODO()
×
1880

×
1881
        var (
×
1882
                exists          bool
×
1883
                isZombie        bool
×
1884
                node1LastUpdate time.Time
×
1885
                node2LastUpdate time.Time
×
1886
        )
×
1887

×
1888
        // We'll query the cache with the shared lock held to allow multiple
×
1889
        // readers to access values in the cache concurrently if they exist.
×
1890
        s.cacheMu.RLock()
×
1891
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1892
                s.cacheMu.RUnlock()
×
1893
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1894
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1895
                exists, isZombie = entry.flags.unpack()
×
1896

×
1897
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1898
        }
×
1899
        s.cacheMu.RUnlock()
×
1900

×
1901
        s.cacheMu.Lock()
×
1902
        defer s.cacheMu.Unlock()
×
1903

×
1904
        // The item was not found with the shared lock, so we'll acquire the
×
1905
        // exclusive lock and check the cache again in case another method added
×
1906
        // the entry to the cache while no lock was held.
×
1907
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1908
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1909
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1910
                exists, isZombie = entry.flags.unpack()
×
1911

×
1912
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1913
        }
×
1914

1915
        chanIDB := channelIDToBytes(chanID)
×
1916
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1917
                channel, err := db.GetChannelBySCID(
×
1918
                        ctx, sqlc.GetChannelBySCIDParams{
×
1919
                                Scid:    chanIDB,
×
1920
                                Version: int16(ProtocolV1),
×
1921
                        },
×
1922
                )
×
1923
                if errors.Is(err, sql.ErrNoRows) {
×
1924
                        // Check if it is a zombie channel.
×
1925
                        isZombie, err = db.IsZombieChannel(
×
1926
                                ctx, sqlc.IsZombieChannelParams{
×
1927
                                        Scid:    chanIDB,
×
1928
                                        Version: int16(ProtocolV1),
×
1929
                                },
×
1930
                        )
×
1931
                        if err != nil {
×
1932
                                return fmt.Errorf("could not check if channel "+
×
1933
                                        "is zombie: %w", err)
×
1934
                        }
×
1935

1936
                        return nil
×
1937
                } else if err != nil {
×
1938
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1939
                }
×
1940

1941
                exists = true
×
1942

×
1943
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
1944
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1945
                                Version:   int16(ProtocolV1),
×
1946
                                ChannelID: channel.ID,
×
1947
                                NodeID:    channel.NodeID1,
×
1948
                        },
×
1949
                )
×
1950
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1951
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
1952
                                err)
×
1953
                } else if err == nil {
×
1954
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
1955
                }
×
1956

1957
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
1958
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1959
                                Version:   int16(ProtocolV1),
×
1960
                                ChannelID: channel.ID,
×
1961
                                NodeID:    channel.NodeID2,
×
1962
                        },
×
1963
                )
×
1964
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1965
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
1966
                                err)
×
1967
                } else if err == nil {
×
1968
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
1969
                }
×
1970

1971
                return nil
×
1972
        }, sqldb.NoOpReset)
1973
        if err != nil {
×
1974
                return time.Time{}, time.Time{}, false, false,
×
1975
                        fmt.Errorf("unable to fetch channel: %w", err)
×
1976
        }
×
1977

1978
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
1979
                upd1Time: node1LastUpdate.Unix(),
×
1980
                upd2Time: node2LastUpdate.Unix(),
×
1981
                flags:    packRejectFlags(exists, isZombie),
×
1982
        })
×
1983

×
1984
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1985
}
1986

1987
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
1988
// passed channel point (outpoint). If the passed channel doesn't exist within
1989
// the database, then ErrEdgeNotFound is returned.
1990
//
1991
// NOTE: part of the V1Store interface.
1992
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
1993
        var (
×
1994
                ctx       = context.TODO()
×
1995
                channelID uint64
×
1996
        )
×
1997
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1998
                chanID, err := db.GetSCIDByOutpoint(
×
1999
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2000
                                Outpoint: chanPoint.String(),
×
2001
                                Version:  int16(ProtocolV1),
×
2002
                        },
×
2003
                )
×
2004
                if errors.Is(err, sql.ErrNoRows) {
×
2005
                        return ErrEdgeNotFound
×
2006
                } else if err != nil {
×
2007
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2008
                                err)
×
2009
                }
×
2010

2011
                channelID = byteOrder.Uint64(chanID)
×
2012

×
2013
                return nil
×
2014
        }, sqldb.NoOpReset)
2015
        if err != nil {
×
2016
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2017
        }
×
2018

2019
        return channelID, nil
×
2020
}
2021

2022
// IsPublicNode is a helper method that determines whether the node with the
2023
// given public key is seen as a public node in the graph from the graph's
2024
// source node's point of view.
2025
//
2026
// NOTE: part of the V1Store interface.
2027
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2028
        ctx := context.TODO()
×
2029

×
2030
        var isPublic bool
×
2031
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2032
                var err error
×
2033
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2034

×
2035
                return err
×
2036
        }, sqldb.NoOpReset)
×
2037
        if err != nil {
×
2038
                return false, fmt.Errorf("unable to check if node is "+
×
2039
                        "public: %w", err)
×
2040
        }
×
2041

2042
        return isPublic, nil
×
2043
}
2044

2045
// FetchChanInfos returns the set of channel edges that correspond to the passed
2046
// channel ID's. If an edge is the query is unknown to the database, it will
2047
// skipped and the result will contain only those edges that exist at the time
2048
// of the query. This can be used to respond to peer queries that are seeking to
2049
// fill in gaps in their view of the channel graph.
2050
//
2051
// NOTE: part of the V1Store interface.
2052
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2053
        var (
×
2054
                ctx   = context.TODO()
×
2055
                edges = make(map[uint64]ChannelEdge)
×
2056
        )
×
2057
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2058
                // First, collect all channel rows.
×
2059
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2060
                chanCallBack := func(ctx context.Context,
×
2061
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2062

×
2063
                        channelRows = append(channelRows, row)
×
2064
                        return nil
×
2065
                }
×
2066

2067
                err := s.forEachChanWithPoliciesInSCIDList(
×
2068
                        ctx, db, chanCallBack, chanIDs,
×
2069
                )
×
2070
                if err != nil {
×
2071
                        return err
×
2072
                }
×
2073

2074
                if len(channelRows) == 0 {
×
2075
                        return nil
×
2076
                }
×
2077

2078
                // Batch build all channel edges.
2079
                chans, err := batchBuildChannelEdges(
×
2080
                        ctx, s.cfg, db, channelRows,
×
2081
                )
×
2082
                if err != nil {
×
2083
                        return fmt.Errorf("unable to build channel edges: %w",
×
2084
                                err)
×
2085
                }
×
2086

2087
                for _, c := range chans {
×
2088
                        edges[c.Info.ChannelID] = c
×
2089
                }
×
2090

2091
                return err
×
2092
        }, func() {
×
2093
                clear(edges)
×
2094
        })
×
2095
        if err != nil {
×
2096
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2097
        }
×
2098

2099
        res := make([]ChannelEdge, 0, len(edges))
×
2100
        for _, chanID := range chanIDs {
×
2101
                edge, ok := edges[chanID]
×
2102
                if !ok {
×
2103
                        continue
×
2104
                }
2105

2106
                res = append(res, edge)
×
2107
        }
2108

2109
        return res, nil
×
2110
}
2111

2112
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2113
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2114
// channels in a paginated manner.
2115
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2116
        db SQLQueries, cb func(ctx context.Context,
2117
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2118
        chanIDs []uint64) error {
×
2119

×
2120
        queryWrapper := func(ctx context.Context,
×
2121
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2122
                error) {
×
2123

×
2124
                return db.GetChannelsBySCIDWithPolicies(
×
2125
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2126
                                Version: int16(ProtocolV1),
×
2127
                                Scids:   scids,
×
2128
                        },
×
2129
                )
×
2130
        }
×
2131

2132
        return sqldb.ExecuteBatchQuery(
×
2133
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2134
                cb,
×
2135
        )
×
2136
}
2137

2138
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2139
// ID's that we don't know and are not known zombies of the passed set. In other
2140
// words, we perform a set difference of our set of chan ID's and the ones
2141
// passed in. This method can be used by callers to determine the set of
2142
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2143
// known zombies is also returned.
2144
//
2145
// NOTE: part of the V1Store interface.
2146
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2147
        []ChannelUpdateInfo, error) {
×
2148

×
2149
        var (
×
2150
                ctx          = context.TODO()
×
2151
                newChanIDs   []uint64
×
2152
                knownZombies []ChannelUpdateInfo
×
2153
                infoLookup   = make(
×
2154
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2155
                )
×
2156
        )
×
2157

×
2158
        // We first build a lookup map of the channel ID's to the
×
2159
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2160
        // already know about.
×
2161
        for _, chanInfo := range chansInfo {
×
2162
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2163
        }
×
2164

2165
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2166
                // The call-back function deletes known channels from
×
2167
                // infoLookup, so that we can later check which channels are
×
2168
                // zombies by only looking at the remaining channels in the set.
×
2169
                cb := func(ctx context.Context,
×
2170
                        channel sqlc.GraphChannel) error {
×
2171

×
2172
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2173

×
2174
                        return nil
×
2175
                }
×
2176

2177
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2178
                if err != nil {
×
2179
                        return fmt.Errorf("unable to iterate through "+
×
2180
                                "channels: %w", err)
×
2181
                }
×
2182

2183
                // We want to ensure that we deal with the channels in the
2184
                // same order that they were passed in, so we iterate over the
2185
                // original chansInfo slice and then check if that channel is
2186
                // still in the infoLookup map.
2187
                for _, chanInfo := range chansInfo {
×
2188
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2189
                        if _, ok := infoLookup[channelID]; !ok {
×
2190
                                continue
×
2191
                        }
2192

2193
                        isZombie, err := db.IsZombieChannel(
×
2194
                                ctx, sqlc.IsZombieChannelParams{
×
2195
                                        Scid:    channelIDToBytes(channelID),
×
2196
                                        Version: int16(ProtocolV1),
×
2197
                                },
×
2198
                        )
×
2199
                        if err != nil {
×
2200
                                return fmt.Errorf("unable to fetch zombie "+
×
2201
                                        "channel: %w", err)
×
2202
                        }
×
2203

2204
                        if isZombie {
×
2205
                                knownZombies = append(knownZombies, chanInfo)
×
2206

×
2207
                                continue
×
2208
                        }
2209

2210
                        newChanIDs = append(newChanIDs, channelID)
×
2211
                }
2212

2213
                return nil
×
2214
        }, func() {
×
2215
                newChanIDs = nil
×
2216
                knownZombies = nil
×
2217
                // Rebuild the infoLookup map in case of a rollback.
×
2218
                for _, chanInfo := range chansInfo {
×
2219
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2220
                        infoLookup[scid] = chanInfo
×
2221
                }
×
2222
        })
2223
        if err != nil {
×
2224
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2225
        }
×
2226

2227
        return newChanIDs, knownZombies, nil
×
2228
}
2229

2230
// forEachChanInSCIDList is a helper method that executes a paged query
2231
// against the database to fetch all channels that match the passed
2232
// ChannelUpdateInfo slice. The callback function is called for each channel
2233
// that is found.
2234
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2235
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2236
        chansInfo []ChannelUpdateInfo) error {
×
2237

×
2238
        queryWrapper := func(ctx context.Context,
×
2239
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2240

×
2241
                return db.GetChannelsBySCIDs(
×
2242
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2243
                                Version: int16(ProtocolV1),
×
2244
                                Scids:   scids,
×
2245
                        },
×
2246
                )
×
2247
        }
×
2248

2249
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2250
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2251

×
2252
                return channelIDToBytes(channelID)
×
2253
        }
×
2254

2255
        return sqldb.ExecuteBatchQuery(
×
2256
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2257
                cb,
×
2258
        )
×
2259
}
2260

2261
// PruneGraphNodes is a garbage collection method which attempts to prune out
2262
// any nodes from the channel graph that are currently unconnected. This ensure
2263
// that we only maintain a graph of reachable nodes. In the event that a pruned
2264
// node gains more channels, it will be re-added back to the graph.
2265
//
2266
// NOTE: this prunes nodes across protocol versions. It will never prune the
2267
// source nodes.
2268
//
2269
// NOTE: part of the V1Store interface.
2270
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2271
        var ctx = context.TODO()
×
2272

×
2273
        var prunedNodes []route.Vertex
×
2274
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2275
                var err error
×
2276
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2277

×
2278
                return err
×
2279
        }, func() {
×
2280
                prunedNodes = nil
×
2281
        })
×
2282
        if err != nil {
×
2283
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2284
        }
×
2285

2286
        return prunedNodes, nil
×
2287
}
2288

2289
// PruneGraph prunes newly closed channels from the channel graph in response
2290
// to a new block being solved on the network. Any transactions which spend the
2291
// funding output of any known channels within he graph will be deleted.
2292
// Additionally, the "prune tip", or the last block which has been used to
2293
// prune the graph is stored so callers can ensure the graph is fully in sync
2294
// with the current UTXO state. A slice of channels that have been closed by
2295
// the target block along with any pruned nodes are returned if the function
2296
// succeeds without error.
2297
//
2298
// NOTE: part of the V1Store interface.
2299
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2300
        blockHash *chainhash.Hash, blockHeight uint32) (
2301
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2302

×
2303
        ctx := context.TODO()
×
2304

×
2305
        s.cacheMu.Lock()
×
2306
        defer s.cacheMu.Unlock()
×
2307

×
2308
        var (
×
2309
                closedChans []*models.ChannelEdgeInfo
×
2310
                prunedNodes []route.Vertex
×
2311
        )
×
2312
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2313
                // First, collect all channel rows that need to be pruned.
×
2314
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2315
                channelCallback := func(ctx context.Context,
×
2316
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2317

×
2318
                        channelRows = append(channelRows, row)
×
2319

×
2320
                        return nil
×
2321
                }
×
2322

2323
                err := s.forEachChanInOutpoints(
×
2324
                        ctx, db, spentOutputs, channelCallback,
×
2325
                )
×
2326
                if err != nil {
×
2327
                        return fmt.Errorf("unable to fetch channels by "+
×
2328
                                "outpoints: %w", err)
×
2329
                }
×
2330

2331
                if len(channelRows) == 0 {
×
2332
                        // There are no channels to prune. So we can exit early
×
2333
                        // after updating the prune log.
×
2334
                        err = db.UpsertPruneLogEntry(
×
2335
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
2336
                                        BlockHash:   blockHash[:],
×
2337
                                        BlockHeight: int64(blockHeight),
×
2338
                                },
×
2339
                        )
×
2340
                        if err != nil {
×
2341
                                return fmt.Errorf("unable to insert prune log "+
×
2342
                                        "entry: %w", err)
×
2343
                        }
×
2344

2345
                        return nil
×
2346
                }
2347

2348
                // Batch build all channel edges for pruning.
2349
                var chansToDelete []int64
×
2350
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
2351
                        ctx, s.cfg, db, channelRows,
×
2352
                )
×
2353
                if err != nil {
×
2354
                        return err
×
2355
                }
×
2356

2357
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2358
                if err != nil {
×
2359
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2360
                }
×
2361

2362
                err = db.UpsertPruneLogEntry(
×
2363
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2364
                                BlockHash:   blockHash[:],
×
2365
                                BlockHeight: int64(blockHeight),
×
2366
                        },
×
2367
                )
×
2368
                if err != nil {
×
2369
                        return fmt.Errorf("unable to insert prune log "+
×
2370
                                "entry: %w", err)
×
2371
                }
×
2372

2373
                // Now that we've pruned some channels, we'll also prune any
2374
                // nodes that no longer have any channels.
2375
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2376
                if err != nil {
×
2377
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2378
                                err)
×
2379
                }
×
2380

2381
                return nil
×
2382
        }, func() {
×
2383
                prunedNodes = nil
×
2384
                closedChans = nil
×
2385
        })
×
2386
        if err != nil {
×
2387
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2388
        }
×
2389

2390
        for _, channel := range closedChans {
×
2391
                s.rejectCache.remove(channel.ChannelID)
×
2392
                s.chanCache.remove(channel.ChannelID)
×
2393
        }
×
2394

2395
        return closedChans, prunedNodes, nil
×
2396
}
2397

2398
// forEachChanInOutpoints is a helper function that executes a paginated
2399
// query to fetch channels by their outpoints and applies the given call-back
2400
// to each.
2401
//
2402
// NOTE: this fetches channels for all protocol versions.
2403
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2404
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2405
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2406

×
2407
        // Create a wrapper that uses the transaction's db instance to execute
×
2408
        // the query.
×
2409
        queryWrapper := func(ctx context.Context,
×
2410
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2411
                error) {
×
2412

×
2413
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2414
        }
×
2415

2416
        // Define the conversion function from Outpoint to string.
2417
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2418
                return outpoint.String()
×
2419
        }
×
2420

2421
        return sqldb.ExecuteBatchQuery(
×
2422
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2423
                queryWrapper, cb,
×
2424
        )
×
2425
}
2426

2427
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2428
        dbIDs []int64) error {
×
2429

×
2430
        // Create a wrapper that uses the transaction's db instance to execute
×
2431
        // the query.
×
2432
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2433
                return nil, db.DeleteChannels(ctx, ids)
×
2434
        }
×
2435

2436
        idConverter := func(id int64) int64 {
×
2437
                return id
×
2438
        }
×
2439

2440
        return sqldb.ExecuteBatchQuery(
×
2441
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2442
                queryWrapper, func(ctx context.Context, _ any) error {
×
2443
                        return nil
×
2444
                },
×
2445
        )
2446
}
2447

2448
// ChannelView returns the verifiable edge information for each active channel
2449
// within the known channel graph. The set of UTXOs (along with their scripts)
2450
// returned are the ones that need to be watched on chain to detect channel
2451
// closes on the resident blockchain.
2452
//
2453
// NOTE: part of the V1Store interface.
2454
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2455
        var (
×
2456
                ctx        = context.TODO()
×
2457
                edgePoints []EdgePoint
×
2458
        )
×
2459

×
2460
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2461
                handleChannel := func(_ context.Context,
×
2462
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2463

×
2464
                        pkScript, err := genMultiSigP2WSH(
×
2465
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2466
                        )
×
2467
                        if err != nil {
×
2468
                                return err
×
2469
                        }
×
2470

2471
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2472
                        if err != nil {
×
2473
                                return err
×
2474
                        }
×
2475

2476
                        edgePoints = append(edgePoints, EdgePoint{
×
2477
                                FundingPkScript: pkScript,
×
2478
                                OutPoint:        *op,
×
2479
                        })
×
2480

×
2481
                        return nil
×
2482
                }
2483

2484
                queryFunc := func(ctx context.Context, lastID int64,
×
2485
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2486

×
2487
                        return db.ListChannelsPaginated(
×
2488
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2489
                                        Version: int16(ProtocolV1),
×
2490
                                        ID:      lastID,
×
2491
                                        Limit:   limit,
×
2492
                                },
×
2493
                        )
×
2494
                }
×
2495

2496
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2497
                        return row.ID
×
2498
                }
×
2499

2500
                return sqldb.ExecutePaginatedQuery(
×
2501
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
2502
                        extractCursor, handleChannel,
×
2503
                )
×
2504
        }, func() {
×
2505
                edgePoints = nil
×
2506
        })
×
2507
        if err != nil {
×
2508
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2509
        }
×
2510

2511
        return edgePoints, nil
×
2512
}
2513

2514
// PruneTip returns the block height and hash of the latest block that has been
2515
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2516
// to tell if the graph is currently in sync with the current best known UTXO
2517
// state.
2518
//
2519
// NOTE: part of the V1Store interface.
2520
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2521
        var (
×
2522
                ctx       = context.TODO()
×
2523
                tipHash   chainhash.Hash
×
2524
                tipHeight uint32
×
2525
        )
×
2526
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2527
                pruneTip, err := db.GetPruneTip(ctx)
×
2528
                if errors.Is(err, sql.ErrNoRows) {
×
2529
                        return ErrGraphNeverPruned
×
2530
                } else if err != nil {
×
2531
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2532
                }
×
2533

2534
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2535
                tipHeight = uint32(pruneTip.BlockHeight)
×
2536

×
2537
                return nil
×
2538
        }, sqldb.NoOpReset)
2539
        if err != nil {
×
2540
                return nil, 0, err
×
2541
        }
×
2542

2543
        return &tipHash, tipHeight, nil
×
2544
}
2545

2546
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2547
//
2548
// NOTE: this prunes nodes across protocol versions. It will never prune the
2549
// source nodes.
2550
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2551
        db SQLQueries) ([]route.Vertex, error) {
×
2552

×
2553
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2554
        if err != nil {
×
2555
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2556
                        "nodes: %w", err)
×
2557
        }
×
2558

2559
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2560
        for i, nodeKey := range nodeKeys {
×
2561
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2562
                if err != nil {
×
2563
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2564
                                "from bytes: %w", err)
×
2565
                }
×
2566

2567
                prunedNodes[i] = pub
×
2568
        }
2569

2570
        return prunedNodes, nil
×
2571
}
2572

2573
// DisconnectBlockAtHeight is used to indicate that the block specified
2574
// by the passed height has been disconnected from the main chain. This
2575
// will "rewind" the graph back to the height below, deleting channels
2576
// that are no longer confirmed from the graph. The prune log will be
2577
// set to the last prune height valid for the remaining chain.
2578
// Channels that were removed from the graph resulting from the
2579
// disconnected block are returned.
2580
//
2581
// NOTE: part of the V1Store interface.
2582
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2583
        []*models.ChannelEdgeInfo, error) {
×
2584

×
2585
        ctx := context.TODO()
×
2586

×
2587
        var (
×
2588
                // Every channel having a ShortChannelID starting at 'height'
×
2589
                // will no longer be confirmed.
×
2590
                startShortChanID = lnwire.ShortChannelID{
×
2591
                        BlockHeight: height,
×
2592
                }
×
2593

×
2594
                // Delete everything after this height from the db up until the
×
2595
                // SCID alias range.
×
2596
                endShortChanID = aliasmgr.StartingAlias
×
2597

×
2598
                removedChans []*models.ChannelEdgeInfo
×
2599

×
2600
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2601
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2602
        )
×
2603

×
2604
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2605
                rows, err := db.GetChannelsBySCIDRange(
×
2606
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2607
                                StartScid: chanIDStart,
×
2608
                                EndScid:   chanIDEnd,
×
2609
                        },
×
2610
                )
×
2611
                if err != nil {
×
2612
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2613
                }
×
2614

2615
                if len(rows) == 0 {
×
2616
                        // No channels to disconnect, but still clean up prune
×
2617
                        // log.
×
2618
                        return db.DeletePruneLogEntriesInRange(
×
2619
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2620
                                        StartHeight: int64(height),
×
2621
                                        EndHeight: int64(
×
2622
                                                endShortChanID.BlockHeight,
×
2623
                                        ),
×
2624
                                },
×
2625
                        )
×
2626
                }
×
2627

2628
                // Batch build all channel edges for disconnection.
2629
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
2630
                        ctx, s.cfg, db, rows,
×
2631
                )
×
2632
                if err != nil {
×
2633
                        return err
×
2634
                }
×
2635

2636
                removedChans = channelEdges
×
2637

×
2638
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2639
                if err != nil {
×
2640
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2641
                }
×
2642

2643
                return db.DeletePruneLogEntriesInRange(
×
2644
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2645
                                StartHeight: int64(height),
×
2646
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2647
                        },
×
2648
                )
×
2649
        }, func() {
×
2650
                removedChans = nil
×
2651
        })
×
2652
        if err != nil {
×
2653
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2654
                        "height: %w", err)
×
2655
        }
×
2656

2657
        for _, channel := range removedChans {
×
2658
                s.rejectCache.remove(channel.ChannelID)
×
2659
                s.chanCache.remove(channel.ChannelID)
×
2660
        }
×
2661

2662
        return removedChans, nil
×
2663
}
2664

2665
// AddEdgeProof sets the proof of an existing edge in the graph database.
2666
//
2667
// NOTE: part of the V1Store interface.
2668
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2669
        proof *models.ChannelAuthProof) error {
×
2670

×
2671
        var (
×
2672
                ctx       = context.TODO()
×
2673
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2674
        )
×
2675

×
2676
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2677
                res, err := db.AddV1ChannelProof(
×
2678
                        ctx, sqlc.AddV1ChannelProofParams{
×
2679
                                Scid:              scidBytes,
×
2680
                                Node1Signature:    proof.NodeSig1Bytes,
×
2681
                                Node2Signature:    proof.NodeSig2Bytes,
×
2682
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2683
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2684
                        },
×
2685
                )
×
2686
                if err != nil {
×
2687
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2688
                }
×
2689

2690
                n, err := res.RowsAffected()
×
2691
                if err != nil {
×
2692
                        return err
×
2693
                }
×
2694

2695
                if n == 0 {
×
2696
                        return fmt.Errorf("no rows affected when adding edge "+
×
2697
                                "proof for SCID %v", scid)
×
2698
                } else if n > 1 {
×
2699
                        return fmt.Errorf("multiple rows affected when adding "+
×
2700
                                "edge proof for SCID %v: %d rows affected",
×
2701
                                scid, n)
×
2702
                }
×
2703

2704
                return nil
×
2705
        }, sqldb.NoOpReset)
2706
        if err != nil {
×
2707
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2708
        }
×
2709

2710
        return nil
×
2711
}
2712

2713
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2714
// that we can ignore channel announcements that we know to be closed without
2715
// having to validate them and fetch a block.
2716
//
2717
// NOTE: part of the V1Store interface.
2718
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2719
        var (
×
2720
                ctx     = context.TODO()
×
2721
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2722
        )
×
2723

×
2724
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2725
                return db.InsertClosedChannel(ctx, chanIDB)
×
2726
        }, sqldb.NoOpReset)
×
2727
}
2728

2729
// IsClosedScid checks whether a channel identified by the passed in scid is
2730
// closed. This helps avoid having to perform expensive validation checks.
2731
//
2732
// NOTE: part of the V1Store interface.
2733
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2734
        var (
×
2735
                ctx      = context.TODO()
×
2736
                isClosed bool
×
2737
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2738
        )
×
2739
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2740
                var err error
×
2741
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2742
                if err != nil {
×
2743
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2744
                                err)
×
2745
                }
×
2746

2747
                return nil
×
2748
        }, sqldb.NoOpReset)
2749
        if err != nil {
×
2750
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2751
                        err)
×
2752
        }
×
2753

2754
        return isClosed, nil
×
2755
}
2756

2757
// GraphSession will provide the call-back with access to a NodeTraverser
2758
// instance which can be used to perform queries against the channel graph.
2759
//
2760
// NOTE: part of the V1Store interface.
2761
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
2762
        reset func()) error {
×
2763

×
2764
        var ctx = context.TODO()
×
2765

×
2766
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2767
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2768
        }, reset)
×
2769
}
2770

2771
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2772
// read only transaction for a consistent view of the graph.
2773
type sqlNodeTraverser struct {
2774
        db    SQLQueries
2775
        chain chainhash.Hash
2776
}
2777

2778
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2779
// NodeTraverser interface.
2780
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2781

2782
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2783
func newSQLNodeTraverser(db SQLQueries,
2784
        chain chainhash.Hash) *sqlNodeTraverser {
×
2785

×
2786
        return &sqlNodeTraverser{
×
2787
                db:    db,
×
2788
                chain: chain,
×
2789
        }
×
2790
}
×
2791

2792
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2793
// node.
2794
//
2795
// NOTE: Part of the NodeTraverser interface.
2796
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2797
        cb func(channel *DirectedChannel) error, _ func()) error {
×
2798

×
2799
        ctx := context.TODO()
×
2800

×
2801
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2802
}
×
2803

2804
// FetchNodeFeatures returns the features of the given node. If the node is
2805
// unknown, assume no additional features are supported.
2806
//
2807
// NOTE: Part of the NodeTraverser interface.
2808
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2809
        *lnwire.FeatureVector, error) {
×
2810

×
2811
        ctx := context.TODO()
×
2812

×
2813
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2814
}
×
2815

2816
// forEachNodeDirectedChannel iterates through all channels of a given
2817
// node, executing the passed callback on the directed edge representing the
2818
// channel and its incoming policy. If the node is not found, no error is
2819
// returned.
2820
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2821
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2822

×
2823
        toNodeCallback := func() route.Vertex {
×
2824
                return nodePub
×
2825
        }
×
2826

2827
        dbID, err := db.GetNodeIDByPubKey(
×
2828
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2829
                        Version: int16(ProtocolV1),
×
2830
                        PubKey:  nodePub[:],
×
2831
                },
×
2832
        )
×
2833
        if errors.Is(err, sql.ErrNoRows) {
×
2834
                return nil
×
2835
        } else if err != nil {
×
2836
                return fmt.Errorf("unable to fetch node: %w", err)
×
2837
        }
×
2838

2839
        rows, err := db.ListChannelsByNodeID(
×
2840
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2841
                        Version: int16(ProtocolV1),
×
2842
                        NodeID1: dbID,
×
2843
                },
×
2844
        )
×
2845
        if err != nil {
×
2846
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2847
        }
×
2848

2849
        // Exit early if there are no channels for this node so we don't
2850
        // do the unnecessary feature fetching.
2851
        if len(rows) == 0 {
×
2852
                return nil
×
2853
        }
×
2854

2855
        features, err := getNodeFeatures(ctx, db, dbID)
×
2856
        if err != nil {
×
2857
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2858
        }
×
2859

2860
        for _, row := range rows {
×
2861
                node1, node2, err := buildNodeVertices(
×
2862
                        row.Node1Pubkey, row.Node2Pubkey,
×
2863
                )
×
2864
                if err != nil {
×
2865
                        return fmt.Errorf("unable to build node vertices: %w",
×
2866
                                err)
×
2867
                }
×
2868

2869
                edge := buildCacheableChannelInfo(
×
2870
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
2871
                        node1, node2,
×
2872
                )
×
2873

×
2874
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2875
                if err != nil {
×
2876
                        return err
×
2877
                }
×
2878

2879
                p1, p2, err := buildCachedChanPolicies(
×
2880
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
2881
                )
×
2882
                if err != nil {
×
2883
                        return err
×
2884
                }
×
2885

2886
                // Determine the outgoing and incoming policy for this
2887
                // channel and node combo.
2888
                outPolicy, inPolicy := p1, p2
×
2889
                if p1 != nil && node2 == nodePub {
×
2890
                        outPolicy, inPolicy = p2, p1
×
2891
                } else if p2 != nil && node1 != nodePub {
×
2892
                        outPolicy, inPolicy = p2, p1
×
2893
                }
×
2894

2895
                var cachedInPolicy *models.CachedEdgePolicy
×
2896
                if inPolicy != nil {
×
2897
                        cachedInPolicy = inPolicy
×
2898
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2899
                        cachedInPolicy.ToNodeFeatures = features
×
2900
                }
×
2901

2902
                directedChannel := &DirectedChannel{
×
2903
                        ChannelID:    edge.ChannelID,
×
2904
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2905
                        OtherNode:    edge.NodeKey2Bytes,
×
2906
                        Capacity:     edge.Capacity,
×
2907
                        OutPolicySet: outPolicy != nil,
×
2908
                        InPolicy:     cachedInPolicy,
×
2909
                }
×
2910
                if outPolicy != nil {
×
2911
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2912
                                directedChannel.InboundFee = fee
×
2913
                        })
×
2914
                }
2915

2916
                if nodePub == edge.NodeKey2Bytes {
×
2917
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2918
                }
×
2919

2920
                if err := cb(directedChannel); err != nil {
×
2921
                        return err
×
2922
                }
×
2923
        }
2924

2925
        return nil
×
2926
}
2927

2928
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2929
// and executes the provided callback for each node. It does so via pagination
2930
// along with batch loading of the node feature bits.
2931
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
2932
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
2933
                features *lnwire.FeatureVector) error) error {
×
2934

×
2935
        handleNode := func(_ context.Context,
×
2936
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
2937
                featureBits map[int64][]int) error {
×
2938

×
2939
                fv := lnwire.EmptyFeatureVector()
×
2940
                if features, exists := featureBits[dbNode.ID]; exists {
×
2941
                        for _, bit := range features {
×
2942
                                fv.Set(lnwire.FeatureBit(bit))
×
2943
                        }
×
2944
                }
2945

2946
                var pub route.Vertex
×
2947
                copy(pub[:], dbNode.PubKey)
×
2948

×
2949
                return processNode(dbNode.ID, pub, fv)
×
2950
        }
2951

2952
        queryFunc := func(ctx context.Context, lastID int64,
×
2953
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
2954

×
2955
                return db.ListNodeIDsAndPubKeys(
×
2956
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
2957
                                Version: int16(ProtocolV1),
×
2958
                                ID:      lastID,
×
2959
                                Limit:   limit,
×
2960
                        },
×
2961
                )
×
2962
        }
×
2963

2964
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
2965
                return row.ID
×
2966
        }
×
2967

2968
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
2969
                return node.ID, nil
×
2970
        }
×
2971

2972
        batchQueryFunc := func(ctx context.Context,
×
2973
                nodeIDs []int64) (map[int64][]int, error) {
×
2974

×
2975
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
2976
        }
×
2977

2978
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
2979
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
2980
                batchQueryFunc, handleNode,
×
2981
        )
×
2982
}
2983

2984
// forEachNodeChannel iterates through all channels of a node, executing
2985
// the passed callback on each. The call-back is provided with the channel's
2986
// edge information, the outgoing policy and the incoming policy for the
2987
// channel and node combo.
2988
func forEachNodeChannel(ctx context.Context, db SQLQueries,
2989
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
2990
                *models.ChannelEdgePolicy,
2991
                *models.ChannelEdgePolicy) error) error {
×
2992

×
2993
        // Get all the V1 channels for this node.
×
2994
        rows, err := db.ListChannelsByNodeID(
×
2995
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2996
                        Version: int16(ProtocolV1),
×
2997
                        NodeID1: id,
×
2998
                },
×
2999
        )
×
3000
        if err != nil {
×
3001
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3002
        }
×
3003

3004
        // Collect all the channel and policy IDs.
3005
        var (
×
3006
                chanIDs   = make([]int64, 0, len(rows))
×
3007
                policyIDs = make([]int64, 0, 2*len(rows))
×
3008
        )
×
3009
        for _, row := range rows {
×
3010
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3011

×
3012
                if row.Policy1ID.Valid {
×
3013
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3014
                }
×
3015
                if row.Policy2ID.Valid {
×
3016
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3017
                }
×
3018
        }
3019

3020
        batchData, err := batchLoadChannelData(
×
3021
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3022
        )
×
3023
        if err != nil {
×
3024
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3025
        }
×
3026

3027
        // Call the call-back for each channel and its known policies.
3028
        for _, row := range rows {
×
3029
                node1, node2, err := buildNodeVertices(
×
3030
                        row.Node1Pubkey, row.Node2Pubkey,
×
3031
                )
×
3032
                if err != nil {
×
3033
                        return fmt.Errorf("unable to build node vertices: %w",
×
3034
                                err)
×
3035
                }
×
3036

3037
                edge, err := buildEdgeInfoWithBatchData(
×
3038
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3039
                        batchData,
×
3040
                )
×
3041
                if err != nil {
×
3042
                        return fmt.Errorf("unable to build channel info: %w",
×
3043
                                err)
×
3044
                }
×
3045

3046
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3047
                if err != nil {
×
3048
                        return fmt.Errorf("unable to extract channel "+
×
3049
                                "policies: %w", err)
×
3050
                }
×
3051

3052
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3053
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3054
                )
×
3055
                if err != nil {
×
3056
                        return fmt.Errorf("unable to build channel "+
×
3057
                                "policies: %w", err)
×
3058
                }
×
3059

3060
                // Determine the outgoing and incoming policy for this
3061
                // channel and node combo.
3062
                p1ToNode := row.GraphChannel.NodeID2
×
3063
                p2ToNode := row.GraphChannel.NodeID1
×
3064
                outPolicy, inPolicy := p1, p2
×
3065
                if (p1 != nil && p1ToNode == id) ||
×
3066
                        (p2 != nil && p2ToNode != id) {
×
3067

×
3068
                        outPolicy, inPolicy = p2, p1
×
3069
                }
×
3070

3071
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3072
                        return err
×
3073
                }
×
3074
        }
3075

3076
        return nil
×
3077
}
3078

3079
// updateChanEdgePolicy upserts the channel policy info we have stored for
3080
// a channel we already know of.
3081
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3082
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3083
        error) {
×
3084

×
3085
        var (
×
3086
                node1Pub, node2Pub route.Vertex
×
3087
                isNode1            bool
×
3088
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3089
        )
×
3090

×
3091
        // Check that this edge policy refers to a channel that we already
×
3092
        // know of. We do this explicitly so that we can return the appropriate
×
3093
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3094
        // abort the transaction which would abort the entire batch.
×
3095
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3096
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3097
                        Scid:    chanIDB,
×
3098
                        Version: int16(ProtocolV1),
×
3099
                },
×
3100
        )
×
3101
        if errors.Is(err, sql.ErrNoRows) {
×
3102
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3103
        } else if err != nil {
×
3104
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3105
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3106
        }
×
3107

3108
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3109
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3110

×
3111
        // Figure out which node this edge is from.
×
3112
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3113
        nodeID := dbChan.NodeID1
×
3114
        if !isNode1 {
×
3115
                nodeID = dbChan.NodeID2
×
3116
        }
×
3117

3118
        var (
×
3119
                inboundBase sql.NullInt64
×
3120
                inboundRate sql.NullInt64
×
3121
        )
×
3122
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3123
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3124
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3125
        })
×
3126

3127
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3128
                Version:     int16(ProtocolV1),
×
3129
                ChannelID:   dbChan.ID,
×
3130
                NodeID:      nodeID,
×
3131
                Timelock:    int32(edge.TimeLockDelta),
×
3132
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3133
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3134
                MinHtlcMsat: int64(edge.MinHTLC),
×
3135
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3136
                Disabled: sql.NullBool{
×
3137
                        Valid: true,
×
3138
                        Bool:  edge.IsDisabled(),
×
3139
                },
×
3140
                MaxHtlcMsat: sql.NullInt64{
×
3141
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3142
                        Int64: int64(edge.MaxHTLC),
×
3143
                },
×
3144
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3145
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3146
                InboundBaseFeeMsat:      inboundBase,
×
3147
                InboundFeeRateMilliMsat: inboundRate,
×
3148
                Signature:               edge.SigBytes,
×
3149
        })
×
3150
        if err != nil {
×
3151
                return node1Pub, node2Pub, isNode1,
×
3152
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3153
        }
×
3154

3155
        // Convert the flat extra opaque data into a map of TLV types to
3156
        // values.
3157
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3158
        if err != nil {
×
3159
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3160
                        "marshal extra opaque data: %w", err)
×
3161
        }
×
3162

3163
        // Update the channel policy's extra signed fields.
3164
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3165
        if err != nil {
×
3166
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3167
                        "policy extra TLVs: %w", err)
×
3168
        }
×
3169

3170
        return node1Pub, node2Pub, isNode1, nil
×
3171
}
3172

3173
// getNodeByPubKey attempts to look up a target node by its public key.
3174
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3175
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3176

×
3177
        dbNode, err := db.GetNodeByPubKey(
×
3178
                ctx, sqlc.GetNodeByPubKeyParams{
×
3179
                        Version: int16(ProtocolV1),
×
3180
                        PubKey:  pubKey[:],
×
3181
                },
×
3182
        )
×
3183
        if errors.Is(err, sql.ErrNoRows) {
×
3184
                return 0, nil, ErrGraphNodeNotFound
×
3185
        } else if err != nil {
×
3186
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3187
        }
×
3188

NEW
3189
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3190
        if err != nil {
×
3191
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3192
        }
×
3193

3194
        return dbNode.ID, node, nil
×
3195
}
3196

3197
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3198
// provided parameters.
3199
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3200
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3201

×
3202
        return &models.CachedEdgeInfo{
×
3203
                ChannelID:     byteOrder.Uint64(scid),
×
3204
                NodeKey1Bytes: node1Pub,
×
3205
                NodeKey2Bytes: node2Pub,
×
3206
                Capacity:      btcutil.Amount(capacity),
×
3207
        }
×
3208
}
×
3209

3210
// buildNode constructs a LightningNode instance from the given database node
3211
// record. The node's features, addresses and extra signed fields are also
3212
// fetched from the database and set on the node.
3213
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3214
        dbNode sqlc.GraphNode) (*models.LightningNode, error) {
×
3215

×
3216
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3217
        if err != nil {
×
3218
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3219
                        err)
×
3220
        }
×
3221

3222
        return buildNodeWithBatchData(dbNode, data)
×
3223
}
3224

3225
// buildNodeWithBatchData builds a models.LightningNode instance
3226
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3227
// features/addresses/extra fields, then the corresponding fields are expected
3228
// to be present in the batchNodeData.
3229
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3230
        batchData *batchNodeData) (*models.LightningNode, error) {
×
3231

×
3232
        if dbNode.Version != int16(ProtocolV1) {
×
3233
                return nil, fmt.Errorf("unsupported node version: %d",
×
3234
                        dbNode.Version)
×
3235
        }
×
3236

3237
        var pub [33]byte
×
3238
        copy(pub[:], dbNode.PubKey)
×
3239

×
3240
        node := &models.LightningNode{
×
3241
                PubKeyBytes: pub,
×
3242
                Features:    lnwire.EmptyFeatureVector(),
×
3243
                LastUpdate:  time.Unix(0, 0),
×
3244
        }
×
3245

×
3246
        if len(dbNode.Signature) == 0 {
×
3247
                return node, nil
×
3248
        }
×
3249

3250
        node.HaveNodeAnnouncement = true
×
3251
        node.AuthSigBytes = dbNode.Signature
×
3252
        node.Alias = dbNode.Alias.String
×
3253
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3254

×
3255
        var err error
×
3256
        if dbNode.Color.Valid {
×
3257
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3258
                if err != nil {
×
3259
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3260
                                err)
×
3261
                }
×
3262
        }
3263

3264
        // Use preloaded features.
3265
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3266
                fv := lnwire.EmptyFeatureVector()
×
3267
                for _, bit := range features {
×
3268
                        fv.Set(lnwire.FeatureBit(bit))
×
3269
                }
×
3270
                node.Features = fv
×
3271
        }
3272

3273
        // Use preloaded addresses.
3274
        addresses, exists := batchData.addresses[dbNode.ID]
×
3275
        if exists && len(addresses) > 0 {
×
3276
                node.Addresses, err = buildNodeAddresses(addresses)
×
3277
                if err != nil {
×
3278
                        return nil, fmt.Errorf("unable to build addresses "+
×
3279
                                "for node(%d): %w", dbNode.ID, err)
×
3280
                }
×
3281
        }
3282

3283
        // Use preloaded extra fields.
3284
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3285
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3286
                if err != nil {
×
3287
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3288
                                "signed fields: %w", err)
×
3289
                }
×
3290
                if len(recs) != 0 {
×
3291
                        node.ExtraOpaqueData = recs
×
3292
                }
×
3293
        }
3294

3295
        return node, nil
×
3296
}
3297

3298
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3299
// with the preloaded data, and executes the provided callback for each node.
3300
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3301
        db SQLQueries, nodes []sqlc.GraphNode,
3302
        cb func(dbID int64, node *models.LightningNode) error) error {
×
3303

×
3304
        // Extract node IDs for batch loading.
×
3305
        nodeIDs := make([]int64, len(nodes))
×
3306
        for i, node := range nodes {
×
3307
                nodeIDs[i] = node.ID
×
3308
        }
×
3309

3310
        // Batch load all related data for this page.
3311
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3312
        if err != nil {
×
3313
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3314
        }
×
3315

3316
        for _, dbNode := range nodes {
×
3317
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
3318
                if err != nil {
×
3319
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3320
                                dbNode.ID, err)
×
3321
                }
×
3322

3323
                if err := cb(dbNode.ID, node); err != nil {
×
3324
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3325
                                dbNode.ID, err)
×
3326
                }
×
3327
        }
3328

3329
        return nil
×
3330
}
3331

3332
// getNodeFeatures fetches the feature bits and constructs the feature vector
3333
// for a node with the given DB ID.
3334
func getNodeFeatures(ctx context.Context, db SQLQueries,
3335
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3336

×
3337
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3338
        if err != nil {
×
3339
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3340
                        nodeID, err)
×
3341
        }
×
3342

3343
        features := lnwire.EmptyFeatureVector()
×
3344
        for _, feature := range rows {
×
3345
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3346
        }
×
3347

3348
        return features, nil
×
3349
}
3350

3351
// upsertNode upserts the node record into the database. If the node already
3352
// exists, then the node's information is updated. If the node doesn't exist,
3353
// then a new node is created. The node's features, addresses and extra TLV
3354
// types are also updated. The node's DB ID is returned.
3355
func upsertNode(ctx context.Context, db SQLQueries,
3356
        node *models.LightningNode) (int64, error) {
×
3357

×
3358
        params := sqlc.UpsertNodeParams{
×
3359
                Version: int16(ProtocolV1),
×
3360
                PubKey:  node.PubKeyBytes[:],
×
3361
        }
×
3362

×
3363
        if node.HaveNodeAnnouncement {
×
3364
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3365
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3366
                params.Alias = sqldb.SQLStr(node.Alias)
×
3367
                params.Signature = node.AuthSigBytes
×
3368
        }
×
3369

3370
        nodeID, err := db.UpsertNode(ctx, params)
×
3371
        if err != nil {
×
3372
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3373
                        err)
×
3374
        }
×
3375

3376
        // We can exit here if we don't have the announcement yet.
3377
        if !node.HaveNodeAnnouncement {
×
3378
                return nodeID, nil
×
3379
        }
×
3380

3381
        // Update the node's features.
3382
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3383
        if err != nil {
×
3384
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3385
        }
×
3386

3387
        // Update the node's addresses.
3388
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3389
        if err != nil {
×
3390
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3391
        }
×
3392

3393
        // Convert the flat extra opaque data into a map of TLV types to
3394
        // values.
3395
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3396
        if err != nil {
×
3397
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3398
                        err)
×
3399
        }
×
3400

3401
        // Update the node's extra signed fields.
3402
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3403
        if err != nil {
×
3404
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3405
        }
×
3406

3407
        return nodeID, nil
×
3408
}
3409

3410
// upsertNodeFeatures updates the node's features node_features table. This
3411
// includes deleting any feature bits no longer present and inserting any new
3412
// feature bits. If the feature bit does not yet exist in the features table,
3413
// then an entry is created in that table first.
3414
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3415
        features *lnwire.FeatureVector) error {
×
3416

×
3417
        // Get any existing features for the node.
×
3418
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3419
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3420
                return err
×
3421
        }
×
3422

3423
        // Copy the nodes latest set of feature bits.
3424
        newFeatures := make(map[int32]struct{})
×
3425
        if features != nil {
×
3426
                for feature := range features.Features() {
×
3427
                        newFeatures[int32(feature)] = struct{}{}
×
3428
                }
×
3429
        }
3430

3431
        // For any current feature that already exists in the DB, remove it from
3432
        // the in-memory map. For any existing feature that does not exist in
3433
        // the in-memory map, delete it from the database.
3434
        for _, feature := range existingFeatures {
×
3435
                // The feature is still present, so there are no updates to be
×
3436
                // made.
×
3437
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3438
                        delete(newFeatures, feature.FeatureBit)
×
3439
                        continue
×
3440
                }
3441

3442
                // The feature is no longer present, so we remove it from the
3443
                // database.
3444
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3445
                        NodeID:     nodeID,
×
3446
                        FeatureBit: feature.FeatureBit,
×
3447
                })
×
3448
                if err != nil {
×
3449
                        return fmt.Errorf("unable to delete node(%d) "+
×
3450
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3451
                                err)
×
3452
                }
×
3453
        }
3454

3455
        // Any remaining entries in newFeatures are new features that need to be
3456
        // added to the database for the first time.
3457
        for feature := range newFeatures {
×
3458
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3459
                        NodeID:     nodeID,
×
3460
                        FeatureBit: feature,
×
3461
                })
×
3462
                if err != nil {
×
3463
                        return fmt.Errorf("unable to insert node(%d) "+
×
3464
                                "feature(%v): %w", nodeID, feature, err)
×
3465
                }
×
3466
        }
3467

3468
        return nil
×
3469
}
3470

3471
// fetchNodeFeatures fetches the features for a node with the given public key.
3472
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3473
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3474

×
3475
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3476
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3477
                        PubKey:  nodePub[:],
×
3478
                        Version: int16(ProtocolV1),
×
3479
                },
×
3480
        )
×
3481
        if err != nil {
×
3482
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3483
                        nodePub, err)
×
3484
        }
×
3485

3486
        features := lnwire.EmptyFeatureVector()
×
3487
        for _, bit := range rows {
×
3488
                features.Set(lnwire.FeatureBit(bit))
×
3489
        }
×
3490

3491
        return features, nil
×
3492
}
3493

3494
// dbAddressType is an enum type that represents the different address types
3495
// that we store in the node_addresses table. The address type determines how
3496
// the address is to be serialised/deserialize.
3497
type dbAddressType uint8
3498

3499
const (
3500
        addressTypeIPv4   dbAddressType = 1
3501
        addressTypeIPv6   dbAddressType = 2
3502
        addressTypeTorV2  dbAddressType = 3
3503
        addressTypeTorV3  dbAddressType = 4
3504
        addressTypeOpaque dbAddressType = math.MaxInt8
3505
)
3506

3507
// upsertNodeAddresses updates the node's addresses in the database. This
3508
// includes deleting any existing addresses and inserting the new set of
3509
// addresses. The deletion is necessary since the ordering of the addresses may
3510
// change, and we need to ensure that the database reflects the latest set of
3511
// addresses so that at the time of reconstructing the node announcement, the
3512
// order is preserved and the signature over the message remains valid.
3513
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3514
        addresses []net.Addr) error {
×
3515

×
3516
        // Delete any existing addresses for the node. This is required since
×
3517
        // even if the new set of addresses is the same, the ordering may have
×
3518
        // changed for a given address type.
×
3519
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3520
        if err != nil {
×
3521
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3522
                        nodeID, err)
×
3523
        }
×
3524

3525
        // Copy the nodes latest set of addresses.
3526
        newAddresses := map[dbAddressType][]string{
×
3527
                addressTypeIPv4:   {},
×
3528
                addressTypeIPv6:   {},
×
3529
                addressTypeTorV2:  {},
×
3530
                addressTypeTorV3:  {},
×
3531
                addressTypeOpaque: {},
×
3532
        }
×
3533
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3534
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3535
        }
×
3536

3537
        for _, address := range addresses {
×
3538
                switch addr := address.(type) {
×
3539
                case *net.TCPAddr:
×
3540
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3541
                                addAddr(addressTypeIPv4, addr)
×
3542
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3543
                                addAddr(addressTypeIPv6, addr)
×
3544
                        } else {
×
3545
                                return fmt.Errorf("unhandled IP address: %v",
×
3546
                                        addr)
×
3547
                        }
×
3548

3549
                case *tor.OnionAddr:
×
3550
                        switch len(addr.OnionService) {
×
3551
                        case tor.V2Len:
×
3552
                                addAddr(addressTypeTorV2, addr)
×
3553
                        case tor.V3Len:
×
3554
                                addAddr(addressTypeTorV3, addr)
×
3555
                        default:
×
3556
                                return fmt.Errorf("invalid length for a tor " +
×
3557
                                        "address")
×
3558
                        }
3559

3560
                case *lnwire.OpaqueAddrs:
×
3561
                        addAddr(addressTypeOpaque, addr)
×
3562

3563
                default:
×
3564
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3565
                }
3566
        }
3567

3568
        // Any remaining entries in newAddresses are new addresses that need to
3569
        // be added to the database for the first time.
3570
        for addrType, addrList := range newAddresses {
×
3571
                for position, addr := range addrList {
×
3572
                        err := db.InsertNodeAddress(
×
3573
                                ctx, sqlc.InsertNodeAddressParams{
×
3574
                                        NodeID:   nodeID,
×
3575
                                        Type:     int16(addrType),
×
3576
                                        Address:  addr,
×
3577
                                        Position: int32(position),
×
3578
                                },
×
3579
                        )
×
3580
                        if err != nil {
×
3581
                                return fmt.Errorf("unable to insert "+
×
3582
                                        "node(%d) address(%v): %w", nodeID,
×
3583
                                        addr, err)
×
3584
                        }
×
3585
                }
3586
        }
3587

3588
        return nil
×
3589
}
3590

3591
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3592
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
3593
        error) {
×
3594

×
3595
        // GetNodeAddresses ensures that the addresses for a given type are
×
3596
        // returned in the same order as they were inserted.
×
3597
        rows, err := db.GetNodeAddresses(ctx, id)
×
3598
        if err != nil {
×
3599
                return nil, err
×
3600
        }
×
3601

3602
        addresses := make([]net.Addr, 0, len(rows))
×
3603
        for _, row := range rows {
×
3604
                address := row.Address
×
3605

×
3606
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
3607
                if err != nil {
×
3608
                        return nil, fmt.Errorf("unable to parse address "+
×
3609
                                "for node(%d): %v: %w", id, address, err)
×
3610
                }
×
3611

3612
                addresses = append(addresses, addr)
×
3613
        }
3614

3615
        // If we have no addresses, then we'll return nil instead of an
3616
        // empty slice.
3617
        if len(addresses) == 0 {
×
3618
                addresses = nil
×
3619
        }
×
3620

3621
        return addresses, nil
×
3622
}
3623

3624
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3625
// database. This includes updating any existing types, inserting any new types,
3626
// and deleting any types that are no longer present.
3627
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3628
        nodeID int64, extraFields map[uint64][]byte) error {
×
3629

×
3630
        // Get any existing extra signed fields for the node.
×
3631
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3632
        if err != nil {
×
3633
                return err
×
3634
        }
×
3635

3636
        // Make a lookup map of the existing field types so that we can use it
3637
        // to keep track of any fields we should delete.
3638
        m := make(map[uint64]bool)
×
3639
        for _, field := range existingFields {
×
3640
                m[uint64(field.Type)] = true
×
3641
        }
×
3642

3643
        // For all the new fields, we'll upsert them and remove them from the
3644
        // map of existing fields.
3645
        for tlvType, value := range extraFields {
×
3646
                err = db.UpsertNodeExtraType(
×
3647
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3648
                                NodeID: nodeID,
×
3649
                                Type:   int64(tlvType),
×
3650
                                Value:  value,
×
3651
                        },
×
3652
                )
×
3653
                if err != nil {
×
3654
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3655
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3656
                }
×
3657

3658
                // Remove the field from the map of existing fields if it was
3659
                // present.
3660
                delete(m, tlvType)
×
3661
        }
3662

3663
        // For all the fields that are left in the map of existing fields, we'll
3664
        // delete them as they are no longer present in the new set of fields.
3665
        for tlvType := range m {
×
3666
                err = db.DeleteExtraNodeType(
×
3667
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3668
                                NodeID: nodeID,
×
3669
                                Type:   int64(tlvType),
×
3670
                        },
×
3671
                )
×
3672
                if err != nil {
×
3673
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3674
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3675
                }
×
3676
        }
3677

3678
        return nil
×
3679
}
3680

3681
// srcNodeInfo holds the information about the source node of the graph.
3682
type srcNodeInfo struct {
3683
        // id is the DB level ID of the source node entry in the "nodes" table.
3684
        id int64
3685

3686
        // pub is the public key of the source node.
3687
        pub route.Vertex
3688
}
3689

3690
// sourceNode returns the DB node ID and pub key of the source node for the
3691
// specified protocol version.
3692
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3693
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3694

×
3695
        s.srcNodeMu.Lock()
×
3696
        defer s.srcNodeMu.Unlock()
×
3697

×
3698
        // If we already have the source node ID and pub key cached, then
×
3699
        // return them.
×
3700
        if info, ok := s.srcNodes[version]; ok {
×
3701
                return info.id, info.pub, nil
×
3702
        }
×
3703

3704
        var pubKey route.Vertex
×
3705

×
3706
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3707
        if err != nil {
×
3708
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3709
                        err)
×
3710
        }
×
3711

3712
        if len(nodes) == 0 {
×
3713
                return 0, pubKey, ErrSourceNodeNotSet
×
3714
        } else if len(nodes) > 1 {
×
3715
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3716
                        "protocol %s found", version)
×
3717
        }
×
3718

3719
        copy(pubKey[:], nodes[0].PubKey)
×
3720

×
3721
        s.srcNodes[version] = &srcNodeInfo{
×
3722
                id:  nodes[0].NodeID,
×
3723
                pub: pubKey,
×
3724
        }
×
3725

×
3726
        return nodes[0].NodeID, pubKey, nil
×
3727
}
3728

3729
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3730
// This then produces a map from TLV type to value. If the input is not a
3731
// valid TLV stream, then an error is returned.
3732
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3733
        r := bytes.NewReader(data)
×
3734

×
3735
        tlvStream, err := tlv.NewStream()
×
3736
        if err != nil {
×
3737
                return nil, err
×
3738
        }
×
3739

3740
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3741
        // pass it into the P2P decoding variant.
3742
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3743
        if err != nil {
×
3744
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3745
        }
×
3746
        if len(parsedTypes) == 0 {
×
3747
                return nil, nil
×
3748
        }
×
3749

3750
        records := make(map[uint64][]byte)
×
3751
        for k, v := range parsedTypes {
×
3752
                records[uint64(k)] = v
×
3753
        }
×
3754

3755
        return records, nil
×
3756
}
3757

3758
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3759
// channel.
3760
type dbChanInfo struct {
3761
        channelID int64
3762
        node1ID   int64
3763
        node2ID   int64
3764
}
3765

3766
// insertChannel inserts a new channel record into the database.
3767
func insertChannel(ctx context.Context, db SQLQueries,
3768
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3769

×
3770
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3771

×
3772
        // Make sure that the channel doesn't already exist. We do this
×
3773
        // explicitly instead of relying on catching a unique constraint error
×
3774
        // because relying on SQL to throw that error would abort the entire
×
3775
        // batch of transactions.
×
3776
        _, err := db.GetChannelBySCID(
×
3777
                ctx, sqlc.GetChannelBySCIDParams{
×
3778
                        Scid:    chanIDB,
×
3779
                        Version: int16(ProtocolV1),
×
3780
                },
×
3781
        )
×
3782
        if err == nil {
×
3783
                return nil, ErrEdgeAlreadyExist
×
3784
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3785
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3786
        }
×
3787

3788
        // Make sure that at least a "shell" entry for each node is present in
3789
        // the nodes table.
3790
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3791
        if err != nil {
×
3792
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3793
        }
×
3794

3795
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3796
        if err != nil {
×
3797
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3798
        }
×
3799

3800
        var capacity sql.NullInt64
×
3801
        if edge.Capacity != 0 {
×
3802
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3803
        }
×
3804

3805
        createParams := sqlc.CreateChannelParams{
×
3806
                Version:     int16(ProtocolV1),
×
3807
                Scid:        chanIDB,
×
3808
                NodeID1:     node1DBID,
×
3809
                NodeID2:     node2DBID,
×
3810
                Outpoint:    edge.ChannelPoint.String(),
×
3811
                Capacity:    capacity,
×
3812
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3813
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3814
        }
×
3815

×
3816
        if edge.AuthProof != nil {
×
3817
                proof := edge.AuthProof
×
3818

×
3819
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3820
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3821
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3822
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3823
        }
×
3824

3825
        // Insert the new channel record.
3826
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3827
        if err != nil {
×
3828
                return nil, err
×
3829
        }
×
3830

3831
        // Insert any channel features.
3832
        for feature := range edge.Features.Features() {
×
3833
                err = db.InsertChannelFeature(
×
3834
                        ctx, sqlc.InsertChannelFeatureParams{
×
3835
                                ChannelID:  dbChanID,
×
3836
                                FeatureBit: int32(feature),
×
3837
                        },
×
3838
                )
×
3839
                if err != nil {
×
3840
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
3841
                                "feature(%v): %w", dbChanID, feature, err)
×
3842
                }
×
3843
        }
3844

3845
        // Finally, insert any extra TLV fields in the channel announcement.
3846
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3847
        if err != nil {
×
3848
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
3849
                        "data: %w", err)
×
3850
        }
×
3851

3852
        for tlvType, value := range extra {
×
3853
                err := db.CreateChannelExtraType(
×
3854
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3855
                                ChannelID: dbChanID,
×
3856
                                Type:      int64(tlvType),
×
3857
                                Value:     value,
×
3858
                        },
×
3859
                )
×
3860
                if err != nil {
×
3861
                        return nil, fmt.Errorf("unable to upsert "+
×
3862
                                "channel(%d) extra signed field(%v): %w",
×
3863
                                edge.ChannelID, tlvType, err)
×
3864
                }
×
3865
        }
3866

3867
        return &dbChanInfo{
×
3868
                channelID: dbChanID,
×
3869
                node1ID:   node1DBID,
×
3870
                node2ID:   node2DBID,
×
3871
        }, nil
×
3872
}
3873

3874
// maybeCreateShellNode checks if a shell node entry exists for the
3875
// given public key. If it does not exist, then a new shell node entry is
3876
// created. The ID of the node is returned. A shell node only has a protocol
3877
// version and public key persisted.
3878
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3879
        pubKey route.Vertex) (int64, error) {
×
3880

×
3881
        dbNode, err := db.GetNodeByPubKey(
×
3882
                ctx, sqlc.GetNodeByPubKeyParams{
×
3883
                        PubKey:  pubKey[:],
×
3884
                        Version: int16(ProtocolV1),
×
3885
                },
×
3886
        )
×
3887
        // The node exists. Return the ID.
×
3888
        if err == nil {
×
3889
                return dbNode.ID, nil
×
3890
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3891
                return 0, err
×
3892
        }
×
3893

3894
        // Otherwise, the node does not exist, so we create a shell entry for
3895
        // it.
3896
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3897
                Version: int16(ProtocolV1),
×
3898
                PubKey:  pubKey[:],
×
3899
        })
×
3900
        if err != nil {
×
3901
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3902
        }
×
3903

3904
        return id, nil
×
3905
}
3906

3907
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3908
// the database. This includes deleting any existing types and then inserting
3909
// the new types.
3910
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3911
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3912

×
3913
        // Delete all existing extra signed fields for the channel policy.
×
3914
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3915
        if err != nil {
×
3916
                return fmt.Errorf("unable to delete "+
×
3917
                        "existing policy extra signed fields for policy %d: %w",
×
3918
                        chanPolicyID, err)
×
3919
        }
×
3920

3921
        // Insert all new extra signed fields for the channel policy.
3922
        for tlvType, value := range extraFields {
×
3923
                err = db.InsertChanPolicyExtraType(
×
3924
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3925
                                ChannelPolicyID: chanPolicyID,
×
3926
                                Type:            int64(tlvType),
×
3927
                                Value:           value,
×
3928
                        },
×
3929
                )
×
3930
                if err != nil {
×
3931
                        return fmt.Errorf("unable to insert "+
×
3932
                                "channel_policy(%d) extra signed field(%v): %w",
×
3933
                                chanPolicyID, tlvType, err)
×
3934
                }
×
3935
        }
3936

3937
        return nil
×
3938
}
3939

3940
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3941
// provided dbChanRow and also fetches any other required information
3942
// to construct the edge info.
3943
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
3944
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
3945
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3946

×
NEW
3947
        data, err := batchLoadChannelData(
×
NEW
3948
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
NEW
3949
        )
×
3950
        if err != nil {
×
3951
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
3952
                        err)
×
3953
        }
×
3954

NEW
3955
        return buildEdgeInfoWithBatchData(
×
NEW
3956
                cfg.ChainHash, dbChan, node1, node2, data,
×
NEW
3957
        )
×
3958
}
3959

3960
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
3961
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
3962
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
3963
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
3964

×
3965
        if dbChan.Version != int16(ProtocolV1) {
×
3966
                return nil, fmt.Errorf("unsupported channel version: %d",
×
3967
                        dbChan.Version)
×
3968
        }
×
3969

3970
        // Use pre-loaded features and extras types.
3971
        fv := lnwire.EmptyFeatureVector()
×
3972
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
3973
                for _, bit := range features {
×
3974
                        fv.Set(lnwire.FeatureBit(bit))
×
3975
                }
×
3976
        }
3977

3978
        var extras map[uint64][]byte
×
3979
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
3980
        if exists {
×
3981
                extras = channelExtras
×
3982
        } else {
×
3983
                extras = make(map[uint64][]byte)
×
3984
        }
×
3985

3986
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
3987
        if err != nil {
×
3988
                return nil, err
×
3989
        }
×
3990

3991
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3992
        if err != nil {
×
3993
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3994
                        "fields: %w", err)
×
3995
        }
×
3996
        if recs == nil {
×
3997
                recs = make([]byte, 0)
×
3998
        }
×
3999

4000
        var btcKey1, btcKey2 route.Vertex
×
4001
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4002
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4003

×
4004
        channel := &models.ChannelEdgeInfo{
×
4005
                ChainHash:        chain,
×
4006
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4007
                NodeKey1Bytes:    node1,
×
4008
                NodeKey2Bytes:    node2,
×
4009
                BitcoinKey1Bytes: btcKey1,
×
4010
                BitcoinKey2Bytes: btcKey2,
×
4011
                ChannelPoint:     *op,
×
4012
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4013
                Features:         fv,
×
4014
                ExtraOpaqueData:  recs,
×
4015
        }
×
4016

×
4017
        // We always set all the signatures at the same time, so we can
×
4018
        // safely check if one signature is present to determine if we have the
×
4019
        // rest of the signatures for the auth proof.
×
4020
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4021
                channel.AuthProof = &models.ChannelAuthProof{
×
4022
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4023
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4024
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4025
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4026
                }
×
4027
        }
×
4028

4029
        return channel, nil
×
4030
}
4031

4032
// buildNodeVertices is a helper that converts raw node public keys
4033
// into route.Vertex instances.
4034
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4035
        route.Vertex, error) {
×
4036

×
4037
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4038
        if err != nil {
×
4039
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4040
                        "create vertex from node1 pubkey: %w", err)
×
4041
        }
×
4042

4043
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4044
        if err != nil {
×
4045
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4046
                        "create vertex from node2 pubkey: %w", err)
×
4047
        }
×
4048

4049
        return node1Vertex, node2Vertex, nil
×
4050
}
4051

4052
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4053
// retrieves all the extra info required to build the complete
4054
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4055
// the provided sqlc.GraphChannelPolicy records are nil.
4056
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
4057
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4058
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
4059
        *models.ChannelEdgePolicy, error) {
×
4060

×
4061
        if dbPol1 == nil && dbPol2 == nil {
×
4062
                return nil, nil, nil
×
4063
        }
×
4064

4065
        var policyIDs = make([]int64, 0, 2)
×
4066
        if dbPol1 != nil {
×
4067
                policyIDs = append(policyIDs, dbPol1.ID)
×
4068
        }
×
4069
        if dbPol2 != nil {
×
4070
                policyIDs = append(policyIDs, dbPol2.ID)
×
4071
        }
×
4072

UNCOV
4073
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
UNCOV
4074
        if err != nil {
×
4075
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4076
                        "data: %w", err)
×
4077
        }
×
4078

4079
        pol1, err := buildChanPolicyWithBatchData(
×
4080
                dbPol1, channelID, node2, batchData,
×
4081
        )
×
4082
        if err != nil {
×
4083
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4084
        }
×
4085

4086
        pol2, err := buildChanPolicyWithBatchData(
×
4087
                dbPol2, channelID, node1, batchData,
×
4088
        )
×
4089
        if err != nil {
×
4090
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4091
        }
×
4092

4093
        return pol1, pol2, nil
×
4094
}
4095

4096
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4097
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4098
// then nil is returned for it.
4099
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4100
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4101
        *models.CachedEdgePolicy, error) {
×
4102

×
4103
        var p1, p2 *models.CachedEdgePolicy
×
4104
        if dbPol1 != nil {
×
4105
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
4106
                if err != nil {
×
4107
                        return nil, nil, err
×
4108
                }
×
4109

4110
                p1 = models.NewCachedPolicy(policy1)
×
4111
        }
4112
        if dbPol2 != nil {
×
4113
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4114
                if err != nil {
×
4115
                        return nil, nil, err
×
4116
                }
×
4117

4118
                p2 = models.NewCachedPolicy(policy2)
×
4119
        }
4120

4121
        return p1, p2, nil
×
4122
}
4123

4124
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4125
// provided sqlc.GraphChannelPolicy and other required information.
4126
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4127
        extras map[uint64][]byte,
4128
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4129

×
4130
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4131
        if err != nil {
×
4132
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4133
                        "fields: %w", err)
×
4134
        }
×
4135

4136
        var inboundFee fn.Option[lnwire.Fee]
×
4137
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4138
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4139

×
4140
                inboundFee = fn.Some(lnwire.Fee{
×
4141
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4142
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4143
                })
×
4144
        }
×
4145

4146
        return &models.ChannelEdgePolicy{
×
4147
                SigBytes:  dbPolicy.Signature,
×
4148
                ChannelID: channelID,
×
4149
                LastUpdate: time.Unix(
×
4150
                        dbPolicy.LastUpdate.Int64, 0,
×
4151
                ),
×
4152
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4153
                        dbPolicy.MessageFlags,
×
4154
                ),
×
4155
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4156
                        dbPolicy.ChannelFlags,
×
4157
                ),
×
4158
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4159
                MinHTLC: lnwire.MilliSatoshi(
×
4160
                        dbPolicy.MinHtlcMsat,
×
4161
                ),
×
4162
                MaxHTLC: lnwire.MilliSatoshi(
×
4163
                        dbPolicy.MaxHtlcMsat.Int64,
×
4164
                ),
×
4165
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4166
                        dbPolicy.BaseFeeMsat,
×
4167
                ),
×
4168
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4169
                ToNode:                    toNode,
×
4170
                InboundFee:                inboundFee,
×
4171
                ExtraOpaqueData:           recs,
×
4172
        }, nil
×
4173
}
4174

4175
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4176
// row which is expected to be a sqlc type that contains channel policy
4177
// information. It returns two policies, which may be nil if the policy
4178
// information is not present in the row.
4179
//
4180
//nolint:ll,dupl,funlen
4181
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4182
        *sqlc.GraphChannelPolicy, error) {
×
4183

×
4184
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4185
        switch r := row.(type) {
×
4186
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4187
                if r.Policy1Timelock.Valid {
×
4188
                        policy1 = &sqlc.GraphChannelPolicy{
×
4189
                                Timelock:                r.Policy1Timelock.Int32,
×
4190
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4191
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4192
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4193
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4194
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4195
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4196
                                Disabled:                r.Policy1Disabled,
×
4197
                                MessageFlags:            r.Policy1MessageFlags,
×
4198
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4199
                        }
×
4200
                }
×
4201
                if r.Policy2Timelock.Valid {
×
4202
                        policy2 = &sqlc.GraphChannelPolicy{
×
4203
                                Timelock:                r.Policy2Timelock.Int32,
×
4204
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4205
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4206
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4207
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4208
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4209
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4210
                                Disabled:                r.Policy2Disabled,
×
4211
                                MessageFlags:            r.Policy2MessageFlags,
×
4212
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4213
                        }
×
4214
                }
×
4215

4216
                return policy1, policy2, nil
×
4217

4218
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4219
                if r.Policy1ID.Valid {
×
4220
                        policy1 = &sqlc.GraphChannelPolicy{
×
4221
                                ID:                      r.Policy1ID.Int64,
×
4222
                                Version:                 r.Policy1Version.Int16,
×
4223
                                ChannelID:               r.GraphChannel.ID,
×
4224
                                NodeID:                  r.Policy1NodeID.Int64,
×
4225
                                Timelock:                r.Policy1Timelock.Int32,
×
4226
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4227
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4228
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4229
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4230
                                LastUpdate:              r.Policy1LastUpdate,
×
4231
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4232
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4233
                                Disabled:                r.Policy1Disabled,
×
4234
                                MessageFlags:            r.Policy1MessageFlags,
×
4235
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4236
                                Signature:               r.Policy1Signature,
×
4237
                        }
×
4238
                }
×
4239
                if r.Policy2ID.Valid {
×
4240
                        policy2 = &sqlc.GraphChannelPolicy{
×
4241
                                ID:                      r.Policy2ID.Int64,
×
4242
                                Version:                 r.Policy2Version.Int16,
×
4243
                                ChannelID:               r.GraphChannel.ID,
×
4244
                                NodeID:                  r.Policy2NodeID.Int64,
×
4245
                                Timelock:                r.Policy2Timelock.Int32,
×
4246
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4247
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4248
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4249
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4250
                                LastUpdate:              r.Policy2LastUpdate,
×
4251
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4252
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4253
                                Disabled:                r.Policy2Disabled,
×
4254
                                MessageFlags:            r.Policy2MessageFlags,
×
4255
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4256
                                Signature:               r.Policy2Signature,
×
4257
                        }
×
4258
                }
×
4259

4260
                return policy1, policy2, nil
×
4261

4262
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4263
                if r.Policy1ID.Valid {
×
4264
                        policy1 = &sqlc.GraphChannelPolicy{
×
4265
                                ID:                      r.Policy1ID.Int64,
×
4266
                                Version:                 r.Policy1Version.Int16,
×
4267
                                ChannelID:               r.GraphChannel.ID,
×
4268
                                NodeID:                  r.Policy1NodeID.Int64,
×
4269
                                Timelock:                r.Policy1Timelock.Int32,
×
4270
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4271
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4272
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4273
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4274
                                LastUpdate:              r.Policy1LastUpdate,
×
4275
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4276
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4277
                                Disabled:                r.Policy1Disabled,
×
4278
                                MessageFlags:            r.Policy1MessageFlags,
×
4279
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4280
                                Signature:               r.Policy1Signature,
×
4281
                        }
×
4282
                }
×
4283
                if r.Policy2ID.Valid {
×
4284
                        policy2 = &sqlc.GraphChannelPolicy{
×
4285
                                ID:                      r.Policy2ID.Int64,
×
4286
                                Version:                 r.Policy2Version.Int16,
×
4287
                                ChannelID:               r.GraphChannel.ID,
×
4288
                                NodeID:                  r.Policy2NodeID.Int64,
×
4289
                                Timelock:                r.Policy2Timelock.Int32,
×
4290
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4291
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4292
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4293
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4294
                                LastUpdate:              r.Policy2LastUpdate,
×
4295
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4296
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4297
                                Disabled:                r.Policy2Disabled,
×
4298
                                MessageFlags:            r.Policy2MessageFlags,
×
4299
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4300
                                Signature:               r.Policy2Signature,
×
4301
                        }
×
4302
                }
×
4303

4304
                return policy1, policy2, nil
×
4305

4306
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4307
                if r.Policy1ID.Valid {
×
4308
                        policy1 = &sqlc.GraphChannelPolicy{
×
4309
                                ID:                      r.Policy1ID.Int64,
×
4310
                                Version:                 r.Policy1Version.Int16,
×
4311
                                ChannelID:               r.GraphChannel.ID,
×
4312
                                NodeID:                  r.Policy1NodeID.Int64,
×
4313
                                Timelock:                r.Policy1Timelock.Int32,
×
4314
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4315
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4316
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4317
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4318
                                LastUpdate:              r.Policy1LastUpdate,
×
4319
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4320
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4321
                                Disabled:                r.Policy1Disabled,
×
4322
                                MessageFlags:            r.Policy1MessageFlags,
×
4323
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4324
                                Signature:               r.Policy1Signature,
×
4325
                        }
×
4326
                }
×
4327
                if r.Policy2ID.Valid {
×
4328
                        policy2 = &sqlc.GraphChannelPolicy{
×
4329
                                ID:                      r.Policy2ID.Int64,
×
4330
                                Version:                 r.Policy2Version.Int16,
×
4331
                                ChannelID:               r.GraphChannel.ID,
×
4332
                                NodeID:                  r.Policy2NodeID.Int64,
×
4333
                                Timelock:                r.Policy2Timelock.Int32,
×
4334
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4335
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4336
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4337
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4338
                                LastUpdate:              r.Policy2LastUpdate,
×
4339
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4340
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4341
                                Disabled:                r.Policy2Disabled,
×
4342
                                MessageFlags:            r.Policy2MessageFlags,
×
4343
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4344
                                Signature:               r.Policy2Signature,
×
4345
                        }
×
4346
                }
×
4347

4348
                return policy1, policy2, nil
×
4349

4350
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4351
                if r.Policy1ID.Valid {
×
4352
                        policy1 = &sqlc.GraphChannelPolicy{
×
4353
                                ID:                      r.Policy1ID.Int64,
×
4354
                                Version:                 r.Policy1Version.Int16,
×
4355
                                ChannelID:               r.GraphChannel.ID,
×
4356
                                NodeID:                  r.Policy1NodeID.Int64,
×
4357
                                Timelock:                r.Policy1Timelock.Int32,
×
4358
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4359
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4360
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4361
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4362
                                LastUpdate:              r.Policy1LastUpdate,
×
4363
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4364
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4365
                                Disabled:                r.Policy1Disabled,
×
4366
                                MessageFlags:            r.Policy1MessageFlags,
×
4367
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4368
                                Signature:               r.Policy1Signature,
×
4369
                        }
×
4370
                }
×
4371
                if r.Policy2ID.Valid {
×
4372
                        policy2 = &sqlc.GraphChannelPolicy{
×
4373
                                ID:                      r.Policy2ID.Int64,
×
4374
                                Version:                 r.Policy2Version.Int16,
×
4375
                                ChannelID:               r.GraphChannel.ID,
×
4376
                                NodeID:                  r.Policy2NodeID.Int64,
×
4377
                                Timelock:                r.Policy2Timelock.Int32,
×
4378
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4379
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4380
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4381
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4382
                                LastUpdate:              r.Policy2LastUpdate,
×
4383
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4384
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4385
                                Disabled:                r.Policy2Disabled,
×
4386
                                MessageFlags:            r.Policy2MessageFlags,
×
4387
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4388
                                Signature:               r.Policy2Signature,
×
4389
                        }
×
4390
                }
×
4391

4392
                return policy1, policy2, nil
×
4393

4394
        case sqlc.ListChannelsForNodeIDsRow:
×
4395
                if r.Policy1ID.Valid {
×
4396
                        policy1 = &sqlc.GraphChannelPolicy{
×
4397
                                ID:                      r.Policy1ID.Int64,
×
4398
                                Version:                 r.Policy1Version.Int16,
×
4399
                                ChannelID:               r.GraphChannel.ID,
×
4400
                                NodeID:                  r.Policy1NodeID.Int64,
×
4401
                                Timelock:                r.Policy1Timelock.Int32,
×
4402
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4403
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4404
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4405
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4406
                                LastUpdate:              r.Policy1LastUpdate,
×
4407
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4408
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4409
                                Disabled:                r.Policy1Disabled,
×
4410
                                MessageFlags:            r.Policy1MessageFlags,
×
4411
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4412
                                Signature:               r.Policy1Signature,
×
4413
                        }
×
4414
                }
×
4415
                if r.Policy2ID.Valid {
×
4416
                        policy2 = &sqlc.GraphChannelPolicy{
×
4417
                                ID:                      r.Policy2ID.Int64,
×
4418
                                Version:                 r.Policy2Version.Int16,
×
4419
                                ChannelID:               r.GraphChannel.ID,
×
4420
                                NodeID:                  r.Policy2NodeID.Int64,
×
4421
                                Timelock:                r.Policy2Timelock.Int32,
×
4422
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4423
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4424
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4425
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4426
                                LastUpdate:              r.Policy2LastUpdate,
×
4427
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4428
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4429
                                Disabled:                r.Policy2Disabled,
×
4430
                                MessageFlags:            r.Policy2MessageFlags,
×
4431
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4432
                                Signature:               r.Policy2Signature,
×
4433
                        }
×
4434
                }
×
4435

4436
                return policy1, policy2, nil
×
4437

4438
        case sqlc.ListChannelsByNodeIDRow:
×
4439
                if r.Policy1ID.Valid {
×
4440
                        policy1 = &sqlc.GraphChannelPolicy{
×
4441
                                ID:                      r.Policy1ID.Int64,
×
4442
                                Version:                 r.Policy1Version.Int16,
×
4443
                                ChannelID:               r.GraphChannel.ID,
×
4444
                                NodeID:                  r.Policy1NodeID.Int64,
×
4445
                                Timelock:                r.Policy1Timelock.Int32,
×
4446
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4447
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4448
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4449
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4450
                                LastUpdate:              r.Policy1LastUpdate,
×
4451
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4452
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4453
                                Disabled:                r.Policy1Disabled,
×
4454
                                MessageFlags:            r.Policy1MessageFlags,
×
4455
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4456
                                Signature:               r.Policy1Signature,
×
4457
                        }
×
4458
                }
×
4459
                if r.Policy2ID.Valid {
×
4460
                        policy2 = &sqlc.GraphChannelPolicy{
×
4461
                                ID:                      r.Policy2ID.Int64,
×
4462
                                Version:                 r.Policy2Version.Int16,
×
4463
                                ChannelID:               r.GraphChannel.ID,
×
4464
                                NodeID:                  r.Policy2NodeID.Int64,
×
4465
                                Timelock:                r.Policy2Timelock.Int32,
×
4466
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4467
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4468
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4469
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4470
                                LastUpdate:              r.Policy2LastUpdate,
×
4471
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4472
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4473
                                Disabled:                r.Policy2Disabled,
×
4474
                                MessageFlags:            r.Policy2MessageFlags,
×
4475
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4476
                                Signature:               r.Policy2Signature,
×
4477
                        }
×
4478
                }
×
4479

4480
                return policy1, policy2, nil
×
4481

4482
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4483
                if r.Policy1ID.Valid {
×
4484
                        policy1 = &sqlc.GraphChannelPolicy{
×
4485
                                ID:                      r.Policy1ID.Int64,
×
4486
                                Version:                 r.Policy1Version.Int16,
×
4487
                                ChannelID:               r.GraphChannel.ID,
×
4488
                                NodeID:                  r.Policy1NodeID.Int64,
×
4489
                                Timelock:                r.Policy1Timelock.Int32,
×
4490
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4491
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4492
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4493
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4494
                                LastUpdate:              r.Policy1LastUpdate,
×
4495
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4496
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4497
                                Disabled:                r.Policy1Disabled,
×
4498
                                MessageFlags:            r.Policy1MessageFlags,
×
4499
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4500
                                Signature:               r.Policy1Signature,
×
4501
                        }
×
4502
                }
×
4503
                if r.Policy2ID.Valid {
×
4504
                        policy2 = &sqlc.GraphChannelPolicy{
×
4505
                                ID:                      r.Policy2ID.Int64,
×
4506
                                Version:                 r.Policy2Version.Int16,
×
4507
                                ChannelID:               r.GraphChannel.ID,
×
4508
                                NodeID:                  r.Policy2NodeID.Int64,
×
4509
                                Timelock:                r.Policy2Timelock.Int32,
×
4510
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4511
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4512
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4513
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4514
                                LastUpdate:              r.Policy2LastUpdate,
×
4515
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4516
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4517
                                Disabled:                r.Policy2Disabled,
×
4518
                                MessageFlags:            r.Policy2MessageFlags,
×
4519
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4520
                                Signature:               r.Policy2Signature,
×
4521
                        }
×
4522
                }
×
4523

4524
                return policy1, policy2, nil
×
4525
        default:
×
4526
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4527
                        "extractChannelPolicies: %T", r)
×
4528
        }
4529
}
4530

4531
// channelIDToBytes converts a channel ID (SCID) to a byte array
4532
// representation.
4533
func channelIDToBytes(channelID uint64) []byte {
×
4534
        var chanIDB [8]byte
×
4535
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4536

×
4537
        return chanIDB[:]
×
4538
}
×
4539

4540
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4541
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4542
        if len(addresses) == 0 {
×
4543
                return nil, nil
×
4544
        }
×
4545

4546
        result := make([]net.Addr, 0, len(addresses))
×
4547
        for _, addr := range addresses {
×
4548
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
4549
                if err != nil {
×
4550
                        return nil, fmt.Errorf("unable to parse address %s "+
×
4551
                                "of type %d: %w", addr.address, addr.addrType,
×
4552
                                err)
×
4553
                }
×
4554
                if netAddr != nil {
×
4555
                        result = append(result, netAddr)
×
4556
                }
×
4557
        }
4558

4559
        // If we have no valid addresses, return nil instead of empty slice.
4560
        if len(result) == 0 {
×
4561
                return nil, nil
×
4562
        }
×
4563

4564
        return result, nil
×
4565
}
4566

4567
// parseAddress parses the given address string based on the address type
4568
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4569
// and opaque addresses.
4570
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4571
        switch addrType {
×
4572
        case addressTypeIPv4:
×
4573
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4574
                if err != nil {
×
4575
                        return nil, err
×
4576
                }
×
4577

4578
                tcp.IP = tcp.IP.To4()
×
4579

×
4580
                return tcp, nil
×
4581

4582
        case addressTypeIPv6:
×
4583
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
4584
                if err != nil {
×
4585
                        return nil, err
×
4586
                }
×
4587

4588
                return tcp, nil
×
4589

4590
        case addressTypeTorV3, addressTypeTorV2:
×
4591
                service, portStr, err := net.SplitHostPort(address)
×
4592
                if err != nil {
×
4593
                        return nil, fmt.Errorf("unable to split tor "+
×
4594
                                "address: %v", address)
×
4595
                }
×
4596

4597
                port, err := strconv.Atoi(portStr)
×
4598
                if err != nil {
×
4599
                        return nil, err
×
4600
                }
×
4601

4602
                return &tor.OnionAddr{
×
4603
                        OnionService: service,
×
4604
                        Port:         port,
×
4605
                }, nil
×
4606

4607
        case addressTypeOpaque:
×
4608
                opaque, err := hex.DecodeString(address)
×
4609
                if err != nil {
×
4610
                        return nil, fmt.Errorf("unable to decode opaque "+
×
4611
                                "address: %v", address)
×
4612
                }
×
4613

4614
                return &lnwire.OpaqueAddrs{
×
4615
                        Payload: opaque,
×
4616
                }, nil
×
4617

4618
        default:
×
4619
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
4620
        }
4621
}
4622

4623
// batchNodeData holds all the related data for a batch of nodes.
4624
type batchNodeData struct {
4625
        // features is a map from a DB node ID to the feature bits for that
4626
        // node.
4627
        features map[int64][]int
4628

4629
        // addresses is a map from a DB node ID to the node's addresses.
4630
        addresses map[int64][]nodeAddress
4631

4632
        // extraFields is a map from a DB node ID to the extra signed fields
4633
        // for that node.
4634
        extraFields map[int64]map[uint64][]byte
4635
}
4636

4637
// nodeAddress holds the address type, position and address string for a
4638
// node. This is used to batch the fetching of node addresses.
4639
type nodeAddress struct {
4640
        addrType dbAddressType
4641
        position int32
4642
        address  string
4643
}
4644

4645
// batchLoadNodeData loads all related data for a batch of node IDs using the
4646
// provided SQLQueries interface. It returns a batchNodeData instance containing
4647
// the node features, addresses and extra signed fields.
4648
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
4649
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
4650

×
4651
        // Batch load the node features.
×
4652
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
4653
        if err != nil {
×
4654
                return nil, fmt.Errorf("unable to batch load node "+
×
4655
                        "features: %w", err)
×
4656
        }
×
4657

4658
        // Batch load the node addresses.
4659
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
4660
        if err != nil {
×
4661
                return nil, fmt.Errorf("unable to batch load node "+
×
4662
                        "addresses: %w", err)
×
4663
        }
×
4664

4665
        // Batch load the node extra signed fields.
4666
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
4667
        if err != nil {
×
4668
                return nil, fmt.Errorf("unable to batch load node extra "+
×
4669
                        "signed fields: %w", err)
×
4670
        }
×
4671

4672
        return &batchNodeData{
×
4673
                features:    features,
×
4674
                addresses:   addrs,
×
4675
                extraFields: extraTypes,
×
4676
        }, nil
×
4677
}
4678

4679
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
4680
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
4681
func batchLoadNodeFeaturesHelper(ctx context.Context,
4682
        cfg *sqldb.QueryConfig, db SQLQueries,
4683
        nodeIDs []int64) (map[int64][]int, error) {
×
4684

×
4685
        features := make(map[int64][]int)
×
4686

×
4687
        return features, sqldb.ExecuteBatchQuery(
×
4688
                ctx, cfg, nodeIDs,
×
4689
                func(id int64) int64 {
×
4690
                        return id
×
4691
                },
×
4692
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
4693
                        error) {
×
4694

×
4695
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
4696
                },
×
4697
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
4698
                        features[feature.NodeID] = append(
×
4699
                                features[feature.NodeID],
×
4700
                                int(feature.FeatureBit),
×
4701
                        )
×
4702

×
4703
                        return nil
×
4704
                },
×
4705
        )
4706
}
4707

4708
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
4709
// wrapper around the GetNodeAddressesBatch query. It returns a map from
4710
// node ID to a slice of nodeAddress structs.
4711
func batchLoadNodeAddressesHelper(ctx context.Context,
4712
        cfg *sqldb.QueryConfig, db SQLQueries,
4713
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
4714

×
4715
        addrs := make(map[int64][]nodeAddress)
×
4716

×
4717
        return addrs, sqldb.ExecuteBatchQuery(
×
4718
                ctx, cfg, nodeIDs,
×
4719
                func(id int64) int64 {
×
4720
                        return id
×
4721
                },
×
4722
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
4723
                        error) {
×
4724

×
4725
                        return db.GetNodeAddressesBatch(ctx, ids)
×
4726
                },
×
4727
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
4728
                        addrs[addr.NodeID] = append(
×
4729
                                addrs[addr.NodeID], nodeAddress{
×
4730
                                        addrType: dbAddressType(addr.Type),
×
4731
                                        position: addr.Position,
×
4732
                                        address:  addr.Address,
×
4733
                                },
×
4734
                        )
×
4735

×
4736
                        return nil
×
4737
                },
×
4738
        )
4739
}
4740

4741
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
4742
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
4743
// query.
4744
func batchLoadNodeExtraTypesHelper(ctx context.Context,
4745
        cfg *sqldb.QueryConfig, db SQLQueries,
4746
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4747

×
4748
        extraFields := make(map[int64]map[uint64][]byte)
×
4749

×
4750
        callback := func(ctx context.Context,
×
4751
                field sqlc.GraphNodeExtraType) error {
×
4752

×
4753
                if extraFields[field.NodeID] == nil {
×
4754
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
4755
                }
×
4756
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
4757

×
4758
                return nil
×
4759
        }
4760

4761
        return extraFields, sqldb.ExecuteBatchQuery(
×
4762
                ctx, cfg, nodeIDs,
×
4763
                func(id int64) int64 {
×
4764
                        return id
×
4765
                },
×
4766
                func(ctx context.Context, ids []int64) (
4767
                        []sqlc.GraphNodeExtraType, error) {
×
4768

×
4769
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
4770
                },
×
4771
                callback,
4772
        )
4773
}
4774

4775
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
4776
// from the provided sqlc.GraphChannelPolicy records and the
4777
// provided batchChannelData.
4778
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4779
        channelID uint64, node1, node2 route.Vertex,
4780
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
4781
        *models.ChannelEdgePolicy, error) {
×
4782

×
4783
        pol1, err := buildChanPolicyWithBatchData(
×
4784
                dbPol1, channelID, node2, batchData,
×
4785
        )
×
4786
        if err != nil {
×
4787
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4788
        }
×
4789

4790
        pol2, err := buildChanPolicyWithBatchData(
×
4791
                dbPol2, channelID, node1, batchData,
×
4792
        )
×
4793
        if err != nil {
×
4794
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4795
        }
×
4796

4797
        return pol1, pol2, nil
×
4798
}
4799

4800
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
4801
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
4802
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
4803
        channelID uint64, toNode route.Vertex,
4804
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
4805

×
4806
        if dbPol == nil {
×
4807
                return nil, nil
×
4808
        }
×
4809

4810
        var dbPol1Extras map[uint64][]byte
×
4811
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
4812
                dbPol1Extras = extras
×
4813
        } else {
×
4814
                dbPol1Extras = make(map[uint64][]byte)
×
4815
        }
×
4816

4817
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
4818
}
4819

4820
// batchChannelData holds all the related data for a batch of channels.
4821
type batchChannelData struct {
4822
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
4823
        chanfeatures map[int64][]int
4824

4825
        // chanExtras is a map from DB channel ID to a map of TLV type to
4826
        // extra signed field bytes.
4827
        chanExtraTypes map[int64]map[uint64][]byte
4828

4829
        // policyExtras is a map from DB channel policy ID to a map of TLV type
4830
        // to extra signed field bytes.
4831
        policyExtras map[int64]map[uint64][]byte
4832
}
4833

4834
// batchLoadChannelData loads all related data for batches of channels and
4835
// policies.
4836
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
4837
        db SQLQueries, channelIDs []int64,
4838
        policyIDs []int64) (*batchChannelData, error) {
×
4839

×
4840
        batchData := &batchChannelData{
×
4841
                chanfeatures:   make(map[int64][]int),
×
4842
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
4843
                policyExtras:   make(map[int64]map[uint64][]byte),
×
4844
        }
×
4845

×
4846
        // Batch load channel features and extras
×
4847
        var err error
×
4848
        if len(channelIDs) > 0 {
×
4849
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
4850
                        ctx, cfg, db, channelIDs,
×
4851
                )
×
4852
                if err != nil {
×
4853
                        return nil, fmt.Errorf("unable to batch load "+
×
4854
                                "channel features: %w", err)
×
4855
                }
×
4856

4857
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
4858
                        ctx, cfg, db, channelIDs,
×
4859
                )
×
4860
                if err != nil {
×
4861
                        return nil, fmt.Errorf("unable to batch load "+
×
4862
                                "channel extras: %w", err)
×
4863
                }
×
4864
        }
4865

4866
        if len(policyIDs) > 0 {
×
4867
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
4868
                        ctx, cfg, db, policyIDs,
×
4869
                )
×
4870
                if err != nil {
×
4871
                        return nil, fmt.Errorf("unable to batch load "+
×
4872
                                "policy extras: %w", err)
×
4873
                }
×
4874
                batchData.policyExtras = policyExtras
×
4875
        }
4876

4877
        return batchData, nil
×
4878
}
4879

4880
// batchLoadChannelFeaturesHelper loads channel features for a batch of
4881
// channel IDs using ExecuteBatchQuery wrapper around the
4882
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
4883
// slice of feature bits.
4884
func batchLoadChannelFeaturesHelper(ctx context.Context,
4885
        cfg *sqldb.QueryConfig, db SQLQueries,
4886
        channelIDs []int64) (map[int64][]int, error) {
×
4887

×
4888
        features := make(map[int64][]int)
×
4889

×
4890
        return features, sqldb.ExecuteBatchQuery(
×
4891
                ctx, cfg, channelIDs,
×
4892
                func(id int64) int64 {
×
4893
                        return id
×
4894
                },
×
4895
                func(ctx context.Context,
4896
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
4897

×
4898
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
4899
                },
×
4900
                func(ctx context.Context,
4901
                        feature sqlc.GraphChannelFeature) error {
×
4902

×
4903
                        features[feature.ChannelID] = append(
×
4904
                                features[feature.ChannelID],
×
4905
                                int(feature.FeatureBit),
×
4906
                        )
×
4907

×
4908
                        return nil
×
4909
                },
×
4910
        )
4911
}
4912

4913
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
4914
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
4915
// query. It returns a map from DB channel ID to a map of TLV type to extra
4916
// signed field bytes.
4917
func batchLoadChannelExtrasHelper(ctx context.Context,
4918
        cfg *sqldb.QueryConfig, db SQLQueries,
4919
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4920

×
4921
        extras := make(map[int64]map[uint64][]byte)
×
4922

×
4923
        cb := func(ctx context.Context,
×
4924
                extra sqlc.GraphChannelExtraType) error {
×
4925

×
4926
                if extras[extra.ChannelID] == nil {
×
4927
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
4928
                }
×
4929
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
4930

×
4931
                return nil
×
4932
        }
4933

4934
        return extras, sqldb.ExecuteBatchQuery(
×
4935
                ctx, cfg, channelIDs,
×
4936
                func(id int64) int64 {
×
4937
                        return id
×
4938
                },
×
4939
                func(ctx context.Context,
4940
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
4941

×
4942
                        return db.GetChannelExtrasBatch(ctx, ids)
×
4943
                }, cb,
×
4944
        )
4945
}
4946

4947
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
4948
// batch of policy IDs using ExecuteBatchQuery wrapper around the
4949
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
4950
// a map of TLV type to extra signed field bytes.
4951
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
4952
        cfg *sqldb.QueryConfig, db SQLQueries,
4953
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4954

×
4955
        extras := make(map[int64]map[uint64][]byte)
×
4956

×
4957
        return extras, sqldb.ExecuteBatchQuery(
×
4958
                ctx, cfg, policyIDs,
×
4959
                func(id int64) int64 {
×
4960
                        return id
×
4961
                },
×
4962
                func(ctx context.Context, ids []int64) (
4963
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
4964

×
4965
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
4966
                },
×
4967
                func(ctx context.Context,
4968
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
4969

×
4970
                        if extras[row.PolicyID] == nil {
×
4971
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
4972
                        }
×
4973
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
4974

×
4975
                        return nil
×
4976
                },
4977
        )
4978
}
4979

4980
// forEachNodePaginated executes a paginated query to process each node in the
4981
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
4982
// and applies the provided processNode function to each node.
4983
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
4984
        db SQLQueries, protocol ProtocolVersion,
4985
        processNode func(context.Context, int64,
4986
                *models.LightningNode) error) error {
×
4987

×
4988
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
4989
                limit int32) ([]sqlc.GraphNode, error) {
×
4990

×
4991
                return db.ListNodesPaginated(
×
4992
                        ctx, sqlc.ListNodesPaginatedParams{
×
4993
                                Version: int16(protocol),
×
4994
                                ID:      lastID,
×
4995
                                Limit:   limit,
×
4996
                        },
×
4997
                )
×
4998
        }
×
4999

5000
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5001
                return node.ID
×
5002
        }
×
5003

5004
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5005
                return node.ID, nil
×
5006
        }
×
5007

5008
        batchQueryFunc := func(ctx context.Context,
×
5009
                nodeIDs []int64) (*batchNodeData, error) {
×
5010

×
5011
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5012
        }
×
5013

5014
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5015
                batchData *batchNodeData) error {
×
5016

×
5017
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
5018
                if err != nil {
×
5019
                        return fmt.Errorf("unable to build "+
×
5020
                                "node(id=%d): %w", dbNode.ID, err)
×
5021
                }
×
5022

5023
                return processNode(ctx, dbNode.ID, node)
×
5024
        }
5025

5026
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5027
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5028
                collectFunc, batchQueryFunc, processItem,
×
5029
        )
×
5030
}
5031

5032
// forEachChannelWithPolicies executes a paginated query to process each channel
5033
// with policies in the graph.
5034
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5035
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5036
                *models.ChannelEdgePolicy,
5037
                *models.ChannelEdgePolicy) error) error {
×
5038

×
5039
        type channelBatchIDs struct {
×
5040
                channelID int64
×
5041
                policyIDs []int64
×
5042
        }
×
5043

×
5044
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5045
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5046
                error) {
×
5047

×
5048
                return db.ListChannelsWithPoliciesPaginated(
×
5049
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
5050
                                Version: int16(ProtocolV1),
×
5051
                                ID:      lastID,
×
5052
                                Limit:   limit,
×
5053
                        },
×
5054
                )
×
5055
        }
×
5056

5057
        extractPageCursor := func(
×
5058
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5059

×
5060
                return row.GraphChannel.ID
×
5061
        }
×
5062

5063
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
5064
                channelBatchIDs, error) {
×
5065

×
5066
                ids := channelBatchIDs{
×
5067
                        channelID: row.GraphChannel.ID,
×
5068
                }
×
5069

×
5070
                // Extract policy IDs from the row.
×
5071
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5072
                if err != nil {
×
5073
                        return ids, err
×
5074
                }
×
5075

5076
                if dbPol1 != nil {
×
5077
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5078
                }
×
5079
                if dbPol2 != nil {
×
5080
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5081
                }
×
5082

5083
                return ids, nil
×
5084
        }
5085

5086
        batchDataFunc := func(ctx context.Context,
×
5087
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5088

×
5089
                // Separate channel IDs from policy IDs.
×
5090
                var (
×
5091
                        channelIDs = make([]int64, len(allIDs))
×
5092
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5093
                )
×
5094

×
5095
                for i, ids := range allIDs {
×
5096
                        channelIDs[i] = ids.channelID
×
5097
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5098
                }
×
5099

5100
                return batchLoadChannelData(
×
5101
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5102
                )
×
5103
        }
5104

5105
        processItem := func(ctx context.Context,
×
5106
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5107
                batchData *batchChannelData) error {
×
5108

×
5109
                node1, node2, err := buildNodeVertices(
×
5110
                        row.Node1Pubkey, row.Node2Pubkey,
×
5111
                )
×
5112
                if err != nil {
×
5113
                        return err
×
5114
                }
×
5115

5116
                edge, err := buildEdgeInfoWithBatchData(
×
5117
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5118
                        batchData,
×
5119
                )
×
5120
                if err != nil {
×
5121
                        return fmt.Errorf("unable to build channel info: %w",
×
5122
                                err)
×
5123
                }
×
5124

5125
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5126
                if err != nil {
×
5127
                        return err
×
5128
                }
×
5129

5130
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5131
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
5132
                )
×
5133
                if err != nil {
×
5134
                        return err
×
5135
                }
×
5136

5137
                return processChannel(edge, p1, p2)
×
5138
        }
5139

5140
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5141
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5142
                collectFunc, batchDataFunc, processItem,
×
5143
        )
×
5144
}
5145

5146
// buildDirectedChannel builds a DirectedChannel instance from the provided
5147
// data.
5148
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5149
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5150
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5151
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
5152

×
5153
        node1, node2, err := buildNodeVertices(
×
5154
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5155
        )
×
5156
        if err != nil {
×
5157
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5158
        }
×
5159

5160
        edge, err := buildEdgeInfoWithBatchData(
×
5161
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5162
        )
×
5163
        if err != nil {
×
5164
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5165
        }
×
5166

5167
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
5168
        if err != nil {
×
5169
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5170
                        err)
×
5171
        }
×
5172

5173
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5174
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5175
                channelBatchData,
×
5176
        )
×
5177
        if err != nil {
×
5178
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5179
                        err)
×
5180
        }
×
5181

5182
        // Determine outgoing and incoming policy for this specific node.
5183
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5184
        p2ToNode := channelRow.GraphChannel.NodeID1
×
5185
        outPolicy, inPolicy := p1, p2
×
5186
        if (p1 != nil && p1ToNode == nodeID) ||
×
5187
                (p2 != nil && p2ToNode != nodeID) {
×
5188

×
5189
                outPolicy, inPolicy = p2, p1
×
5190
        }
×
5191

5192
        // Build cached policy.
5193
        var cachedInPolicy *models.CachedEdgePolicy
×
5194
        if inPolicy != nil {
×
5195
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5196
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5197
                cachedInPolicy.ToNodeFeatures = features
×
5198
        }
×
5199

5200
        // Extract inbound fee.
5201
        var inboundFee lnwire.Fee
×
5202
        if outPolicy != nil {
×
5203
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5204
                        inboundFee = fee
×
5205
                })
×
5206
        }
5207

5208
        // Build directed channel.
5209
        directedChannel := &DirectedChannel{
×
5210
                ChannelID:    edge.ChannelID,
×
5211
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
5212
                OtherNode:    edge.NodeKey2Bytes,
×
5213
                Capacity:     edge.Capacity,
×
5214
                OutPolicySet: outPolicy != nil,
×
5215
                InPolicy:     cachedInPolicy,
×
5216
                InboundFee:   inboundFee,
×
5217
        }
×
5218

×
5219
        if nodePub == edge.NodeKey2Bytes {
×
5220
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5221
        }
×
5222

5223
        return directedChannel, nil
×
5224
}
5225

5226
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
5227
// provided rows. It uses batch loading for channels, policies, and nodes.
5228
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
5229
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
5230

×
5231
        var (
×
5232
                channelIDs = make([]int64, len(rows))
×
5233
                policyIDs  = make([]int64, 0, len(rows)*2)
×
5234
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
5235

×
5236
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
5237
                nodeIDSet = make(map[int64]bool)
×
5238

×
5239
                // edges will hold the final channel edges built from the rows.
×
5240
                edges = make([]ChannelEdge, 0, len(rows))
×
5241
        )
×
5242

×
5243
        // Collect all IDs needed for batch loading.
×
5244
        for i, row := range rows {
×
5245
                channelIDs[i] = row.Channel().ID
×
5246

×
5247
                // Collect policy IDs
×
5248
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5249
                if err != nil {
×
5250
                        return nil, fmt.Errorf("unable to extract channel "+
×
5251
                                "policies: %w", err)
×
5252
                }
×
5253
                if dbPol1 != nil {
×
5254
                        policyIDs = append(policyIDs, dbPol1.ID)
×
5255
                }
×
5256
                if dbPol2 != nil {
×
5257
                        policyIDs = append(policyIDs, dbPol2.ID)
×
5258
                }
×
5259

5260
                var (
×
5261
                        node1ID = row.Node1().ID
×
5262
                        node2ID = row.Node2().ID
×
5263
                )
×
5264

×
5265
                // Collect unique node IDs.
×
5266
                if !nodeIDSet[node1ID] {
×
5267
                        nodeIDs = append(nodeIDs, node1ID)
×
5268
                        nodeIDSet[node1ID] = true
×
5269
                }
×
5270

5271
                if !nodeIDSet[node2ID] {
×
5272
                        nodeIDs = append(nodeIDs, node2ID)
×
5273
                        nodeIDSet[node2ID] = true
×
5274
                }
×
5275
        }
5276

5277
        // Batch the data for all the channels and policies.
5278
        channelBatchData, err := batchLoadChannelData(
×
5279
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5280
        )
×
5281
        if err != nil {
×
5282
                return nil, fmt.Errorf("unable to batch load channel and "+
×
5283
                        "policy data: %w", err)
×
5284
        }
×
5285

5286
        // Batch the data for all the nodes.
5287
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
5288
        if err != nil {
×
5289
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
5290
                        err)
×
5291
        }
×
5292

5293
        // Build all channel edges using batch data.
5294
        for _, row := range rows {
×
5295
                // Build nodes using batch data.
×
5296
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
5297
                if err != nil {
×
5298
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
5299
                }
×
5300

5301
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
5302
                if err != nil {
×
5303
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
5304
                }
×
5305

5306
                // Build channel info using batch data.
5307
                channel, err := buildEdgeInfoWithBatchData(
×
5308
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
5309
                        node2.PubKeyBytes, channelBatchData,
×
5310
                )
×
5311
                if err != nil {
×
5312
                        return nil, fmt.Errorf("unable to build channel "+
×
5313
                                "info: %w", err)
×
5314
                }
×
5315

5316
                // Extract and build policies using batch data.
5317
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5318
                if err != nil {
×
5319
                        return nil, fmt.Errorf("unable to extract channel "+
×
5320
                                "policies: %w", err)
×
5321
                }
×
5322

5323
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5324
                        dbPol1, dbPol2, channel.ChannelID,
×
5325
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
5326
                )
×
5327
                if err != nil {
×
5328
                        return nil, fmt.Errorf("unable to build channel "+
×
5329
                                "policies: %w", err)
×
5330
                }
×
5331

5332
                edges = append(edges, ChannelEdge{
×
5333
                        Info:    channel,
×
5334
                        Policy1: p1,
×
5335
                        Policy2: p2,
×
5336
                        Node1:   node1,
×
5337
                        Node2:   node2,
×
5338
                })
×
5339
        }
5340

5341
        return edges, nil
×
5342
}
5343

5344
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
5345
// instances from the provided rows using batch loading for channel data.
5346
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
5347
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
5348
        []*models.ChannelEdgeInfo, []int64, error) {
×
5349

×
5350
        if len(rows) == 0 {
×
5351
                return nil, nil, nil
×
5352
        }
×
5353

5354
        // Collect all the channel IDs needed for batch loading.
5355
        channelIDs := make([]int64, len(rows))
×
5356
        for i, row := range rows {
×
5357
                channelIDs[i] = row.Channel().ID
×
5358
        }
×
5359

5360
        // Batch load the channel data.
5361
        channelBatchData, err := batchLoadChannelData(
×
5362
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
5363
        )
×
5364
        if err != nil {
×
5365
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
5366
                        "data: %w", err)
×
5367
        }
×
5368

5369
        // Build all channel edges using batch data.
5370
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
5371
        for _, row := range rows {
×
5372
                node1, node2, err := buildNodeVertices(
×
5373
                        row.Node1Pub(), row.Node2Pub(),
×
5374
                )
×
5375
                if err != nil {
×
5376
                        return nil, nil, err
×
5377
                }
×
5378

5379
                // Build channel info using batch data
5380
                info, err := buildEdgeInfoWithBatchData(
×
5381
                        cfg.ChainHash, row.Channel(), node1, node2,
×
5382
                        channelBatchData,
×
5383
                )
×
5384
                if err != nil {
×
5385
                        return nil, nil, err
×
5386
                }
×
5387

5388
                edges = append(edges, info)
×
5389
        }
5390

5391
        return edges, channelIDs, nil
×
5392
}
5393

5394
// handleZombieMarking is a helper function that handles the logic of
5395
// marking a channel as a zombie in the database. It takes into account whether
5396
// we are in strict zombie pruning mode, and adjusts the node public keys
5397
// accordingly based on the last update timestamps of the channel policies.
5398
func handleZombieMarking(ctx context.Context, db SQLQueries,
5399
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
5400
        strictZombiePruning bool, scid uint64) error {
×
5401

×
5402
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
5403

×
5404
        if strictZombiePruning {
×
5405
                var e1UpdateTime, e2UpdateTime *time.Time
×
5406
                if row.Policy1LastUpdate.Valid {
×
5407
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
5408
                        e1UpdateTime = &e1Time
×
5409
                }
×
5410
                if row.Policy2LastUpdate.Valid {
×
5411
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
5412
                        e2UpdateTime = &e2Time
×
5413
                }
×
5414

5415
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
5416
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
5417
                        e2UpdateTime,
×
5418
                )
×
5419
        }
5420

5421
        return db.UpsertZombieChannel(
×
5422
                ctx, sqlc.UpsertZombieChannelParams{
×
5423
                        Version:  int16(ProtocolV1),
×
5424
                        Scid:     channelIDToBytes(scid),
×
5425
                        NodeKey1: nodeKey1[:],
×
5426
                        NodeKey2: nodeKey2[:],
×
5427
                },
×
5428
        )
×
5429
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc