• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 18279366784

06 Oct 2025 11:28AM UTC coverage: 66.641% (-0.005%) from 66.646%
18279366784

push

github

web-flow
Merge pull request #10269 from ellemouton/microSampleConf

1 of 20 new or added lines in 2 files covered. (5.0%)

67 existing lines in 22 files now uncovered.

137218 of 205906 relevant lines covered (66.64%)

21326.42 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "iter"
11
        "maps"
12
        "math"
13
        "net"
14
        "slices"
15
        "strconv"
16
        "sync"
17
        "time"
18

19
        "github.com/btcsuite/btcd/btcec/v2"
20
        "github.com/btcsuite/btcd/btcutil"
21
        "github.com/btcsuite/btcd/chaincfg/chainhash"
22
        "github.com/btcsuite/btcd/wire"
23
        "github.com/lightningnetwork/lnd/aliasmgr"
24
        "github.com/lightningnetwork/lnd/batch"
25
        "github.com/lightningnetwork/lnd/fn/v2"
26
        "github.com/lightningnetwork/lnd/graph/db/models"
27
        "github.com/lightningnetwork/lnd/lnwire"
28
        "github.com/lightningnetwork/lnd/routing/route"
29
        "github.com/lightningnetwork/lnd/sqldb"
30
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
31
        "github.com/lightningnetwork/lnd/tlv"
32
        "github.com/lightningnetwork/lnd/tor"
33
)
34

35
// ProtocolVersion is an enum that defines the gossip protocol version of a
36
// message.
37
type ProtocolVersion uint8
38

39
const (
40
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
41
        ProtocolV1 ProtocolVersion = 1
42
)
43

44
// String returns a string representation of the protocol version.
45
func (v ProtocolVersion) String() string {
×
46
        return fmt.Sprintf("V%d", v)
×
47
}
×
48

49
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
50
// execute queries against the SQL graph tables.
51
//
52
//nolint:ll,interfacebloat
53
type SQLQueries interface {
54
        /*
55
                Node queries.
56
        */
57
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
58
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
59
        GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
60
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
61
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
62
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
63
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
64
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
65
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
66
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
67
        DeleteNode(ctx context.Context, id int64) error
68

69
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
70
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
71
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
72
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
73

74
        UpsertNodeAddress(ctx context.Context, arg sqlc.UpsertNodeAddressParams) error
75
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
76
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
81
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
82
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
83
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
84

85
        /*
86
                Source node queries.
87
        */
88
        AddSourceNode(ctx context.Context, nodeID int64) error
89
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
90

91
        /*
92
                Channel queries.
93
        */
94
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
95
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
96
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
97
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
98
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
99
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
100
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
101
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
102
        GetChannelsByIDs(ctx context.Context, ids []int64) ([]sqlc.GetChannelsByIDsRow, error)
103
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
104
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
105
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
106
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
107
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
108
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
109
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
110
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
111
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
112
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
113
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
114
        DeleteChannels(ctx context.Context, ids []int64) error
115

116
        UpsertChannelExtraType(ctx context.Context, arg sqlc.UpsertChannelExtraTypeParams) error
117
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
118
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
119
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
120

121
        /*
122
                Channel Policy table queries.
123
        */
124
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
125
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
126
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
127

128
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
129
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
130
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
131

132
        /*
133
                Zombie index queries.
134
        */
135
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
136
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
137
        GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error)
138
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
139
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
140
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
141

142
        /*
143
                Prune log table queries.
144
        */
145
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
146
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
147
        GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]sqlc.GraphPruneLog, error)
148
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
149
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
150

151
        /*
152
                Closed SCID table queries.
153
        */
154
        InsertClosedChannel(ctx context.Context, scid []byte) error
155
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
156
        GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
157

158
        /*
159
                Migration specific queries.
160

161
                NOTE: these should not be used in code other than migrations.
162
                Once sqldbv2 is in place, these can be removed from this struct
163
                as then migrations will have their own dedicated queries
164
                structs.
165
        */
166
        InsertNodeMig(ctx context.Context, arg sqlc.InsertNodeMigParams) (int64, error)
167
        InsertChannelMig(ctx context.Context, arg sqlc.InsertChannelMigParams) (int64, error)
168
        InsertEdgePolicyMig(ctx context.Context, arg sqlc.InsertEdgePolicyMigParams) (int64, error)
169
}
170

171
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
172
// database operations.
173
type BatchedSQLQueries interface {
174
        SQLQueries
175
        sqldb.BatchedTx[SQLQueries]
176
}
177

178
// SQLStore is an implementation of the V1Store interface that uses a SQL
179
// database as the backend.
180
type SQLStore struct {
181
        cfg *SQLStoreConfig
182
        db  BatchedSQLQueries
183

184
        // cacheMu guards all caches (rejectCache and chanCache). If
185
        // this mutex will be acquired at the same time as the DB mutex then
186
        // the cacheMu MUST be acquired first to prevent deadlock.
187
        cacheMu     sync.RWMutex
188
        rejectCache *rejectCache
189
        chanCache   *channelCache
190

191
        chanScheduler batch.Scheduler[SQLQueries]
192
        nodeScheduler batch.Scheduler[SQLQueries]
193

194
        srcNodes  map[ProtocolVersion]*srcNodeInfo
195
        srcNodeMu sync.Mutex
196
}
197

198
// A compile-time assertion to ensure that SQLStore implements the V1Store
199
// interface.
200
var _ V1Store = (*SQLStore)(nil)
201

202
// SQLStoreConfig holds the configuration for the SQLStore.
203
type SQLStoreConfig struct {
204
        // ChainHash is the genesis hash for the chain that all the gossip
205
        // messages in this store are aimed at.
206
        ChainHash chainhash.Hash
207

208
        // QueryConfig holds configuration values for SQL queries.
209
        QueryCfg *sqldb.QueryConfig
210
}
211

212
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
213
// storage backend.
214
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
215
        options ...StoreOptionModifier) (*SQLStore, error) {
×
216

×
217
        opts := DefaultOptions()
×
218
        for _, o := range options {
×
219
                o(opts)
×
220
        }
×
221

222
        if opts.NoMigration {
×
223
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
224
                        "supported for SQL stores")
×
225
        }
×
226

227
        s := &SQLStore{
×
228
                cfg:         cfg,
×
229
                db:          db,
×
230
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
231
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
232
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
233
        }
×
234

×
235
        s.chanScheduler = batch.NewTimeScheduler(
×
236
                db, &s.cacheMu, opts.BatchCommitInterval,
×
237
        )
×
238
        s.nodeScheduler = batch.NewTimeScheduler(
×
239
                db, nil, opts.BatchCommitInterval,
×
240
        )
×
241

×
242
        return s, nil
×
243
}
244

245
// AddNode adds a vertex/node to the graph database. If the node is not
246
// in the database from before, this will add a new, unconnected one to the
247
// graph. If it is present from before, this will update that node's
248
// information.
249
//
250
// NOTE: part of the V1Store interface.
251
func (s *SQLStore) AddNode(ctx context.Context,
252
        node *models.Node, opts ...batch.SchedulerOption) error {
×
253

×
254
        r := &batch.Request[SQLQueries]{
×
255
                Opts: batch.NewSchedulerOptions(opts...),
×
256
                Do: func(queries SQLQueries) error {
×
257
                        _, err := upsertNode(ctx, queries, node)
×
NEW
258

×
NEW
259
                        // It is possible that two of the same node
×
NEW
260
                        // announcements are both being processed in the same
×
NEW
261
                        // batch. This may case the UpsertNode conflict to
×
NEW
262
                        // be hit since we require at the db layer that the
×
NEW
263
                        // new last_update is greater than the existing
×
NEW
264
                        // last_update. We need to gracefully handle this here.
×
NEW
265
                        if errors.Is(err, sql.ErrNoRows) {
×
NEW
266
                                return nil
×
NEW
267
                        }
×
268

269
                        return err
×
270
                },
271
        }
272

273
        return s.nodeScheduler.Execute(ctx, r)
×
274
}
275

276
// FetchNode attempts to look up a target node by its identity public
277
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
278
// returned.
279
//
280
// NOTE: part of the V1Store interface.
281
func (s *SQLStore) FetchNode(ctx context.Context,
282
        pubKey route.Vertex) (*models.Node, error) {
×
283

×
284
        var node *models.Node
×
285
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
286
                var err error
×
287
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, pubKey)
×
288

×
289
                return err
×
290
        }, sqldb.NoOpReset)
×
291
        if err != nil {
×
292
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
293
        }
×
294

295
        return node, nil
×
296
}
297

298
// HasNode determines if the graph has a vertex identified by the
299
// target node identity public key. If the node exists in the database, a
300
// timestamp of when the data for the node was lasted updated is returned along
301
// with a true boolean. Otherwise, an empty time.Time is returned with a false
302
// boolean.
303
//
304
// NOTE: part of the V1Store interface.
305
func (s *SQLStore) HasNode(ctx context.Context,
306
        pubKey [33]byte) (time.Time, bool, error) {
×
307

×
308
        var (
×
309
                exists     bool
×
310
                lastUpdate time.Time
×
311
        )
×
312
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
313
                dbNode, err := db.GetNodeByPubKey(
×
314
                        ctx, sqlc.GetNodeByPubKeyParams{
×
315
                                Version: int16(ProtocolV1),
×
316
                                PubKey:  pubKey[:],
×
317
                        },
×
318
                )
×
319
                if errors.Is(err, sql.ErrNoRows) {
×
320
                        return nil
×
321
                } else if err != nil {
×
322
                        return fmt.Errorf("unable to fetch node: %w", err)
×
323
                }
×
324

325
                exists = true
×
326

×
327
                if dbNode.LastUpdate.Valid {
×
328
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
329
                }
×
330

331
                return nil
×
332
        }, sqldb.NoOpReset)
333
        if err != nil {
×
334
                return time.Time{}, false,
×
335
                        fmt.Errorf("unable to fetch node: %w", err)
×
336
        }
×
337

338
        return lastUpdate, exists, nil
×
339
}
340

341
// AddrsForNode returns all known addresses for the target node public key
342
// that the graph DB is aware of. The returned boolean indicates if the
343
// given node is unknown to the graph DB or not.
344
//
345
// NOTE: part of the V1Store interface.
346
func (s *SQLStore) AddrsForNode(ctx context.Context,
347
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
348

×
349
        var (
×
350
                addresses []net.Addr
×
351
                known     bool
×
352
        )
×
353
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
354
                // First, check if the node exists and get its DB ID if it
×
355
                // does.
×
356
                dbID, err := db.GetNodeIDByPubKey(
×
357
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
358
                                Version: int16(ProtocolV1),
×
359
                                PubKey:  nodePub.SerializeCompressed(),
×
360
                        },
×
361
                )
×
362
                if errors.Is(err, sql.ErrNoRows) {
×
363
                        return nil
×
364
                }
×
365

366
                known = true
×
367

×
368
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
369
                if err != nil {
×
370
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
371
                                err)
×
372
                }
×
373

374
                return nil
×
375
        }, sqldb.NoOpReset)
376
        if err != nil {
×
377
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
378
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
379
        }
×
380

381
        return known, addresses, nil
×
382
}
383

384
// DeleteNode starts a new database transaction to remove a vertex/node
385
// from the database according to the node's public key.
386
//
387
// NOTE: part of the V1Store interface.
388
func (s *SQLStore) DeleteNode(ctx context.Context,
389
        pubKey route.Vertex) error {
×
390

×
391
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
392
                res, err := db.DeleteNodeByPubKey(
×
393
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
394
                                Version: int16(ProtocolV1),
×
395
                                PubKey:  pubKey[:],
×
396
                        },
×
397
                )
×
398
                if err != nil {
×
399
                        return err
×
400
                }
×
401

402
                rows, err := res.RowsAffected()
×
403
                if err != nil {
×
404
                        return err
×
405
                }
×
406

407
                if rows == 0 {
×
408
                        return ErrGraphNodeNotFound
×
409
                } else if rows > 1 {
×
410
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
411
                }
×
412

413
                return err
×
414
        }, sqldb.NoOpReset)
415
        if err != nil {
×
416
                return fmt.Errorf("unable to delete node: %w", err)
×
417
        }
×
418

419
        return nil
×
420
}
421

422
// FetchNodeFeatures returns the features of the given node. If no features are
423
// known for the node, an empty feature vector is returned.
424
//
425
// NOTE: this is part of the graphdb.NodeTraverser interface.
426
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
427
        *lnwire.FeatureVector, error) {
×
428

×
429
        ctx := context.TODO()
×
430

×
431
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
432
}
×
433

434
// DisabledChannelIDs returns the channel ids of disabled channels.
435
// A channel is disabled when two of the associated ChanelEdgePolicies
436
// have their disabled bit on.
437
//
438
// NOTE: part of the V1Store interface.
439
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
440
        var (
×
441
                ctx     = context.TODO()
×
442
                chanIDs []uint64
×
443
        )
×
444
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
445
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
446
                if err != nil {
×
447
                        return fmt.Errorf("unable to fetch disabled "+
×
448
                                "channels: %w", err)
×
449
                }
×
450

451
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
452

×
453
                return nil
×
454
        }, sqldb.NoOpReset)
455
        if err != nil {
×
456
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
457
                        err)
×
458
        }
×
459

460
        return chanIDs, nil
×
461
}
462

463
// LookupAlias attempts to return the alias as advertised by the target node.
464
//
465
// NOTE: part of the V1Store interface.
466
func (s *SQLStore) LookupAlias(ctx context.Context,
467
        pub *btcec.PublicKey) (string, error) {
×
468

×
469
        var alias string
×
470
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
471
                dbNode, err := db.GetNodeByPubKey(
×
472
                        ctx, sqlc.GetNodeByPubKeyParams{
×
473
                                Version: int16(ProtocolV1),
×
474
                                PubKey:  pub.SerializeCompressed(),
×
475
                        },
×
476
                )
×
477
                if errors.Is(err, sql.ErrNoRows) {
×
478
                        return ErrNodeAliasNotFound
×
479
                } else if err != nil {
×
480
                        return fmt.Errorf("unable to fetch node: %w", err)
×
481
                }
×
482

483
                if !dbNode.Alias.Valid {
×
484
                        return ErrNodeAliasNotFound
×
485
                }
×
486

487
                alias = dbNode.Alias.String
×
488

×
489
                return nil
×
490
        }, sqldb.NoOpReset)
491
        if err != nil {
×
492
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
493
        }
×
494

495
        return alias, nil
×
496
}
497

498
// SourceNode returns the source node of the graph. The source node is treated
499
// as the center node within a star-graph. This method may be used to kick off
500
// a path finding algorithm in order to explore the reachability of another
501
// node based off the source node.
502
//
503
// NOTE: part of the V1Store interface.
504
func (s *SQLStore) SourceNode(ctx context.Context) (*models.Node,
505
        error) {
×
506

×
507
        var node *models.Node
×
508
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
509
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
510
                if err != nil {
×
511
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
512
                                err)
×
513
                }
×
514

515
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, nodePub)
×
516

×
517
                return err
×
518
        }, sqldb.NoOpReset)
519
        if err != nil {
×
520
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
521
        }
×
522

523
        return node, nil
×
524
}
525

526
// SetSourceNode sets the source node within the graph database. The source
527
// node is to be used as the center of a star-graph within path finding
528
// algorithms.
529
//
530
// NOTE: part of the V1Store interface.
531
func (s *SQLStore) SetSourceNode(ctx context.Context,
532
        node *models.Node) error {
×
533

×
534
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
535
                id, err := upsertNode(ctx, db, node)
×
536
                if err != nil {
×
537
                        return fmt.Errorf("unable to upsert source node: %w",
×
538
                                err)
×
539
                }
×
540

541
                // Make sure that if a source node for this version is already
542
                // set, then the ID is the same as the one we are about to set.
543
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
544
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
545
                        return fmt.Errorf("unable to fetch source node: %w",
×
546
                                err)
×
547
                } else if err == nil {
×
548
                        if dbSourceNodeID != id {
×
549
                                return fmt.Errorf("v1 source node already "+
×
550
                                        "set to a different node: %d vs %d",
×
551
                                        dbSourceNodeID, id)
×
552
                        }
×
553

554
                        return nil
×
555
                }
556

557
                return db.AddSourceNode(ctx, id)
×
558
        }, sqldb.NoOpReset)
559
}
560

561
// NodeUpdatesInHorizon returns all the known lightning node which have an
562
// update timestamp within the passed range. This method can be used by two
563
// nodes to quickly determine if they have the same set of up to date node
564
// announcements.
565
//
566
// NOTE: This is part of the V1Store interface.
567
func (s *SQLStore) NodeUpdatesInHorizon(startTime, endTime time.Time,
568
        opts ...IteratorOption) iter.Seq2[models.Node, error] {
×
569

×
570
        cfg := defaultIteratorConfig()
×
571
        for _, opt := range opts {
×
572
                opt(cfg)
×
573
        }
×
574

575
        return func(yield func(models.Node, error) bool) {
×
576
                var (
×
577
                        ctx            = context.TODO()
×
578
                        lastUpdateTime sql.NullInt64
×
579
                        lastPubKey     = make([]byte, 33)
×
580
                        hasMore        = true
×
581
                )
×
582

×
583
                // Each iteration, we'll read a batch amount of nodes, yield
×
584
                // them, then decide is we have more or not.
×
585
                for hasMore {
×
586
                        var batch []models.Node
×
587

×
588
                        //nolint:ll
×
589
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
590
                                //nolint:ll
×
591
                                params := sqlc.GetNodesByLastUpdateRangeParams{
×
592
                                        StartTime: sqldb.SQLInt64(
×
593
                                                startTime.Unix(),
×
594
                                        ),
×
595
                                        EndTime: sqldb.SQLInt64(
×
596
                                                endTime.Unix(),
×
597
                                        ),
×
598
                                        LastUpdate: lastUpdateTime,
×
599
                                        LastPubKey: lastPubKey,
×
600
                                        OnlyPublic: sql.NullBool{
×
601
                                                Bool:  cfg.iterPublicNodes,
×
602
                                                Valid: true,
×
603
                                        },
×
604
                                        MaxResults: sqldb.SQLInt32(
×
605
                                                cfg.nodeUpdateIterBatchSize,
×
606
                                        ),
×
607
                                }
×
608
                                rows, err := db.GetNodesByLastUpdateRange(
×
609
                                        ctx, params,
×
610
                                )
×
611
                                if err != nil {
×
612
                                        return err
×
613
                                }
×
614

615
                                hasMore = len(rows) == cfg.nodeUpdateIterBatchSize
×
616

×
617
                                err = forEachNodeInBatch(
×
618
                                        ctx, s.cfg.QueryCfg, db, rows,
×
619
                                        func(_ int64, node *models.Node) error {
×
620
                                                batch = append(batch, *node)
×
621

×
622
                                                // Update pagination cursors
×
623
                                                // based on the last processed
×
624
                                                // node.
×
625
                                                lastUpdateTime = sql.NullInt64{
×
626
                                                        Int64: node.LastUpdate.
×
627
                                                                Unix(),
×
628
                                                        Valid: true,
×
629
                                                }
×
630
                                                lastPubKey = node.PubKeyBytes[:]
×
631

×
632
                                                return nil
×
633
                                        },
×
634
                                )
635
                                if err != nil {
×
636
                                        return fmt.Errorf("unable to build "+
×
637
                                                "nodes: %w", err)
×
638
                                }
×
639

640
                                return nil
×
641
                        }, func() {
×
642
                                batch = []models.Node{}
×
643
                        })
×
644

645
                        if err != nil {
×
646
                                log.Errorf("NodeUpdatesInHorizon batch "+
×
647
                                        "error: %v", err)
×
648

×
649
                                yield(models.Node{}, err)
×
650

×
651
                                return
×
652
                        }
×
653

654
                        for _, node := range batch {
×
655
                                if !yield(node, nil) {
×
656
                                        return
×
657
                                }
×
658
                        }
659

660
                        // If the batch didn't yield anything, then we're done.
661
                        if len(batch) == 0 {
×
662
                                break
×
663
                        }
664
                }
665
        }
666
}
667

668
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
669
// undirected edge from the two target nodes are created. The information stored
670
// denotes the static attributes of the channel, such as the channelID, the keys
671
// involved in creation of the channel, and the set of features that the channel
672
// supports. The chanPoint and chanID are used to uniquely identify the edge
673
// globally within the database.
674
//
675
// NOTE: part of the V1Store interface.
676
func (s *SQLStore) AddChannelEdge(ctx context.Context,
677
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
678

×
679
        var alreadyExists bool
×
680
        r := &batch.Request[SQLQueries]{
×
681
                Opts: batch.NewSchedulerOptions(opts...),
×
682
                Reset: func() {
×
683
                        alreadyExists = false
×
684
                },
×
685
                Do: func(tx SQLQueries) error {
×
686
                        chanIDB := channelIDToBytes(edge.ChannelID)
×
687

×
688
                        // Make sure that the channel doesn't already exist. We
×
689
                        // do this explicitly instead of relying on catching a
×
690
                        // unique constraint error because relying on SQL to
×
691
                        // throw that error would abort the entire batch of
×
692
                        // transactions.
×
693
                        _, err := tx.GetChannelBySCID(
×
694
                                ctx, sqlc.GetChannelBySCIDParams{
×
695
                                        Scid:    chanIDB,
×
696
                                        Version: int16(ProtocolV1),
×
697
                                },
×
698
                        )
×
699
                        if err == nil {
×
700
                                alreadyExists = true
×
701
                                return nil
×
702
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
703
                                return fmt.Errorf("unable to fetch channel: %w",
×
704
                                        err)
×
705
                        }
×
706

707
                        return insertChannel(ctx, tx, edge)
×
708
                },
709
                OnCommit: func(err error) error {
×
710
                        switch {
×
711
                        case err != nil:
×
712
                                return err
×
713
                        case alreadyExists:
×
714
                                return ErrEdgeAlreadyExist
×
715
                        default:
×
716
                                s.rejectCache.remove(edge.ChannelID)
×
717
                                s.chanCache.remove(edge.ChannelID)
×
718
                                return nil
×
719
                        }
720
                },
721
        }
722

723
        return s.chanScheduler.Execute(ctx, r)
×
724
}
725

726
// HighestChanID returns the "highest" known channel ID in the channel graph.
727
// This represents the "newest" channel from the PoV of the chain. This method
728
// can be used by peers to quickly determine if their graphs are in sync.
729
//
730
// NOTE: This is part of the V1Store interface.
731
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
732
        var highestChanID uint64
×
733
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
734
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
735
                if errors.Is(err, sql.ErrNoRows) {
×
736
                        return nil
×
737
                } else if err != nil {
×
738
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
739
                                err)
×
740
                }
×
741

742
                highestChanID = byteOrder.Uint64(chanID)
×
743

×
744
                return nil
×
745
        }, sqldb.NoOpReset)
746
        if err != nil {
×
747
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
748
        }
×
749

750
        return highestChanID, nil
×
751
}
752

753
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
754
// within the database for the referenced channel. The `flags` attribute within
755
// the ChannelEdgePolicy determines which of the directed edges are being
756
// updated. If the flag is 1, then the first node's information is being
757
// updated, otherwise it's the second node's information. The node ordering is
758
// determined by the lexicographical ordering of the identity public keys of the
759
// nodes on either side of the channel.
760
//
761
// NOTE: part of the V1Store interface.
762
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
763
        edge *models.ChannelEdgePolicy,
764
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
765

×
766
        var (
×
767
                isUpdate1    bool
×
768
                edgeNotFound bool
×
769
                from, to     route.Vertex
×
770
        )
×
771

×
772
        r := &batch.Request[SQLQueries]{
×
773
                Opts: batch.NewSchedulerOptions(opts...),
×
774
                Reset: func() {
×
775
                        isUpdate1 = false
×
776
                        edgeNotFound = false
×
777
                },
×
778
                Do: func(tx SQLQueries) error {
×
779
                        var err error
×
780
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
781
                                ctx, tx, edge,
×
782
                        )
×
NEW
783
                        // It is possible that two of the same policy
×
NEW
784
                        // announcements are both being processed in the same
×
NEW
785
                        // batch. This may case the UpsertEdgePolicy conflict to
×
NEW
786
                        // be hit since we require at the db layer that the
×
NEW
787
                        // new last_update is greater than the existing
×
NEW
788
                        // last_update. We need to gracefully handle this here.
×
NEW
789
                        if errors.Is(err, sql.ErrNoRows) {
×
NEW
790
                                return nil
×
NEW
791
                        } else if err != nil {
×
792
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
793
                        }
×
794

795
                        // Silence ErrEdgeNotFound so that the batch can
796
                        // succeed, but propagate the error via local state.
797
                        if errors.Is(err, ErrEdgeNotFound) {
×
798
                                edgeNotFound = true
×
799
                                return nil
×
800
                        }
×
801

802
                        return err
×
803
                },
804
                OnCommit: func(err error) error {
×
805
                        switch {
×
806
                        case err != nil:
×
807
                                return err
×
808
                        case edgeNotFound:
×
809
                                return ErrEdgeNotFound
×
810
                        default:
×
811
                                s.updateEdgeCache(edge, isUpdate1)
×
812
                                return nil
×
813
                        }
814
                },
815
        }
816

817
        err := s.chanScheduler.Execute(ctx, r)
×
818

×
819
        return from, to, err
×
820
}
821

822
// updateEdgeCache updates our reject and channel caches with the new
823
// edge policy information.
824
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
825
        isUpdate1 bool) {
×
826

×
827
        // If an entry for this channel is found in reject cache, we'll modify
×
828
        // the entry with the updated timestamp for the direction that was just
×
829
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
830
        // during the next query for this edge.
×
831
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
832
                if isUpdate1 {
×
833
                        entry.upd1Time = e.LastUpdate.Unix()
×
834
                } else {
×
835
                        entry.upd2Time = e.LastUpdate.Unix()
×
836
                }
×
837
                s.rejectCache.insert(e.ChannelID, entry)
×
838
        }
839

840
        // If an entry for this channel is found in channel cache, we'll modify
841
        // the entry with the updated policy for the direction that was just
842
        // written. If the edge doesn't exist, we'll defer loading the info and
843
        // policies and lazily read from disk during the next query.
844
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
845
                if isUpdate1 {
×
846
                        channel.Policy1 = e
×
847
                } else {
×
848
                        channel.Policy2 = e
×
849
                }
×
850
                s.chanCache.insert(e.ChannelID, channel)
×
851
        }
852
}
853

854
// ForEachSourceNodeChannel iterates through all channels of the source node,
855
// executing the passed callback on each. The call-back is provided with the
856
// channel's outpoint, whether we have a policy for the channel and the channel
857
// peer's node information.
858
//
859
// NOTE: part of the V1Store interface.
860
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
861
        cb func(chanPoint wire.OutPoint, havePolicy bool,
862
                otherNode *models.Node) error, reset func()) error {
×
863

×
864
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
865
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
866
                if err != nil {
×
867
                        return fmt.Errorf("unable to fetch source node: %w",
×
868
                                err)
×
869
                }
×
870

871
                return forEachNodeChannel(
×
872
                        ctx, db, s.cfg, nodeID,
×
873
                        func(info *models.ChannelEdgeInfo,
×
874
                                outPolicy *models.ChannelEdgePolicy,
×
875
                                _ *models.ChannelEdgePolicy) error {
×
876

×
877
                                // Fetch the other node.
×
878
                                var (
×
879
                                        otherNodePub [33]byte
×
880
                                        node1        = info.NodeKey1Bytes
×
881
                                        node2        = info.NodeKey2Bytes
×
882
                                )
×
883
                                switch {
×
884
                                case bytes.Equal(node1[:], nodePub[:]):
×
885
                                        otherNodePub = node2
×
886
                                case bytes.Equal(node2[:], nodePub[:]):
×
887
                                        otherNodePub = node1
×
888
                                default:
×
889
                                        return fmt.Errorf("node not " +
×
890
                                                "participating in this channel")
×
891
                                }
892

893
                                _, otherNode, err := getNodeByPubKey(
×
894
                                        ctx, s.cfg.QueryCfg, db, otherNodePub,
×
895
                                )
×
896
                                if err != nil {
×
897
                                        return fmt.Errorf("unable to fetch "+
×
898
                                                "other node(%x): %w",
×
899
                                                otherNodePub, err)
×
900
                                }
×
901

902
                                return cb(
×
903
                                        info.ChannelPoint, outPolicy != nil,
×
904
                                        otherNode,
×
905
                                )
×
906
                        },
907
                )
908
        }, reset)
909
}
910

911
// ForEachNode iterates through all the stored vertices/nodes in the graph,
912
// executing the passed callback with each node encountered. If the callback
913
// returns an error, then the transaction is aborted and the iteration stops
914
// early.
915
//
916
// NOTE: part of the V1Store interface.
917
func (s *SQLStore) ForEachNode(ctx context.Context,
918
        cb func(node *models.Node) error, reset func()) error {
×
919

×
920
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
921
                return forEachNodePaginated(
×
922
                        ctx, s.cfg.QueryCfg, db,
×
923
                        ProtocolV1, func(_ context.Context, _ int64,
×
924
                                node *models.Node) error {
×
925

×
926
                                return cb(node)
×
927
                        },
×
928
                )
929
        }, reset)
930
}
931

932
// ForEachNodeDirectedChannel iterates through all channels of a given node,
933
// executing the passed callback on the directed edge representing the channel
934
// and its incoming policy. If the callback returns an error, then the iteration
935
// is halted with the error propagated back up to the caller.
936
//
937
// Unknown policies are passed into the callback as nil values.
938
//
939
// NOTE: this is part of the graphdb.NodeTraverser interface.
940
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
941
        cb func(channel *DirectedChannel) error, reset func()) error {
×
942

×
943
        var ctx = context.TODO()
×
944

×
945
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
946
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
947
        }, reset)
×
948
}
949

950
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
951
// graph, executing the passed callback with each node encountered. If the
952
// callback returns an error, then the transaction is aborted and the iteration
953
// stops early.
954
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
955
        cb func(route.Vertex, *lnwire.FeatureVector) error,
956
        reset func()) error {
×
957

×
958
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
959
                return forEachNodeCacheable(
×
960
                        ctx, s.cfg.QueryCfg, db,
×
961
                        func(_ int64, nodePub route.Vertex,
×
962
                                features *lnwire.FeatureVector) error {
×
963

×
964
                                return cb(nodePub, features)
×
965
                        },
×
966
                )
967
        }, reset)
968
        if err != nil {
×
969
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
970
        }
×
971

972
        return nil
×
973
}
974

975
// ForEachNodeChannel iterates through all channels of the given node,
976
// executing the passed callback with an edge info structure and the policies
977
// of each end of the channel. The first edge policy is the outgoing edge *to*
978
// the connecting node, while the second is the incoming edge *from* the
979
// connecting node. If the callback returns an error, then the iteration is
980
// halted with the error propagated back up to the caller.
981
//
982
// Unknown policies are passed into the callback as nil values.
983
//
984
// NOTE: part of the V1Store interface.
985
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
986
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
987
                *models.ChannelEdgePolicy) error, reset func()) error {
×
988

×
989
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
990
                dbNode, err := db.GetNodeByPubKey(
×
991
                        ctx, sqlc.GetNodeByPubKeyParams{
×
992
                                Version: int16(ProtocolV1),
×
993
                                PubKey:  nodePub[:],
×
994
                        },
×
995
                )
×
996
                if errors.Is(err, sql.ErrNoRows) {
×
997
                        return nil
×
998
                } else if err != nil {
×
999
                        return fmt.Errorf("unable to fetch node: %w", err)
×
1000
                }
×
1001

1002
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
1003
        }, reset)
1004
}
1005

1006
// extractMaxUpdateTime returns the maximum of the two policy update times.
1007
// This is used for pagination cursor tracking.
1008
func extractMaxUpdateTime(
1009
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) int64 {
×
1010

×
1011
        switch {
×
1012
        case row.Policy1LastUpdate.Valid && row.Policy2LastUpdate.Valid:
×
1013
                return max(row.Policy1LastUpdate.Int64,
×
1014
                        row.Policy2LastUpdate.Int64)
×
1015
        case row.Policy1LastUpdate.Valid:
×
1016
                return row.Policy1LastUpdate.Int64
×
1017
        case row.Policy2LastUpdate.Valid:
×
1018
                return row.Policy2LastUpdate.Int64
×
1019
        default:
×
1020
                return 0
×
1021
        }
1022
}
1023

1024
// buildChannelFromRow constructs a ChannelEdge from a database row.
1025
// This includes building the nodes, channel info, and policies.
1026
func (s *SQLStore) buildChannelFromRow(ctx context.Context, db SQLQueries,
1027
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) (ChannelEdge, error) {
×
1028

×
1029
        node1, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode)
×
1030
        if err != nil {
×
1031
                return ChannelEdge{}, fmt.Errorf("unable to build node1: %w",
×
1032
                        err)
×
1033
        }
×
1034

1035
        node2, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode_2)
×
1036
        if err != nil {
×
1037
                return ChannelEdge{}, fmt.Errorf("unable to build node2: %w",
×
1038
                        err)
×
1039
        }
×
1040

1041
        channel, err := getAndBuildEdgeInfo(
×
1042
                ctx, s.cfg, db,
×
1043
                row.GraphChannel, node1.PubKeyBytes,
×
1044
                node2.PubKeyBytes,
×
1045
        )
×
1046
        if err != nil {
×
1047
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1048
                        "channel info: %w", err)
×
1049
        }
×
1050

1051
        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1052
        if err != nil {
×
1053
                return ChannelEdge{}, fmt.Errorf("unable to extract "+
×
1054
                        "channel policies: %w", err)
×
1055
        }
×
1056

1057
        p1, p2, err := getAndBuildChanPolicies(
×
1058
                ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, channel.ChannelID,
×
1059
                node1.PubKeyBytes, node2.PubKeyBytes,
×
1060
        )
×
1061
        if err != nil {
×
1062
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1063
                        "channel policies: %w", err)
×
1064
        }
×
1065

1066
        return ChannelEdge{
×
1067
                Info:    channel,
×
1068
                Policy1: p1,
×
1069
                Policy2: p2,
×
1070
                Node1:   node1,
×
1071
                Node2:   node2,
×
1072
        }, nil
×
1073
}
1074

1075
// updateChanCacheBatch updates the channel cache with multiple edges at once.
1076
// This method acquires the cache lock only once for the entire batch.
1077
func (s *SQLStore) updateChanCacheBatch(edgesToCache map[uint64]ChannelEdge) {
×
1078
        if len(edgesToCache) == 0 {
×
1079
                return
×
1080
        }
×
1081

1082
        s.cacheMu.Lock()
×
1083
        defer s.cacheMu.Unlock()
×
1084

×
1085
        for chanID, edge := range edgesToCache {
×
1086
                s.chanCache.insert(chanID, edge)
×
1087
        }
×
1088
}
1089

1090
// ChanUpdatesInHorizon returns all the known channel edges which have at least
1091
// one edge that has an update timestamp within the specified horizon.
1092
//
1093
// Iterator Lifecycle:
1094
// 1. Initialize state (edgesSeen map, cache tracking, pagination cursors)
1095
// 2. Query batch of channels with policies in time range
1096
// 3. For each channel: check if seen, check cache, or build from DB
1097
// 4. Yield channels to caller
1098
// 5. Update cache after successful batch
1099
// 6. Repeat with updated pagination cursor until no more results
1100
//
1101
// NOTE: This is part of the V1Store interface.
1102
func (s *SQLStore) ChanUpdatesInHorizon(startTime, endTime time.Time,
1103
        opts ...IteratorOption) iter.Seq2[ChannelEdge, error] {
×
1104

×
1105
        // Apply options.
×
1106
        cfg := defaultIteratorConfig()
×
1107
        for _, opt := range opts {
×
1108
                opt(cfg)
×
1109
        }
×
1110

1111
        return func(yield func(ChannelEdge, error) bool) {
×
1112
                var (
×
1113
                        ctx            = context.TODO()
×
1114
                        edgesSeen      = make(map[uint64]struct{})
×
1115
                        edgesToCache   = make(map[uint64]ChannelEdge)
×
1116
                        hits           int
×
1117
                        total          int
×
1118
                        lastUpdateTime sql.NullInt64
×
1119
                        lastID         sql.NullInt64
×
1120
                        hasMore        = true
×
1121
                )
×
1122

×
1123
                // Each iteration, we'll read a batch amount of channel updates
×
1124
                // (consulting the cache along the way), yield them, then loop
×
1125
                // back to decide if we have any more updates to read out.
×
1126
                for hasMore {
×
1127
                        var batch []ChannelEdge
×
1128

×
1129
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(),
×
1130
                                func(db SQLQueries) error {
×
1131
                                        //nolint:ll
×
1132
                                        params := sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1133
                                                Version: int16(ProtocolV1),
×
1134
                                                StartTime: sqldb.SQLInt64(
×
1135
                                                        startTime.Unix(),
×
1136
                                                ),
×
1137
                                                EndTime: sqldb.SQLInt64(
×
1138
                                                        endTime.Unix(),
×
1139
                                                ),
×
1140
                                                LastUpdateTime: lastUpdateTime,
×
1141
                                                LastID:         lastID,
×
1142
                                                MaxResults: sql.NullInt32{
×
1143
                                                        Int32: int32(
×
1144
                                                                cfg.chanUpdateIterBatchSize,
×
1145
                                                        ),
×
1146
                                                        Valid: true,
×
1147
                                                },
×
1148
                                        }
×
1149
                                        //nolint:ll
×
1150
                                        rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1151
                                                ctx, params,
×
1152
                                        )
×
1153
                                        if err != nil {
×
1154
                                                return err
×
1155
                                        }
×
1156

1157
                                        //nolint:ll
1158
                                        hasMore = len(rows) == cfg.chanUpdateIterBatchSize
×
1159

×
1160
                                        //nolint:ll
×
1161
                                        for _, row := range rows {
×
1162
                                                lastUpdateTime = sql.NullInt64{
×
1163
                                                        Int64: extractMaxUpdateTime(row),
×
1164
                                                        Valid: true,
×
1165
                                                }
×
1166
                                                lastID = sql.NullInt64{
×
1167
                                                        Int64: row.GraphChannel.ID,
×
1168
                                                        Valid: true,
×
1169
                                                }
×
1170

×
1171
                                                // Skip if we've already
×
1172
                                                // processed this channel.
×
1173
                                                chanIDInt := byteOrder.Uint64(
×
1174
                                                        row.GraphChannel.Scid,
×
1175
                                                )
×
1176
                                                _, ok := edgesSeen[chanIDInt]
×
1177
                                                if ok {
×
1178
                                                        continue
×
1179
                                                }
1180

1181
                                                s.cacheMu.RLock()
×
1182
                                                channel, ok := s.chanCache.get(
×
1183
                                                        chanIDInt,
×
1184
                                                )
×
1185
                                                s.cacheMu.RUnlock()
×
1186
                                                if ok {
×
1187
                                                        hits++
×
1188
                                                        total++
×
1189
                                                        edgesSeen[chanIDInt] = struct{}{}
×
1190
                                                        batch = append(batch, channel)
×
1191

×
1192
                                                        continue
×
1193
                                                }
1194

1195
                                                chanEdge, err := s.buildChannelFromRow(
×
1196
                                                        ctx, db, row,
×
1197
                                                )
×
1198
                                                if err != nil {
×
1199
                                                        return err
×
1200
                                                }
×
1201

1202
                                                edgesSeen[chanIDInt] = struct{}{}
×
1203
                                                edgesToCache[chanIDInt] = chanEdge
×
1204

×
1205
                                                batch = append(batch, chanEdge)
×
1206

×
1207
                                                total++
×
1208
                                        }
1209

1210
                                        return nil
×
1211
                                }, func() {
×
1212
                                        batch = nil
×
1213
                                        edgesSeen = make(map[uint64]struct{})
×
1214
                                        edgesToCache = make(
×
1215
                                                map[uint64]ChannelEdge,
×
1216
                                        )
×
1217
                                })
×
1218

1219
                        if err != nil {
×
1220
                                log.Errorf("ChanUpdatesInHorizon "+
×
1221
                                        "batch error: %v", err)
×
1222

×
1223
                                yield(ChannelEdge{}, err)
×
1224

×
1225
                                return
×
1226
                        }
×
1227

1228
                        for _, edge := range batch {
×
1229
                                if !yield(edge, nil) {
×
1230
                                        return
×
1231
                                }
×
1232
                        }
1233

1234
                        // Update cache after successful batch yield, setting
1235
                        // the cache lock only once for the entire batch.
1236
                        s.updateChanCacheBatch(edgesToCache)
×
1237
                        edgesToCache = make(map[uint64]ChannelEdge)
×
1238

×
1239
                        // If the batch didn't yield anything, then we're done.
×
1240
                        if len(batch) == 0 {
×
1241
                                break
×
1242
                        }
1243
                }
1244

1245
                if total > 0 {
×
1246
                        log.Debugf("ChanUpdatesInHorizon hit percentage: "+
×
1247
                                "%.2f (%d/%d)",
×
1248
                                float64(hits)*100/float64(total), hits, total)
×
1249
                } else {
×
1250
                        log.Debugf("ChanUpdatesInHorizon returned no edges "+
×
1251
                                "in horizon (%s, %s)", startTime, endTime)
×
1252
                }
×
1253
        }
1254
}
1255

1256
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1257
// data to the call-back. If withAddrs is true, then the call-back will also be
1258
// provided with the addresses associated with the node. The address retrieval
1259
// result in an additional round-trip to the database, so it should only be used
1260
// if the addresses are actually needed.
1261
//
1262
// NOTE: part of the V1Store interface.
1263
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1264
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1265
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1266

×
1267
        type nodeCachedBatchData struct {
×
1268
                features      map[int64][]int
×
1269
                addrs         map[int64][]nodeAddress
×
1270
                chanBatchData *batchChannelData
×
1271
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1272
        }
×
1273

×
1274
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1275
                // pageQueryFunc is used to query the next page of nodes.
×
1276
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1277
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1278

×
1279
                        return db.ListNodeIDsAndPubKeys(
×
1280
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1281
                                        Version: int16(ProtocolV1),
×
1282
                                        ID:      lastID,
×
1283
                                        Limit:   limit,
×
1284
                                },
×
1285
                        )
×
1286
                }
×
1287

1288
                // batchDataFunc is then used to batch load the data required
1289
                // for each page of nodes.
1290
                batchDataFunc := func(ctx context.Context,
×
1291
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1292

×
1293
                        // Batch load node features.
×
1294
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1295
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1296
                        )
×
1297
                        if err != nil {
×
1298
                                return nil, fmt.Errorf("unable to batch load "+
×
1299
                                        "node features: %w", err)
×
1300
                        }
×
1301

1302
                        // Maybe fetch the node's addresses if requested.
1303
                        var nodeAddrs map[int64][]nodeAddress
×
1304
                        if withAddrs {
×
1305
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1306
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1307
                                )
×
1308
                                if err != nil {
×
1309
                                        return nil, fmt.Errorf("unable to "+
×
1310
                                                "batch load node "+
×
1311
                                                "addresses: %w", err)
×
1312
                                }
×
1313
                        }
1314

1315
                        // Batch load ALL unique channels for ALL nodes in this
1316
                        // page.
1317
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1318
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1319
                                        Version:  int16(ProtocolV1),
×
1320
                                        Node1Ids: nodeIDs,
×
1321
                                        Node2Ids: nodeIDs,
×
1322
                                },
×
1323
                        )
×
1324
                        if err != nil {
×
1325
                                return nil, fmt.Errorf("unable to batch "+
×
1326
                                        "fetch channels for nodes: %w", err)
×
1327
                        }
×
1328

1329
                        // Deduplicate channels and collect IDs.
1330
                        var (
×
1331
                                allChannelIDs []int64
×
1332
                                allPolicyIDs  []int64
×
1333
                        )
×
1334
                        uniqueChannels := make(
×
1335
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1336
                        )
×
1337

×
1338
                        for _, channel := range allChannels {
×
1339
                                channelID := channel.GraphChannel.ID
×
1340

×
1341
                                // Only process each unique channel once.
×
1342
                                _, exists := uniqueChannels[channelID]
×
1343
                                if exists {
×
1344
                                        continue
×
1345
                                }
1346

1347
                                uniqueChannels[channelID] = channel
×
1348
                                allChannelIDs = append(allChannelIDs, channelID)
×
1349

×
1350
                                if channel.Policy1ID.Valid {
×
1351
                                        allPolicyIDs = append(
×
1352
                                                allPolicyIDs,
×
1353
                                                channel.Policy1ID.Int64,
×
1354
                                        )
×
1355
                                }
×
1356
                                if channel.Policy2ID.Valid {
×
1357
                                        allPolicyIDs = append(
×
1358
                                                allPolicyIDs,
×
1359
                                                channel.Policy2ID.Int64,
×
1360
                                        )
×
1361
                                }
×
1362
                        }
1363

1364
                        // Batch load channel data for all unique channels.
1365
                        channelBatchData, err := batchLoadChannelData(
×
1366
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1367
                                allPolicyIDs,
×
1368
                        )
×
1369
                        if err != nil {
×
1370
                                return nil, fmt.Errorf("unable to batch "+
×
1371
                                        "load channel data: %w", err)
×
1372
                        }
×
1373

1374
                        // Create map of node ID to channels that involve this
1375
                        // node.
1376
                        nodeIDSet := make(map[int64]bool)
×
1377
                        for _, nodeID := range nodeIDs {
×
1378
                                nodeIDSet[nodeID] = true
×
1379
                        }
×
1380

1381
                        nodeChannelMap := make(
×
1382
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1383
                        )
×
1384
                        for _, channel := range uniqueChannels {
×
1385
                                // Add channel to both nodes if they're in our
×
1386
                                // current page.
×
1387
                                node1 := channel.GraphChannel.NodeID1
×
1388
                                if nodeIDSet[node1] {
×
1389
                                        nodeChannelMap[node1] = append(
×
1390
                                                nodeChannelMap[node1], channel,
×
1391
                                        )
×
1392
                                }
×
1393
                                node2 := channel.GraphChannel.NodeID2
×
1394
                                if nodeIDSet[node2] {
×
1395
                                        nodeChannelMap[node2] = append(
×
1396
                                                nodeChannelMap[node2], channel,
×
1397
                                        )
×
1398
                                }
×
1399
                        }
1400

1401
                        return &nodeCachedBatchData{
×
1402
                                features:      nodeFeatures,
×
1403
                                addrs:         nodeAddrs,
×
1404
                                chanBatchData: channelBatchData,
×
1405
                                chanMap:       nodeChannelMap,
×
1406
                        }, nil
×
1407
                }
1408

1409
                // processItem is used to process each node in the current page.
1410
                processItem := func(ctx context.Context,
×
1411
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1412
                        batchData *nodeCachedBatchData) error {
×
1413

×
1414
                        // Build feature vector for this node.
×
1415
                        fv := lnwire.EmptyFeatureVector()
×
1416
                        features, exists := batchData.features[nodeData.ID]
×
1417
                        if exists {
×
1418
                                for _, bit := range features {
×
1419
                                        fv.Set(lnwire.FeatureBit(bit))
×
1420
                                }
×
1421
                        }
1422

1423
                        var nodePub route.Vertex
×
1424
                        copy(nodePub[:], nodeData.PubKey)
×
1425

×
1426
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1427

×
1428
                        toNodeCallback := func() route.Vertex {
×
1429
                                return nodePub
×
1430
                        }
×
1431

1432
                        // Build cached channels map for this node.
1433
                        channels := make(map[uint64]*DirectedChannel)
×
1434
                        for _, channelRow := range nodeChannels {
×
1435
                                directedChan, err := buildDirectedChannel(
×
1436
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1437
                                        channelRow, batchData.chanBatchData, fv,
×
1438
                                        toNodeCallback,
×
1439
                                )
×
1440
                                if err != nil {
×
1441
                                        return err
×
1442
                                }
×
1443

1444
                                channels[directedChan.ChannelID] = directedChan
×
1445
                        }
1446

1447
                        addrs, err := buildNodeAddresses(
×
1448
                                batchData.addrs[nodeData.ID],
×
1449
                        )
×
1450
                        if err != nil {
×
1451
                                return fmt.Errorf("unable to build node "+
×
1452
                                        "addresses: %w", err)
×
1453
                        }
×
1454

1455
                        return cb(ctx, nodePub, addrs, channels)
×
1456
                }
1457

1458
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1459
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1460
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1461
                                return node.ID
×
1462
                        },
×
1463
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1464
                                error) {
×
1465

×
1466
                                return node.ID, nil
×
1467
                        },
×
1468
                        batchDataFunc, processItem,
1469
                )
1470
        }, reset)
1471
}
1472

1473
// ForEachChannelCacheable iterates through all the channel edges stored
1474
// within the graph and invokes the passed callback for each edge. The
1475
// callback takes two edges as since this is a directed graph, both the
1476
// in/out edges are visited. If the callback returns an error, then the
1477
// transaction is aborted and the iteration stops early.
1478
//
1479
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1480
// pointer for that particular channel edge routing policy will be
1481
// passed into the callback.
1482
//
1483
// NOTE: this method is like ForEachChannel but fetches only the data
1484
// required for the graph cache.
1485
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1486
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1487
        reset func()) error {
×
1488

×
1489
        ctx := context.TODO()
×
1490

×
1491
        handleChannel := func(_ context.Context,
×
1492
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1493

×
1494
                node1, node2, err := buildNodeVertices(
×
1495
                        row.Node1Pubkey, row.Node2Pubkey,
×
1496
                )
×
1497
                if err != nil {
×
1498
                        return err
×
1499
                }
×
1500

1501
                edge := buildCacheableChannelInfo(
×
1502
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1503
                )
×
1504

×
1505
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1506
                if err != nil {
×
1507
                        return err
×
1508
                }
×
1509

1510
                pol1, pol2, err := buildCachedChanPolicies(
×
1511
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1512
                )
×
1513
                if err != nil {
×
1514
                        return err
×
1515
                }
×
1516

1517
                return cb(edge, pol1, pol2)
×
1518
        }
1519

1520
        extractCursor := func(
×
1521
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1522

×
1523
                return row.ID
×
1524
        }
×
1525

1526
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1527
                //nolint:ll
×
1528
                queryFunc := func(ctx context.Context, lastID int64,
×
1529
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1530
                        error) {
×
1531

×
1532
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1533
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1534
                                        Version: int16(ProtocolV1),
×
1535
                                        ID:      lastID,
×
1536
                                        Limit:   limit,
×
1537
                                },
×
1538
                        )
×
1539
                }
×
1540

1541
                return sqldb.ExecutePaginatedQuery(
×
1542
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1543
                        extractCursor, handleChannel,
×
1544
                )
×
1545
        }, reset)
1546
}
1547

1548
// ForEachChannel iterates through all the channel edges stored within the
1549
// graph and invokes the passed callback for each edge. The callback takes two
1550
// edges as since this is a directed graph, both the in/out edges are visited.
1551
// If the callback returns an error, then the transaction is aborted and the
1552
// iteration stops early.
1553
//
1554
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1555
// for that particular channel edge routing policy will be passed into the
1556
// callback.
1557
//
1558
// NOTE: part of the V1Store interface.
1559
func (s *SQLStore) ForEachChannel(ctx context.Context,
1560
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1561
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1562

×
1563
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1564
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1565
        }, reset)
×
1566
}
1567

1568
// FilterChannelRange returns the channel ID's of all known channels which were
1569
// mined in a block height within the passed range. The channel IDs are grouped
1570
// by their common block height. This method can be used to quickly share with a
1571
// peer the set of channels we know of within a particular range to catch them
1572
// up after a period of time offline. If withTimestamps is true then the
1573
// timestamp info of the latest received channel update messages of the channel
1574
// will be included in the response.
1575
//
1576
// NOTE: This is part of the V1Store interface.
1577
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1578
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1579

×
1580
        var (
×
1581
                ctx       = context.TODO()
×
1582
                startSCID = &lnwire.ShortChannelID{
×
1583
                        BlockHeight: startHeight,
×
1584
                }
×
1585
                endSCID = lnwire.ShortChannelID{
×
1586
                        BlockHeight: endHeight,
×
1587
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1588
                        TxPosition:  math.MaxUint16,
×
1589
                }
×
1590
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1591
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1592
        )
×
1593

×
1594
        // 1) get all channels where channelID is between start and end chan ID.
×
1595
        // 2) skip if not public (ie, no channel_proof)
×
1596
        // 3) collect that channel.
×
1597
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1598
        //    and add those timestamps to the collected channel.
×
1599
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1600
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1601
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1602
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1603
                                StartScid: chanIDStart,
×
1604
                                EndScid:   chanIDEnd,
×
1605
                        },
×
1606
                )
×
1607
                if err != nil {
×
1608
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1609
                                err)
×
1610
                }
×
1611

1612
                for _, dbChan := range dbChans {
×
1613
                        cid := lnwire.NewShortChanIDFromInt(
×
1614
                                byteOrder.Uint64(dbChan.Scid),
×
1615
                        )
×
1616
                        chanInfo := NewChannelUpdateInfo(
×
1617
                                cid, time.Time{}, time.Time{},
×
1618
                        )
×
1619

×
1620
                        if !withTimestamps {
×
1621
                                channelsPerBlock[cid.BlockHeight] = append(
×
1622
                                        channelsPerBlock[cid.BlockHeight],
×
1623
                                        chanInfo,
×
1624
                                )
×
1625

×
1626
                                continue
×
1627
                        }
1628

1629
                        //nolint:ll
1630
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1631
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1632
                                        Version:   int16(ProtocolV1),
×
1633
                                        ChannelID: dbChan.ID,
×
1634
                                        NodeID:    dbChan.NodeID1,
×
1635
                                },
×
1636
                        )
×
1637
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1638
                                return fmt.Errorf("unable to fetch node1 "+
×
1639
                                        "policy: %w", err)
×
1640
                        } else if err == nil {
×
1641
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1642
                                        node1Policy.LastUpdate.Int64, 0,
×
1643
                                )
×
1644
                        }
×
1645

1646
                        //nolint:ll
1647
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1648
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1649
                                        Version:   int16(ProtocolV1),
×
1650
                                        ChannelID: dbChan.ID,
×
1651
                                        NodeID:    dbChan.NodeID2,
×
1652
                                },
×
1653
                        )
×
1654
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1655
                                return fmt.Errorf("unable to fetch node2 "+
×
1656
                                        "policy: %w", err)
×
1657
                        } else if err == nil {
×
1658
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1659
                                        node2Policy.LastUpdate.Int64, 0,
×
1660
                                )
×
1661
                        }
×
1662

1663
                        channelsPerBlock[cid.BlockHeight] = append(
×
1664
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1665
                        )
×
1666
                }
1667

1668
                return nil
×
1669
        }, func() {
×
1670
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1671
        })
×
1672
        if err != nil {
×
1673
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1674
        }
×
1675

1676
        if len(channelsPerBlock) == 0 {
×
1677
                return nil, nil
×
1678
        }
×
1679

1680
        // Return the channel ranges in ascending block height order.
1681
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1682
        slices.Sort(blocks)
×
1683

×
1684
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1685
                return BlockChannelRange{
×
1686
                        Height:   block,
×
1687
                        Channels: channelsPerBlock[block],
×
1688
                }
×
1689
        }), nil
×
1690
}
1691

1692
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1693
// zombie. This method is used on an ad-hoc basis, when channels need to be
1694
// marked as zombies outside the normal pruning cycle.
1695
//
1696
// NOTE: part of the V1Store interface.
1697
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1698
        pubKey1, pubKey2 [33]byte) error {
×
1699

×
1700
        ctx := context.TODO()
×
1701

×
1702
        s.cacheMu.Lock()
×
1703
        defer s.cacheMu.Unlock()
×
1704

×
1705
        chanIDB := channelIDToBytes(chanID)
×
1706

×
1707
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1708
                return db.UpsertZombieChannel(
×
1709
                        ctx, sqlc.UpsertZombieChannelParams{
×
1710
                                Version:  int16(ProtocolV1),
×
1711
                                Scid:     chanIDB,
×
1712
                                NodeKey1: pubKey1[:],
×
1713
                                NodeKey2: pubKey2[:],
×
1714
                        },
×
1715
                )
×
1716
        }, sqldb.NoOpReset)
×
1717
        if err != nil {
×
1718
                return fmt.Errorf("unable to upsert zombie channel "+
×
1719
                        "(channel_id=%d): %w", chanID, err)
×
1720
        }
×
1721

1722
        s.rejectCache.remove(chanID)
×
1723
        s.chanCache.remove(chanID)
×
1724

×
1725
        return nil
×
1726
}
1727

1728
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1729
//
1730
// NOTE: part of the V1Store interface.
1731
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1732
        s.cacheMu.Lock()
×
1733
        defer s.cacheMu.Unlock()
×
1734

×
1735
        var (
×
1736
                ctx     = context.TODO()
×
1737
                chanIDB = channelIDToBytes(chanID)
×
1738
        )
×
1739

×
1740
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1741
                res, err := db.DeleteZombieChannel(
×
1742
                        ctx, sqlc.DeleteZombieChannelParams{
×
1743
                                Scid:    chanIDB,
×
1744
                                Version: int16(ProtocolV1),
×
1745
                        },
×
1746
                )
×
1747
                if err != nil {
×
1748
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1749
                                err)
×
1750
                }
×
1751

1752
                rows, err := res.RowsAffected()
×
1753
                if err != nil {
×
1754
                        return err
×
1755
                }
×
1756

1757
                if rows == 0 {
×
1758
                        return ErrZombieEdgeNotFound
×
1759
                } else if rows > 1 {
×
1760
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1761
                                "expected 1", rows)
×
1762
                }
×
1763

1764
                return nil
×
1765
        }, sqldb.NoOpReset)
1766
        if err != nil {
×
1767
                return fmt.Errorf("unable to mark edge live "+
×
1768
                        "(channel_id=%d): %w", chanID, err)
×
1769
        }
×
1770

1771
        s.rejectCache.remove(chanID)
×
1772
        s.chanCache.remove(chanID)
×
1773

×
1774
        return err
×
1775
}
1776

1777
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1778
// zombie, then the two node public keys corresponding to this edge are also
1779
// returned.
1780
//
1781
// NOTE: part of the V1Store interface.
1782
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1783
        error) {
×
1784

×
1785
        var (
×
1786
                ctx              = context.TODO()
×
1787
                isZombie         bool
×
1788
                pubKey1, pubKey2 route.Vertex
×
1789
                chanIDB          = channelIDToBytes(chanID)
×
1790
        )
×
1791

×
1792
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1793
                zombie, err := db.GetZombieChannel(
×
1794
                        ctx, sqlc.GetZombieChannelParams{
×
1795
                                Scid:    chanIDB,
×
1796
                                Version: int16(ProtocolV1),
×
1797
                        },
×
1798
                )
×
1799
                if errors.Is(err, sql.ErrNoRows) {
×
1800
                        return nil
×
1801
                }
×
1802
                if err != nil {
×
1803
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1804
                                err)
×
1805
                }
×
1806

1807
                copy(pubKey1[:], zombie.NodeKey1)
×
1808
                copy(pubKey2[:], zombie.NodeKey2)
×
1809
                isZombie = true
×
1810

×
1811
                return nil
×
1812
        }, sqldb.NoOpReset)
1813
        if err != nil {
×
1814
                return false, route.Vertex{}, route.Vertex{},
×
1815
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1816
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1817
        }
×
1818

1819
        return isZombie, pubKey1, pubKey2, nil
×
1820
}
1821

1822
// NumZombies returns the current number of zombie channels in the graph.
1823
//
1824
// NOTE: part of the V1Store interface.
1825
func (s *SQLStore) NumZombies() (uint64, error) {
×
1826
        var (
×
1827
                ctx        = context.TODO()
×
1828
                numZombies uint64
×
1829
        )
×
1830
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1831
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1832
                if err != nil {
×
1833
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1834
                                err)
×
1835
                }
×
1836

1837
                numZombies = uint64(count)
×
1838

×
1839
                return nil
×
1840
        }, sqldb.NoOpReset)
1841
        if err != nil {
×
1842
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1843
        }
×
1844

1845
        return numZombies, nil
×
1846
}
1847

1848
// DeleteChannelEdges removes edges with the given channel IDs from the
1849
// database and marks them as zombies. This ensures that we're unable to re-add
1850
// it to our database once again. If an edge does not exist within the
1851
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1852
// true, then when we mark these edges as zombies, we'll set up the keys such
1853
// that we require the node that failed to send the fresh update to be the one
1854
// that resurrects the channel from its zombie state. The markZombie bool
1855
// denotes whether to mark the channel as a zombie.
1856
//
1857
// NOTE: part of the V1Store interface.
1858
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1859
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1860

×
1861
        s.cacheMu.Lock()
×
1862
        defer s.cacheMu.Unlock()
×
1863

×
1864
        // Keep track of which channels we end up finding so that we can
×
1865
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1866
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1867
        for _, chanID := range chanIDs {
×
1868
                chanLookup[chanID] = struct{}{}
×
1869
        }
×
1870

1871
        var (
×
1872
                ctx   = context.TODO()
×
1873
                edges []*models.ChannelEdgeInfo
×
1874
        )
×
1875
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1876
                // First, collect all channel rows.
×
1877
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1878
                chanCallBack := func(ctx context.Context,
×
1879
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1880

×
1881
                        // Deleting the entry from the map indicates that we
×
1882
                        // have found the channel.
×
1883
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1884
                        delete(chanLookup, scid)
×
1885

×
1886
                        channelRows = append(channelRows, row)
×
1887

×
1888
                        return nil
×
1889
                }
×
1890

1891
                err := s.forEachChanWithPoliciesInSCIDList(
×
1892
                        ctx, db, chanCallBack, chanIDs,
×
1893
                )
×
1894
                if err != nil {
×
1895
                        return err
×
1896
                }
×
1897

1898
                if len(chanLookup) > 0 {
×
1899
                        return ErrEdgeNotFound
×
1900
                }
×
1901

1902
                if len(channelRows) == 0 {
×
1903
                        return nil
×
1904
                }
×
1905

1906
                // Batch build all channel edges.
1907
                var chanIDsToDelete []int64
×
1908
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
1909
                        ctx, s.cfg, db, channelRows,
×
1910
                )
×
1911
                if err != nil {
×
1912
                        return err
×
1913
                }
×
1914

1915
                if markZombie {
×
1916
                        for i, row := range channelRows {
×
1917
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1918

×
1919
                                err := handleZombieMarking(
×
1920
                                        ctx, db, row, edges[i],
×
1921
                                        strictZombiePruning, scid,
×
1922
                                )
×
1923
                                if err != nil {
×
1924
                                        return fmt.Errorf("unable to mark "+
×
1925
                                                "channel as zombie: %w", err)
×
1926
                                }
×
1927
                        }
1928
                }
1929

1930
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1931
        }, func() {
×
1932
                edges = nil
×
1933

×
1934
                // Re-fill the lookup map.
×
1935
                for _, chanID := range chanIDs {
×
1936
                        chanLookup[chanID] = struct{}{}
×
1937
                }
×
1938
        })
1939
        if err != nil {
×
1940
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1941
                        err)
×
1942
        }
×
1943

1944
        for _, chanID := range chanIDs {
×
1945
                s.rejectCache.remove(chanID)
×
1946
                s.chanCache.remove(chanID)
×
1947
        }
×
1948

1949
        return edges, nil
×
1950
}
1951

1952
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1953
// channel identified by the channel ID. If the channel can't be found, then
1954
// ErrEdgeNotFound is returned. A struct which houses the general information
1955
// for the channel itself is returned as well as two structs that contain the
1956
// routing policies for the channel in either direction.
1957
//
1958
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1959
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1960
// the ChannelEdgeInfo will only include the public keys of each node.
1961
//
1962
// NOTE: part of the V1Store interface.
1963
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1964
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1965
        *models.ChannelEdgePolicy, error) {
×
1966

×
1967
        var (
×
1968
                ctx              = context.TODO()
×
1969
                edge             *models.ChannelEdgeInfo
×
1970
                policy1, policy2 *models.ChannelEdgePolicy
×
1971
                chanIDB          = channelIDToBytes(chanID)
×
1972
        )
×
1973
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1974
                row, err := db.GetChannelBySCIDWithPolicies(
×
1975
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1976
                                Scid:    chanIDB,
×
1977
                                Version: int16(ProtocolV1),
×
1978
                        },
×
1979
                )
×
1980
                if errors.Is(err, sql.ErrNoRows) {
×
1981
                        // First check if this edge is perhaps in the zombie
×
1982
                        // index.
×
1983
                        zombie, err := db.GetZombieChannel(
×
1984
                                ctx, sqlc.GetZombieChannelParams{
×
1985
                                        Scid:    chanIDB,
×
1986
                                        Version: int16(ProtocolV1),
×
1987
                                },
×
1988
                        )
×
1989
                        if errors.Is(err, sql.ErrNoRows) {
×
1990
                                return ErrEdgeNotFound
×
1991
                        } else if err != nil {
×
1992
                                return fmt.Errorf("unable to check if "+
×
1993
                                        "channel is zombie: %w", err)
×
1994
                        }
×
1995

1996
                        // At this point, we know the channel is a zombie, so
1997
                        // we'll return an error indicating this, and we will
1998
                        // populate the edge info with the public keys of each
1999
                        // party as this is the only information we have about
2000
                        // it.
2001
                        edge = &models.ChannelEdgeInfo{}
×
2002
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
2003
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
2004

×
2005
                        return ErrZombieEdge
×
2006
                } else if err != nil {
×
2007
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2008
                }
×
2009

2010
                node1, node2, err := buildNodeVertices(
×
2011
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
2012
                )
×
2013
                if err != nil {
×
2014
                        return err
×
2015
                }
×
2016

2017
                edge, err = getAndBuildEdgeInfo(
×
2018
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2019
                )
×
2020
                if err != nil {
×
2021
                        return fmt.Errorf("unable to build channel info: %w",
×
2022
                                err)
×
2023
                }
×
2024

2025
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2026
                if err != nil {
×
2027
                        return fmt.Errorf("unable to extract channel "+
×
2028
                                "policies: %w", err)
×
2029
                }
×
2030

2031
                policy1, policy2, err = getAndBuildChanPolicies(
×
2032
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2033
                        node1, node2,
×
2034
                )
×
2035
                if err != nil {
×
2036
                        return fmt.Errorf("unable to build channel "+
×
2037
                                "policies: %w", err)
×
2038
                }
×
2039

2040
                return nil
×
2041
        }, sqldb.NoOpReset)
2042
        if err != nil {
×
2043
                // If we are returning the ErrZombieEdge, then we also need to
×
2044
                // return the edge info as the method comment indicates that
×
2045
                // this will be populated when the edge is a zombie.
×
2046
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2047
                        err)
×
2048
        }
×
2049

2050
        return edge, policy1, policy2, nil
×
2051
}
2052

2053
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
2054
// the channel identified by the funding outpoint. If the channel can't be
2055
// found, then ErrEdgeNotFound is returned. A struct which houses the general
2056
// information for the channel itself is returned as well as two structs that
2057
// contain the routing policies for the channel in either direction.
2058
//
2059
// NOTE: part of the V1Store interface.
2060
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
2061
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
2062
        *models.ChannelEdgePolicy, error) {
×
2063

×
2064
        var (
×
2065
                ctx              = context.TODO()
×
2066
                edge             *models.ChannelEdgeInfo
×
2067
                policy1, policy2 *models.ChannelEdgePolicy
×
2068
        )
×
2069
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2070
                row, err := db.GetChannelByOutpointWithPolicies(
×
2071
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
2072
                                Outpoint: op.String(),
×
2073
                                Version:  int16(ProtocolV1),
×
2074
                        },
×
2075
                )
×
2076
                if errors.Is(err, sql.ErrNoRows) {
×
2077
                        return ErrEdgeNotFound
×
2078
                } else if err != nil {
×
2079
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2080
                }
×
2081

2082
                node1, node2, err := buildNodeVertices(
×
2083
                        row.Node1Pubkey, row.Node2Pubkey,
×
2084
                )
×
2085
                if err != nil {
×
2086
                        return err
×
2087
                }
×
2088

2089
                edge, err = getAndBuildEdgeInfo(
×
2090
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2091
                )
×
2092
                if err != nil {
×
2093
                        return fmt.Errorf("unable to build channel info: %w",
×
2094
                                err)
×
2095
                }
×
2096

2097
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2098
                if err != nil {
×
2099
                        return fmt.Errorf("unable to extract channel "+
×
2100
                                "policies: %w", err)
×
2101
                }
×
2102

2103
                policy1, policy2, err = getAndBuildChanPolicies(
×
2104
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2105
                        node1, node2,
×
2106
                )
×
2107
                if err != nil {
×
2108
                        return fmt.Errorf("unable to build channel "+
×
2109
                                "policies: %w", err)
×
2110
                }
×
2111

2112
                return nil
×
2113
        }, sqldb.NoOpReset)
2114
        if err != nil {
×
2115
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2116
                        err)
×
2117
        }
×
2118

2119
        return edge, policy1, policy2, nil
×
2120
}
2121

2122
// HasChannelEdge returns true if the database knows of a channel edge with the
2123
// passed channel ID, and false otherwise. If an edge with that ID is found
2124
// within the graph, then two time stamps representing the last time the edge
2125
// was updated for both directed edges are returned along with the boolean. If
2126
// it is not found, then the zombie index is checked and its result is returned
2127
// as the second boolean.
2128
//
2129
// NOTE: part of the V1Store interface.
2130
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
2131
        bool, error) {
×
2132

×
2133
        ctx := context.TODO()
×
2134

×
2135
        var (
×
2136
                exists          bool
×
2137
                isZombie        bool
×
2138
                node1LastUpdate time.Time
×
2139
                node2LastUpdate time.Time
×
2140
        )
×
2141

×
2142
        // We'll query the cache with the shared lock held to allow multiple
×
2143
        // readers to access values in the cache concurrently if they exist.
×
2144
        s.cacheMu.RLock()
×
2145
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2146
                s.cacheMu.RUnlock()
×
2147
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2148
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2149
                exists, isZombie = entry.flags.unpack()
×
2150

×
2151
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2152
        }
×
2153
        s.cacheMu.RUnlock()
×
2154

×
2155
        s.cacheMu.Lock()
×
2156
        defer s.cacheMu.Unlock()
×
2157

×
2158
        // The item was not found with the shared lock, so we'll acquire the
×
2159
        // exclusive lock and check the cache again in case another method added
×
2160
        // the entry to the cache while no lock was held.
×
2161
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2162
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2163
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2164
                exists, isZombie = entry.flags.unpack()
×
2165

×
2166
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2167
        }
×
2168

2169
        chanIDB := channelIDToBytes(chanID)
×
2170
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2171
                channel, err := db.GetChannelBySCID(
×
2172
                        ctx, sqlc.GetChannelBySCIDParams{
×
2173
                                Scid:    chanIDB,
×
2174
                                Version: int16(ProtocolV1),
×
2175
                        },
×
2176
                )
×
2177
                if errors.Is(err, sql.ErrNoRows) {
×
2178
                        // Check if it is a zombie channel.
×
2179
                        isZombie, err = db.IsZombieChannel(
×
2180
                                ctx, sqlc.IsZombieChannelParams{
×
2181
                                        Scid:    chanIDB,
×
2182
                                        Version: int16(ProtocolV1),
×
2183
                                },
×
2184
                        )
×
2185
                        if err != nil {
×
2186
                                return fmt.Errorf("could not check if channel "+
×
2187
                                        "is zombie: %w", err)
×
2188
                        }
×
2189

2190
                        return nil
×
2191
                } else if err != nil {
×
2192
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2193
                }
×
2194

2195
                exists = true
×
2196

×
2197
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2198
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2199
                                Version:   int16(ProtocolV1),
×
2200
                                ChannelID: channel.ID,
×
2201
                                NodeID:    channel.NodeID1,
×
2202
                        },
×
2203
                )
×
2204
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2205
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2206
                                err)
×
2207
                } else if err == nil {
×
2208
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2209
                }
×
2210

2211
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2212
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2213
                                Version:   int16(ProtocolV1),
×
2214
                                ChannelID: channel.ID,
×
2215
                                NodeID:    channel.NodeID2,
×
2216
                        },
×
2217
                )
×
2218
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2219
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2220
                                err)
×
2221
                } else if err == nil {
×
2222
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2223
                }
×
2224

2225
                return nil
×
2226
        }, sqldb.NoOpReset)
2227
        if err != nil {
×
2228
                return time.Time{}, time.Time{}, false, false,
×
2229
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2230
        }
×
2231

2232
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2233
                upd1Time: node1LastUpdate.Unix(),
×
2234
                upd2Time: node2LastUpdate.Unix(),
×
2235
                flags:    packRejectFlags(exists, isZombie),
×
2236
        })
×
2237

×
2238
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2239
}
2240

2241
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2242
// passed channel point (outpoint). If the passed channel doesn't exist within
2243
// the database, then ErrEdgeNotFound is returned.
2244
//
2245
// NOTE: part of the V1Store interface.
2246
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2247
        var (
×
2248
                ctx       = context.TODO()
×
2249
                channelID uint64
×
2250
        )
×
2251
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2252
                chanID, err := db.GetSCIDByOutpoint(
×
2253
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2254
                                Outpoint: chanPoint.String(),
×
2255
                                Version:  int16(ProtocolV1),
×
2256
                        },
×
2257
                )
×
2258
                if errors.Is(err, sql.ErrNoRows) {
×
2259
                        return ErrEdgeNotFound
×
2260
                } else if err != nil {
×
2261
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2262
                                err)
×
2263
                }
×
2264

2265
                channelID = byteOrder.Uint64(chanID)
×
2266

×
2267
                return nil
×
2268
        }, sqldb.NoOpReset)
2269
        if err != nil {
×
2270
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2271
        }
×
2272

2273
        return channelID, nil
×
2274
}
2275

2276
// IsPublicNode is a helper method that determines whether the node with the
2277
// given public key is seen as a public node in the graph from the graph's
2278
// source node's point of view.
2279
//
2280
// NOTE: part of the V1Store interface.
2281
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2282
        ctx := context.TODO()
×
2283

×
2284
        var isPublic bool
×
2285
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2286
                var err error
×
2287
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2288

×
2289
                return err
×
2290
        }, sqldb.NoOpReset)
×
2291
        if err != nil {
×
2292
                return false, fmt.Errorf("unable to check if node is "+
×
2293
                        "public: %w", err)
×
2294
        }
×
2295

2296
        return isPublic, nil
×
2297
}
2298

2299
// FetchChanInfos returns the set of channel edges that correspond to the passed
2300
// channel ID's. If an edge is the query is unknown to the database, it will
2301
// skipped and the result will contain only those edges that exist at the time
2302
// of the query. This can be used to respond to peer queries that are seeking to
2303
// fill in gaps in their view of the channel graph.
2304
//
2305
// NOTE: part of the V1Store interface.
2306
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2307
        var (
×
2308
                ctx   = context.TODO()
×
2309
                edges = make(map[uint64]ChannelEdge)
×
2310
        )
×
2311
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2312
                // First, collect all channel rows.
×
2313
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2314
                chanCallBack := func(ctx context.Context,
×
2315
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2316

×
2317
                        channelRows = append(channelRows, row)
×
2318
                        return nil
×
2319
                }
×
2320

2321
                err := s.forEachChanWithPoliciesInSCIDList(
×
2322
                        ctx, db, chanCallBack, chanIDs,
×
2323
                )
×
2324
                if err != nil {
×
2325
                        return err
×
2326
                }
×
2327

2328
                if len(channelRows) == 0 {
×
2329
                        return nil
×
2330
                }
×
2331

2332
                // Batch build all channel edges.
2333
                chans, err := batchBuildChannelEdges(
×
2334
                        ctx, s.cfg, db, channelRows,
×
2335
                )
×
2336
                if err != nil {
×
2337
                        return fmt.Errorf("unable to build channel edges: %w",
×
2338
                                err)
×
2339
                }
×
2340

2341
                for _, c := range chans {
×
2342
                        edges[c.Info.ChannelID] = c
×
2343
                }
×
2344

2345
                return err
×
2346
        }, func() {
×
2347
                clear(edges)
×
2348
        })
×
2349
        if err != nil {
×
2350
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2351
        }
×
2352

2353
        res := make([]ChannelEdge, 0, len(edges))
×
2354
        for _, chanID := range chanIDs {
×
2355
                edge, ok := edges[chanID]
×
2356
                if !ok {
×
2357
                        continue
×
2358
                }
2359

2360
                res = append(res, edge)
×
2361
        }
2362

2363
        return res, nil
×
2364
}
2365

2366
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2367
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2368
// channels in a paginated manner.
2369
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2370
        db SQLQueries, cb func(ctx context.Context,
2371
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2372
        chanIDs []uint64) error {
×
2373

×
2374
        queryWrapper := func(ctx context.Context,
×
2375
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2376
                error) {
×
2377

×
2378
                return db.GetChannelsBySCIDWithPolicies(
×
2379
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2380
                                Version: int16(ProtocolV1),
×
2381
                                Scids:   scids,
×
2382
                        },
×
2383
                )
×
2384
        }
×
2385

2386
        return sqldb.ExecuteBatchQuery(
×
2387
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2388
                cb,
×
2389
        )
×
2390
}
2391

2392
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2393
// ID's that we don't know and are not known zombies of the passed set. In other
2394
// words, we perform a set difference of our set of chan ID's and the ones
2395
// passed in. This method can be used by callers to determine the set of
2396
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2397
// known zombies is also returned.
2398
//
2399
// NOTE: part of the V1Store interface.
2400
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2401
        []ChannelUpdateInfo, error) {
×
2402

×
2403
        var (
×
2404
                ctx          = context.TODO()
×
2405
                newChanIDs   []uint64
×
2406
                knownZombies []ChannelUpdateInfo
×
2407
                infoLookup   = make(
×
2408
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2409
                )
×
2410
        )
×
2411

×
2412
        // We first build a lookup map of the channel ID's to the
×
2413
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2414
        // already know about.
×
2415
        for _, chanInfo := range chansInfo {
×
2416
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2417
        }
×
2418

2419
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2420
                // The call-back function deletes known channels from
×
2421
                // infoLookup, so that we can later check which channels are
×
2422
                // zombies by only looking at the remaining channels in the set.
×
2423
                cb := func(ctx context.Context,
×
2424
                        channel sqlc.GraphChannel) error {
×
2425

×
2426
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2427

×
2428
                        return nil
×
2429
                }
×
2430

2431
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2432
                if err != nil {
×
2433
                        return fmt.Errorf("unable to iterate through "+
×
2434
                                "channels: %w", err)
×
2435
                }
×
2436

2437
                // We want to ensure that we deal with the channels in the
2438
                // same order that they were passed in, so we iterate over the
2439
                // original chansInfo slice and then check if that channel is
2440
                // still in the infoLookup map.
2441
                for _, chanInfo := range chansInfo {
×
2442
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2443
                        if _, ok := infoLookup[channelID]; !ok {
×
2444
                                continue
×
2445
                        }
2446

2447
                        isZombie, err := db.IsZombieChannel(
×
2448
                                ctx, sqlc.IsZombieChannelParams{
×
2449
                                        Scid:    channelIDToBytes(channelID),
×
2450
                                        Version: int16(ProtocolV1),
×
2451
                                },
×
2452
                        )
×
2453
                        if err != nil {
×
2454
                                return fmt.Errorf("unable to fetch zombie "+
×
2455
                                        "channel: %w", err)
×
2456
                        }
×
2457

2458
                        if isZombie {
×
2459
                                knownZombies = append(knownZombies, chanInfo)
×
2460

×
2461
                                continue
×
2462
                        }
2463

2464
                        newChanIDs = append(newChanIDs, channelID)
×
2465
                }
2466

2467
                return nil
×
2468
        }, func() {
×
2469
                newChanIDs = nil
×
2470
                knownZombies = nil
×
2471
                // Rebuild the infoLookup map in case of a rollback.
×
2472
                for _, chanInfo := range chansInfo {
×
2473
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2474
                        infoLookup[scid] = chanInfo
×
2475
                }
×
2476
        })
2477
        if err != nil {
×
2478
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2479
        }
×
2480

2481
        return newChanIDs, knownZombies, nil
×
2482
}
2483

2484
// forEachChanInSCIDList is a helper method that executes a paged query
2485
// against the database to fetch all channels that match the passed
2486
// ChannelUpdateInfo slice. The callback function is called for each channel
2487
// that is found.
2488
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2489
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2490
        chansInfo []ChannelUpdateInfo) error {
×
2491

×
2492
        queryWrapper := func(ctx context.Context,
×
2493
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2494

×
2495
                return db.GetChannelsBySCIDs(
×
2496
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2497
                                Version: int16(ProtocolV1),
×
2498
                                Scids:   scids,
×
2499
                        },
×
2500
                )
×
2501
        }
×
2502

2503
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2504
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2505

×
2506
                return channelIDToBytes(channelID)
×
2507
        }
×
2508

2509
        return sqldb.ExecuteBatchQuery(
×
2510
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2511
                cb,
×
2512
        )
×
2513
}
2514

2515
// PruneGraphNodes is a garbage collection method which attempts to prune out
2516
// any nodes from the channel graph that are currently unconnected. This ensure
2517
// that we only maintain a graph of reachable nodes. In the event that a pruned
2518
// node gains more channels, it will be re-added back to the graph.
2519
//
2520
// NOTE: this prunes nodes across protocol versions. It will never prune the
2521
// source nodes.
2522
//
2523
// NOTE: part of the V1Store interface.
2524
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2525
        var ctx = context.TODO()
×
2526

×
2527
        var prunedNodes []route.Vertex
×
2528
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2529
                var err error
×
2530
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2531

×
2532
                return err
×
2533
        }, func() {
×
2534
                prunedNodes = nil
×
2535
        })
×
2536
        if err != nil {
×
2537
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2538
        }
×
2539

2540
        return prunedNodes, nil
×
2541
}
2542

2543
// PruneGraph prunes newly closed channels from the channel graph in response
2544
// to a new block being solved on the network. Any transactions which spend the
2545
// funding output of any known channels within he graph will be deleted.
2546
// Additionally, the "prune tip", or the last block which has been used to
2547
// prune the graph is stored so callers can ensure the graph is fully in sync
2548
// with the current UTXO state. A slice of channels that have been closed by
2549
// the target block along with any pruned nodes are returned if the function
2550
// succeeds without error.
2551
//
2552
// NOTE: part of the V1Store interface.
2553
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2554
        blockHash *chainhash.Hash, blockHeight uint32) (
2555
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2556

×
2557
        ctx := context.TODO()
×
2558

×
2559
        s.cacheMu.Lock()
×
2560
        defer s.cacheMu.Unlock()
×
2561

×
2562
        var (
×
2563
                closedChans []*models.ChannelEdgeInfo
×
2564
                prunedNodes []route.Vertex
×
2565
        )
×
2566
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2567
                // First, collect all channel rows that need to be pruned.
×
2568
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2569
                channelCallback := func(ctx context.Context,
×
2570
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2571

×
2572
                        channelRows = append(channelRows, row)
×
2573

×
2574
                        return nil
×
2575
                }
×
2576

2577
                err := s.forEachChanInOutpoints(
×
2578
                        ctx, db, spentOutputs, channelCallback,
×
2579
                )
×
2580
                if err != nil {
×
2581
                        return fmt.Errorf("unable to fetch channels by "+
×
2582
                                "outpoints: %w", err)
×
2583
                }
×
2584

2585
                if len(channelRows) == 0 {
×
2586
                        // There are no channels to prune. So we can exit early
×
2587
                        // after updating the prune log.
×
2588
                        err = db.UpsertPruneLogEntry(
×
2589
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
2590
                                        BlockHash:   blockHash[:],
×
2591
                                        BlockHeight: int64(blockHeight),
×
2592
                                },
×
2593
                        )
×
2594
                        if err != nil {
×
2595
                                return fmt.Errorf("unable to insert prune log "+
×
2596
                                        "entry: %w", err)
×
2597
                        }
×
2598

2599
                        return nil
×
2600
                }
2601

2602
                // Batch build all channel edges for pruning.
2603
                var chansToDelete []int64
×
2604
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
2605
                        ctx, s.cfg, db, channelRows,
×
2606
                )
×
2607
                if err != nil {
×
2608
                        return err
×
2609
                }
×
2610

2611
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2612
                if err != nil {
×
2613
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2614
                }
×
2615

2616
                err = db.UpsertPruneLogEntry(
×
2617
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2618
                                BlockHash:   blockHash[:],
×
2619
                                BlockHeight: int64(blockHeight),
×
2620
                        },
×
2621
                )
×
2622
                if err != nil {
×
2623
                        return fmt.Errorf("unable to insert prune log "+
×
2624
                                "entry: %w", err)
×
2625
                }
×
2626

2627
                // Now that we've pruned some channels, we'll also prune any
2628
                // nodes that no longer have any channels.
2629
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2630
                if err != nil {
×
2631
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2632
                                err)
×
2633
                }
×
2634

2635
                return nil
×
2636
        }, func() {
×
2637
                prunedNodes = nil
×
2638
                closedChans = nil
×
2639
        })
×
2640
        if err != nil {
×
2641
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2642
        }
×
2643

2644
        for _, channel := range closedChans {
×
2645
                s.rejectCache.remove(channel.ChannelID)
×
2646
                s.chanCache.remove(channel.ChannelID)
×
2647
        }
×
2648

2649
        return closedChans, prunedNodes, nil
×
2650
}
2651

2652
// forEachChanInOutpoints is a helper function that executes a paginated
2653
// query to fetch channels by their outpoints and applies the given call-back
2654
// to each.
2655
//
2656
// NOTE: this fetches channels for all protocol versions.
2657
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2658
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2659
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2660

×
2661
        // Create a wrapper that uses the transaction's db instance to execute
×
2662
        // the query.
×
2663
        queryWrapper := func(ctx context.Context,
×
2664
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2665
                error) {
×
2666

×
2667
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2668
        }
×
2669

2670
        // Define the conversion function from Outpoint to string.
2671
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2672
                return outpoint.String()
×
2673
        }
×
2674

2675
        return sqldb.ExecuteBatchQuery(
×
2676
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2677
                queryWrapper, cb,
×
2678
        )
×
2679
}
2680

2681
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2682
        dbIDs []int64) error {
×
2683

×
2684
        // Create a wrapper that uses the transaction's db instance to execute
×
2685
        // the query.
×
2686
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2687
                return nil, db.DeleteChannels(ctx, ids)
×
2688
        }
×
2689

2690
        idConverter := func(id int64) int64 {
×
2691
                return id
×
2692
        }
×
2693

2694
        return sqldb.ExecuteBatchQuery(
×
2695
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2696
                queryWrapper, func(ctx context.Context, _ any) error {
×
2697
                        return nil
×
2698
                },
×
2699
        )
2700
}
2701

2702
// ChannelView returns the verifiable edge information for each active channel
2703
// within the known channel graph. The set of UTXOs (along with their scripts)
2704
// returned are the ones that need to be watched on chain to detect channel
2705
// closes on the resident blockchain.
2706
//
2707
// NOTE: part of the V1Store interface.
2708
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2709
        var (
×
2710
                ctx        = context.TODO()
×
2711
                edgePoints []EdgePoint
×
2712
        )
×
2713

×
2714
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2715
                handleChannel := func(_ context.Context,
×
2716
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2717

×
2718
                        pkScript, err := genMultiSigP2WSH(
×
2719
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2720
                        )
×
2721
                        if err != nil {
×
2722
                                return err
×
2723
                        }
×
2724

2725
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2726
                        if err != nil {
×
2727
                                return err
×
2728
                        }
×
2729

2730
                        edgePoints = append(edgePoints, EdgePoint{
×
2731
                                FundingPkScript: pkScript,
×
2732
                                OutPoint:        *op,
×
2733
                        })
×
2734

×
2735
                        return nil
×
2736
                }
2737

2738
                queryFunc := func(ctx context.Context, lastID int64,
×
2739
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2740

×
2741
                        return db.ListChannelsPaginated(
×
2742
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2743
                                        Version: int16(ProtocolV1),
×
2744
                                        ID:      lastID,
×
2745
                                        Limit:   limit,
×
2746
                                },
×
2747
                        )
×
2748
                }
×
2749

2750
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2751
                        return row.ID
×
2752
                }
×
2753

2754
                return sqldb.ExecutePaginatedQuery(
×
2755
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
2756
                        extractCursor, handleChannel,
×
2757
                )
×
2758
        }, func() {
×
2759
                edgePoints = nil
×
2760
        })
×
2761
        if err != nil {
×
2762
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2763
        }
×
2764

2765
        return edgePoints, nil
×
2766
}
2767

2768
// PruneTip returns the block height and hash of the latest block that has been
2769
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2770
// to tell if the graph is currently in sync with the current best known UTXO
2771
// state.
2772
//
2773
// NOTE: part of the V1Store interface.
2774
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2775
        var (
×
2776
                ctx       = context.TODO()
×
2777
                tipHash   chainhash.Hash
×
2778
                tipHeight uint32
×
2779
        )
×
2780
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2781
                pruneTip, err := db.GetPruneTip(ctx)
×
2782
                if errors.Is(err, sql.ErrNoRows) {
×
2783
                        return ErrGraphNeverPruned
×
2784
                } else if err != nil {
×
2785
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2786
                }
×
2787

2788
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2789
                tipHeight = uint32(pruneTip.BlockHeight)
×
2790

×
2791
                return nil
×
2792
        }, sqldb.NoOpReset)
2793
        if err != nil {
×
2794
                return nil, 0, err
×
2795
        }
×
2796

2797
        return &tipHash, tipHeight, nil
×
2798
}
2799

2800
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2801
//
2802
// NOTE: this prunes nodes across protocol versions. It will never prune the
2803
// source nodes.
2804
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2805
        db SQLQueries) ([]route.Vertex, error) {
×
2806

×
2807
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2808
        if err != nil {
×
2809
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2810
                        "nodes: %w", err)
×
2811
        }
×
2812

2813
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2814
        for i, nodeKey := range nodeKeys {
×
2815
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2816
                if err != nil {
×
2817
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2818
                                "from bytes: %w", err)
×
2819
                }
×
2820

2821
                prunedNodes[i] = pub
×
2822
        }
2823

2824
        return prunedNodes, nil
×
2825
}
2826

2827
// DisconnectBlockAtHeight is used to indicate that the block specified
2828
// by the passed height has been disconnected from the main chain. This
2829
// will "rewind" the graph back to the height below, deleting channels
2830
// that are no longer confirmed from the graph. The prune log will be
2831
// set to the last prune height valid for the remaining chain.
2832
// Channels that were removed from the graph resulting from the
2833
// disconnected block are returned.
2834
//
2835
// NOTE: part of the V1Store interface.
2836
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2837
        []*models.ChannelEdgeInfo, error) {
×
2838

×
2839
        ctx := context.TODO()
×
2840

×
2841
        var (
×
2842
                // Every channel having a ShortChannelID starting at 'height'
×
2843
                // will no longer be confirmed.
×
2844
                startShortChanID = lnwire.ShortChannelID{
×
2845
                        BlockHeight: height,
×
2846
                }
×
2847

×
2848
                // Delete everything after this height from the db up until the
×
2849
                // SCID alias range.
×
2850
                endShortChanID = aliasmgr.StartingAlias
×
2851

×
2852
                removedChans []*models.ChannelEdgeInfo
×
2853

×
2854
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2855
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2856
        )
×
2857

×
2858
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2859
                rows, err := db.GetChannelsBySCIDRange(
×
2860
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2861
                                StartScid: chanIDStart,
×
2862
                                EndScid:   chanIDEnd,
×
2863
                        },
×
2864
                )
×
2865
                if err != nil {
×
2866
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2867
                }
×
2868

2869
                if len(rows) == 0 {
×
2870
                        // No channels to disconnect, but still clean up prune
×
2871
                        // log.
×
2872
                        return db.DeletePruneLogEntriesInRange(
×
2873
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2874
                                        StartHeight: int64(height),
×
2875
                                        EndHeight: int64(
×
2876
                                                endShortChanID.BlockHeight,
×
2877
                                        ),
×
2878
                                },
×
2879
                        )
×
2880
                }
×
2881

2882
                // Batch build all channel edges for disconnection.
2883
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
2884
                        ctx, s.cfg, db, rows,
×
2885
                )
×
2886
                if err != nil {
×
2887
                        return err
×
2888
                }
×
2889

2890
                removedChans = channelEdges
×
2891

×
2892
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2893
                if err != nil {
×
2894
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2895
                }
×
2896

2897
                return db.DeletePruneLogEntriesInRange(
×
2898
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2899
                                StartHeight: int64(height),
×
2900
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2901
                        },
×
2902
                )
×
2903
        }, func() {
×
2904
                removedChans = nil
×
2905
        })
×
2906
        if err != nil {
×
2907
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2908
                        "height: %w", err)
×
2909
        }
×
2910

2911
        for _, channel := range removedChans {
×
2912
                s.rejectCache.remove(channel.ChannelID)
×
2913
                s.chanCache.remove(channel.ChannelID)
×
2914
        }
×
2915

2916
        return removedChans, nil
×
2917
}
2918

2919
// AddEdgeProof sets the proof of an existing edge in the graph database.
2920
//
2921
// NOTE: part of the V1Store interface.
2922
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2923
        proof *models.ChannelAuthProof) error {
×
2924

×
2925
        var (
×
2926
                ctx       = context.TODO()
×
2927
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2928
        )
×
2929

×
2930
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2931
                res, err := db.AddV1ChannelProof(
×
2932
                        ctx, sqlc.AddV1ChannelProofParams{
×
2933
                                Scid:              scidBytes,
×
2934
                                Node1Signature:    proof.NodeSig1Bytes,
×
2935
                                Node2Signature:    proof.NodeSig2Bytes,
×
2936
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2937
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2938
                        },
×
2939
                )
×
2940
                if err != nil {
×
2941
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2942
                }
×
2943

2944
                n, err := res.RowsAffected()
×
2945
                if err != nil {
×
2946
                        return err
×
2947
                }
×
2948

2949
                if n == 0 {
×
2950
                        return fmt.Errorf("no rows affected when adding edge "+
×
2951
                                "proof for SCID %v", scid)
×
2952
                } else if n > 1 {
×
2953
                        return fmt.Errorf("multiple rows affected when adding "+
×
2954
                                "edge proof for SCID %v: %d rows affected",
×
2955
                                scid, n)
×
2956
                }
×
2957

2958
                return nil
×
2959
        }, sqldb.NoOpReset)
2960
        if err != nil {
×
2961
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2962
        }
×
2963

2964
        return nil
×
2965
}
2966

2967
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2968
// that we can ignore channel announcements that we know to be closed without
2969
// having to validate them and fetch a block.
2970
//
2971
// NOTE: part of the V1Store interface.
2972
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2973
        var (
×
2974
                ctx     = context.TODO()
×
2975
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2976
        )
×
2977

×
2978
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2979
                return db.InsertClosedChannel(ctx, chanIDB)
×
2980
        }, sqldb.NoOpReset)
×
2981
}
2982

2983
// IsClosedScid checks whether a channel identified by the passed in scid is
2984
// closed. This helps avoid having to perform expensive validation checks.
2985
//
2986
// NOTE: part of the V1Store interface.
2987
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2988
        var (
×
2989
                ctx      = context.TODO()
×
2990
                isClosed bool
×
2991
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2992
        )
×
2993
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2994
                var err error
×
2995
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2996
                if err != nil {
×
2997
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2998
                                err)
×
2999
                }
×
3000

3001
                return nil
×
3002
        }, sqldb.NoOpReset)
3003
        if err != nil {
×
3004
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
3005
                        err)
×
3006
        }
×
3007

3008
        return isClosed, nil
×
3009
}
3010

3011
// GraphSession will provide the call-back with access to a NodeTraverser
3012
// instance which can be used to perform queries against the channel graph.
3013
//
3014
// NOTE: part of the V1Store interface.
3015
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
3016
        reset func()) error {
×
3017

×
3018
        var ctx = context.TODO()
×
3019

×
3020
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3021
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
3022
        }, reset)
×
3023
}
3024

3025
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
3026
// read only transaction for a consistent view of the graph.
3027
type sqlNodeTraverser struct {
3028
        db    SQLQueries
3029
        chain chainhash.Hash
3030
}
3031

3032
// A compile-time assertion to ensure that sqlNodeTraverser implements the
3033
// NodeTraverser interface.
3034
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
3035

3036
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
3037
func newSQLNodeTraverser(db SQLQueries,
3038
        chain chainhash.Hash) *sqlNodeTraverser {
×
3039

×
3040
        return &sqlNodeTraverser{
×
3041
                db:    db,
×
3042
                chain: chain,
×
3043
        }
×
3044
}
×
3045

3046
// ForEachNodeDirectedChannel calls the callback for every channel of the given
3047
// node.
3048
//
3049
// NOTE: Part of the NodeTraverser interface.
3050
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
3051
        cb func(channel *DirectedChannel) error, _ func()) error {
×
3052

×
3053
        ctx := context.TODO()
×
3054

×
3055
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
3056
}
×
3057

3058
// FetchNodeFeatures returns the features of the given node. If the node is
3059
// unknown, assume no additional features are supported.
3060
//
3061
// NOTE: Part of the NodeTraverser interface.
3062
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
3063
        *lnwire.FeatureVector, error) {
×
3064

×
3065
        ctx := context.TODO()
×
3066

×
3067
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
3068
}
×
3069

3070
// forEachNodeDirectedChannel iterates through all channels of a given
3071
// node, executing the passed callback on the directed edge representing the
3072
// channel and its incoming policy. If the node is not found, no error is
3073
// returned.
3074
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
3075
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
3076

×
3077
        toNodeCallback := func() route.Vertex {
×
3078
                return nodePub
×
3079
        }
×
3080

3081
        dbID, err := db.GetNodeIDByPubKey(
×
3082
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
3083
                        Version: int16(ProtocolV1),
×
3084
                        PubKey:  nodePub[:],
×
3085
                },
×
3086
        )
×
3087
        if errors.Is(err, sql.ErrNoRows) {
×
3088
                return nil
×
3089
        } else if err != nil {
×
3090
                return fmt.Errorf("unable to fetch node: %w", err)
×
3091
        }
×
3092

3093
        rows, err := db.ListChannelsByNodeID(
×
3094
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3095
                        Version: int16(ProtocolV1),
×
3096
                        NodeID1: dbID,
×
3097
                },
×
3098
        )
×
3099
        if err != nil {
×
3100
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3101
        }
×
3102

3103
        // Exit early if there are no channels for this node so we don't
3104
        // do the unnecessary feature fetching.
3105
        if len(rows) == 0 {
×
3106
                return nil
×
3107
        }
×
3108

3109
        features, err := getNodeFeatures(ctx, db, dbID)
×
3110
        if err != nil {
×
3111
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3112
        }
×
3113

3114
        for _, row := range rows {
×
3115
                node1, node2, err := buildNodeVertices(
×
3116
                        row.Node1Pubkey, row.Node2Pubkey,
×
3117
                )
×
3118
                if err != nil {
×
3119
                        return fmt.Errorf("unable to build node vertices: %w",
×
3120
                                err)
×
3121
                }
×
3122

3123
                edge := buildCacheableChannelInfo(
×
3124
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
3125
                        node1, node2,
×
3126
                )
×
3127

×
3128
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3129
                if err != nil {
×
3130
                        return err
×
3131
                }
×
3132

3133
                p1, p2, err := buildCachedChanPolicies(
×
3134
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3135
                )
×
3136
                if err != nil {
×
3137
                        return err
×
3138
                }
×
3139

3140
                // Determine the outgoing and incoming policy for this
3141
                // channel and node combo.
3142
                outPolicy, inPolicy := p1, p2
×
3143
                if p1 != nil && node2 == nodePub {
×
3144
                        outPolicy, inPolicy = p2, p1
×
3145
                } else if p2 != nil && node1 != nodePub {
×
3146
                        outPolicy, inPolicy = p2, p1
×
3147
                }
×
3148

3149
                var cachedInPolicy *models.CachedEdgePolicy
×
3150
                if inPolicy != nil {
×
3151
                        cachedInPolicy = inPolicy
×
3152
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3153
                        cachedInPolicy.ToNodeFeatures = features
×
3154
                }
×
3155

3156
                directedChannel := &DirectedChannel{
×
3157
                        ChannelID:    edge.ChannelID,
×
3158
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3159
                        OtherNode:    edge.NodeKey2Bytes,
×
3160
                        Capacity:     edge.Capacity,
×
3161
                        OutPolicySet: outPolicy != nil,
×
3162
                        InPolicy:     cachedInPolicy,
×
3163
                }
×
3164
                if outPolicy != nil {
×
3165
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3166
                                directedChannel.InboundFee = fee
×
3167
                        })
×
3168
                }
3169

3170
                if nodePub == edge.NodeKey2Bytes {
×
3171
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3172
                }
×
3173

3174
                if err := cb(directedChannel); err != nil {
×
3175
                        return err
×
3176
                }
×
3177
        }
3178

3179
        return nil
×
3180
}
3181

3182
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3183
// and executes the provided callback for each node. It does so via pagination
3184
// along with batch loading of the node feature bits.
3185
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
3186
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
3187
                features *lnwire.FeatureVector) error) error {
×
3188

×
3189
        handleNode := func(_ context.Context,
×
3190
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
3191
                featureBits map[int64][]int) error {
×
3192

×
3193
                fv := lnwire.EmptyFeatureVector()
×
3194
                if features, exists := featureBits[dbNode.ID]; exists {
×
3195
                        for _, bit := range features {
×
3196
                                fv.Set(lnwire.FeatureBit(bit))
×
3197
                        }
×
3198
                }
3199

3200
                var pub route.Vertex
×
3201
                copy(pub[:], dbNode.PubKey)
×
3202

×
3203
                return processNode(dbNode.ID, pub, fv)
×
3204
        }
3205

3206
        queryFunc := func(ctx context.Context, lastID int64,
×
3207
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
3208

×
3209
                return db.ListNodeIDsAndPubKeys(
×
3210
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3211
                                Version: int16(ProtocolV1),
×
3212
                                ID:      lastID,
×
3213
                                Limit:   limit,
×
3214
                        },
×
3215
                )
×
3216
        }
×
3217

3218
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
3219
                return row.ID
×
3220
        }
×
3221

3222
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
3223
                return node.ID, nil
×
3224
        }
×
3225

3226
        batchQueryFunc := func(ctx context.Context,
×
3227
                nodeIDs []int64) (map[int64][]int, error) {
×
3228

×
3229
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
3230
        }
×
3231

3232
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
3233
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
3234
                batchQueryFunc, handleNode,
×
3235
        )
×
3236
}
3237

3238
// forEachNodeChannel iterates through all channels of a node, executing
3239
// the passed callback on each. The call-back is provided with the channel's
3240
// edge information, the outgoing policy and the incoming policy for the
3241
// channel and node combo.
3242
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3243
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
3244
                *models.ChannelEdgePolicy,
3245
                *models.ChannelEdgePolicy) error) error {
×
3246

×
3247
        // Get all the V1 channels for this node.
×
3248
        rows, err := db.ListChannelsByNodeID(
×
3249
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3250
                        Version: int16(ProtocolV1),
×
3251
                        NodeID1: id,
×
3252
                },
×
3253
        )
×
3254
        if err != nil {
×
3255
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3256
        }
×
3257

3258
        // Collect all the channel and policy IDs.
3259
        var (
×
3260
                chanIDs   = make([]int64, 0, len(rows))
×
3261
                policyIDs = make([]int64, 0, 2*len(rows))
×
3262
        )
×
3263
        for _, row := range rows {
×
3264
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3265

×
3266
                if row.Policy1ID.Valid {
×
3267
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3268
                }
×
3269
                if row.Policy2ID.Valid {
×
3270
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3271
                }
×
3272
        }
3273

3274
        batchData, err := batchLoadChannelData(
×
3275
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3276
        )
×
3277
        if err != nil {
×
3278
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3279
        }
×
3280

3281
        // Call the call-back for each channel and its known policies.
3282
        for _, row := range rows {
×
3283
                node1, node2, err := buildNodeVertices(
×
3284
                        row.Node1Pubkey, row.Node2Pubkey,
×
3285
                )
×
3286
                if err != nil {
×
3287
                        return fmt.Errorf("unable to build node vertices: %w",
×
3288
                                err)
×
3289
                }
×
3290

3291
                edge, err := buildEdgeInfoWithBatchData(
×
3292
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3293
                        batchData,
×
3294
                )
×
3295
                if err != nil {
×
3296
                        return fmt.Errorf("unable to build channel info: %w",
×
3297
                                err)
×
3298
                }
×
3299

3300
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3301
                if err != nil {
×
3302
                        return fmt.Errorf("unable to extract channel "+
×
3303
                                "policies: %w", err)
×
3304
                }
×
3305

3306
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3307
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3308
                )
×
3309
                if err != nil {
×
3310
                        return fmt.Errorf("unable to build channel "+
×
3311
                                "policies: %w", err)
×
3312
                }
×
3313

3314
                // Determine the outgoing and incoming policy for this
3315
                // channel and node combo.
3316
                p1ToNode := row.GraphChannel.NodeID2
×
3317
                p2ToNode := row.GraphChannel.NodeID1
×
3318
                outPolicy, inPolicy := p1, p2
×
3319
                if (p1 != nil && p1ToNode == id) ||
×
3320
                        (p2 != nil && p2ToNode != id) {
×
3321

×
3322
                        outPolicy, inPolicy = p2, p1
×
3323
                }
×
3324

3325
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3326
                        return err
×
3327
                }
×
3328
        }
3329

3330
        return nil
×
3331
}
3332

3333
// updateChanEdgePolicy upserts the channel policy info we have stored for
3334
// a channel we already know of.
3335
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3336
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3337
        error) {
×
3338

×
3339
        var (
×
3340
                node1Pub, node2Pub route.Vertex
×
3341
                isNode1            bool
×
3342
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3343
        )
×
3344

×
3345
        // Check that this edge policy refers to a channel that we already
×
3346
        // know of. We do this explicitly so that we can return the appropriate
×
3347
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3348
        // abort the transaction which would abort the entire batch.
×
3349
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3350
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3351
                        Scid:    chanIDB,
×
3352
                        Version: int16(ProtocolV1),
×
3353
                },
×
3354
        )
×
3355
        if errors.Is(err, sql.ErrNoRows) {
×
3356
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3357
        } else if err != nil {
×
3358
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3359
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3360
        }
×
3361

3362
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3363
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3364

×
3365
        // Figure out which node this edge is from.
×
3366
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3367
        nodeID := dbChan.NodeID1
×
3368
        if !isNode1 {
×
3369
                nodeID = dbChan.NodeID2
×
3370
        }
×
3371

3372
        var (
×
3373
                inboundBase sql.NullInt64
×
3374
                inboundRate sql.NullInt64
×
3375
        )
×
3376
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3377
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3378
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3379
        })
×
3380

3381
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3382
                Version:     int16(ProtocolV1),
×
3383
                ChannelID:   dbChan.ID,
×
3384
                NodeID:      nodeID,
×
3385
                Timelock:    int32(edge.TimeLockDelta),
×
3386
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3387
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3388
                MinHtlcMsat: int64(edge.MinHTLC),
×
3389
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3390
                Disabled: sql.NullBool{
×
3391
                        Valid: true,
×
3392
                        Bool:  edge.IsDisabled(),
×
3393
                },
×
3394
                MaxHtlcMsat: sql.NullInt64{
×
3395
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3396
                        Int64: int64(edge.MaxHTLC),
×
3397
                },
×
3398
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3399
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3400
                InboundBaseFeeMsat:      inboundBase,
×
3401
                InboundFeeRateMilliMsat: inboundRate,
×
3402
                Signature:               edge.SigBytes,
×
3403
        })
×
3404
        if err != nil {
×
3405
                return node1Pub, node2Pub, isNode1,
×
3406
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3407
        }
×
3408

3409
        // Convert the flat extra opaque data into a map of TLV types to
3410
        // values.
3411
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3412
        if err != nil {
×
3413
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3414
                        "marshal extra opaque data: %w", err)
×
3415
        }
×
3416

3417
        // Update the channel policy's extra signed fields.
3418
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3419
        if err != nil {
×
3420
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3421
                        "policy extra TLVs: %w", err)
×
3422
        }
×
3423

3424
        return node1Pub, node2Pub, isNode1, nil
×
3425
}
3426

3427
// getNodeByPubKey attempts to look up a target node by its public key.
3428
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3429
        pubKey route.Vertex) (int64, *models.Node, error) {
×
3430

×
3431
        dbNode, err := db.GetNodeByPubKey(
×
3432
                ctx, sqlc.GetNodeByPubKeyParams{
×
3433
                        Version: int16(ProtocolV1),
×
3434
                        PubKey:  pubKey[:],
×
3435
                },
×
3436
        )
×
3437
        if errors.Is(err, sql.ErrNoRows) {
×
3438
                return 0, nil, ErrGraphNodeNotFound
×
3439
        } else if err != nil {
×
3440
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3441
        }
×
3442

3443
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3444
        if err != nil {
×
3445
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3446
        }
×
3447

3448
        return dbNode.ID, node, nil
×
3449
}
3450

3451
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3452
// provided parameters.
3453
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3454
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3455

×
3456
        return &models.CachedEdgeInfo{
×
3457
                ChannelID:     byteOrder.Uint64(scid),
×
3458
                NodeKey1Bytes: node1Pub,
×
3459
                NodeKey2Bytes: node2Pub,
×
3460
                Capacity:      btcutil.Amount(capacity),
×
3461
        }
×
3462
}
×
3463

3464
// buildNode constructs a Node instance from the given database node
3465
// record. The node's features, addresses and extra signed fields are also
3466
// fetched from the database and set on the node.
3467
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3468
        dbNode sqlc.GraphNode) (*models.Node, error) {
×
3469

×
3470
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3471
        if err != nil {
×
3472
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3473
                        err)
×
3474
        }
×
3475

3476
        return buildNodeWithBatchData(dbNode, data)
×
3477
}
3478

3479
// buildNodeWithBatchData builds a models.Node instance
3480
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3481
// features/addresses/extra fields, then the corresponding fields are expected
3482
// to be present in the batchNodeData.
3483
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3484
        batchData *batchNodeData) (*models.Node, error) {
×
3485

×
3486
        if dbNode.Version != int16(ProtocolV1) {
×
3487
                return nil, fmt.Errorf("unsupported node version: %d",
×
3488
                        dbNode.Version)
×
3489
        }
×
3490

3491
        var pub [33]byte
×
3492
        copy(pub[:], dbNode.PubKey)
×
3493

×
3494
        node := &models.Node{
×
3495
                PubKeyBytes: pub,
×
3496
                Features:    lnwire.EmptyFeatureVector(),
×
3497
                LastUpdate:  time.Unix(0, 0),
×
3498
        }
×
3499

×
3500
        if len(dbNode.Signature) == 0 {
×
3501
                return node, nil
×
3502
        }
×
3503

3504
        node.HaveNodeAnnouncement = true
×
3505
        node.AuthSigBytes = dbNode.Signature
×
3506
        node.Alias = dbNode.Alias.String
×
3507
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3508

×
3509
        var err error
×
3510
        if dbNode.Color.Valid {
×
3511
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3512
                if err != nil {
×
3513
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3514
                                err)
×
3515
                }
×
3516
        }
3517

3518
        // Use preloaded features.
3519
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3520
                fv := lnwire.EmptyFeatureVector()
×
3521
                for _, bit := range features {
×
3522
                        fv.Set(lnwire.FeatureBit(bit))
×
3523
                }
×
3524
                node.Features = fv
×
3525
        }
3526

3527
        // Use preloaded addresses.
3528
        addresses, exists := batchData.addresses[dbNode.ID]
×
3529
        if exists && len(addresses) > 0 {
×
3530
                node.Addresses, err = buildNodeAddresses(addresses)
×
3531
                if err != nil {
×
3532
                        return nil, fmt.Errorf("unable to build addresses "+
×
3533
                                "for node(%d): %w", dbNode.ID, err)
×
3534
                }
×
3535
        }
3536

3537
        // Use preloaded extra fields.
3538
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3539
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3540
                if err != nil {
×
3541
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3542
                                "signed fields: %w", err)
×
3543
                }
×
3544
                if len(recs) != 0 {
×
3545
                        node.ExtraOpaqueData = recs
×
3546
                }
×
3547
        }
3548

3549
        return node, nil
×
3550
}
3551

3552
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3553
// with the preloaded data, and executes the provided callback for each node.
3554
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3555
        db SQLQueries, nodes []sqlc.GraphNode,
3556
        cb func(dbID int64, node *models.Node) error) error {
×
3557

×
3558
        // Extract node IDs for batch loading.
×
3559
        nodeIDs := make([]int64, len(nodes))
×
3560
        for i, node := range nodes {
×
3561
                nodeIDs[i] = node.ID
×
3562
        }
×
3563

3564
        // Batch load all related data for this page.
3565
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3566
        if err != nil {
×
3567
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3568
        }
×
3569

3570
        for _, dbNode := range nodes {
×
3571
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
3572
                if err != nil {
×
3573
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3574
                                dbNode.ID, err)
×
3575
                }
×
3576

3577
                if err := cb(dbNode.ID, node); err != nil {
×
3578
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3579
                                dbNode.ID, err)
×
3580
                }
×
3581
        }
3582

3583
        return nil
×
3584
}
3585

3586
// getNodeFeatures fetches the feature bits and constructs the feature vector
3587
// for a node with the given DB ID.
3588
func getNodeFeatures(ctx context.Context, db SQLQueries,
3589
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3590

×
3591
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3592
        if err != nil {
×
3593
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3594
                        nodeID, err)
×
3595
        }
×
3596

3597
        features := lnwire.EmptyFeatureVector()
×
3598
        for _, feature := range rows {
×
3599
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3600
        }
×
3601

3602
        return features, nil
×
3603
}
3604

3605
// upsertNode upserts the node record into the database. If the node already
3606
// exists, then the node's information is updated. If the node doesn't exist,
3607
// then a new node is created. The node's features, addresses and extra TLV
3608
// types are also updated. The node's DB ID is returned.
3609
func upsertNode(ctx context.Context, db SQLQueries,
3610
        node *models.Node) (int64, error) {
×
3611

×
3612
        params := sqlc.UpsertNodeParams{
×
3613
                Version: int16(ProtocolV1),
×
3614
                PubKey:  node.PubKeyBytes[:],
×
3615
        }
×
3616

×
3617
        if node.HaveNodeAnnouncement {
×
3618
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3619
                params.Color = sqldb.SQLStrValid(EncodeHexColor(node.Color))
×
3620
                params.Alias = sqldb.SQLStrValid(node.Alias)
×
3621
                params.Signature = node.AuthSigBytes
×
3622
        }
×
3623

3624
        nodeID, err := db.UpsertNode(ctx, params)
×
3625
        if err != nil {
×
3626
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3627
                        err)
×
3628
        }
×
3629

3630
        // We can exit here if we don't have the announcement yet.
3631
        if !node.HaveNodeAnnouncement {
×
3632
                return nodeID, nil
×
3633
        }
×
3634

3635
        // Update the node's features.
3636
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3637
        if err != nil {
×
3638
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3639
        }
×
3640

3641
        // Update the node's addresses.
3642
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3643
        if err != nil {
×
3644
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3645
        }
×
3646

3647
        // Convert the flat extra opaque data into a map of TLV types to
3648
        // values.
3649
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3650
        if err != nil {
×
3651
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3652
                        err)
×
3653
        }
×
3654

3655
        // Update the node's extra signed fields.
3656
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3657
        if err != nil {
×
3658
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3659
        }
×
3660

3661
        return nodeID, nil
×
3662
}
3663

3664
// upsertNodeFeatures updates the node's features node_features table. This
3665
// includes deleting any feature bits no longer present and inserting any new
3666
// feature bits. If the feature bit does not yet exist in the features table,
3667
// then an entry is created in that table first.
3668
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3669
        features *lnwire.FeatureVector) error {
×
3670

×
3671
        // Get any existing features for the node.
×
3672
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3673
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3674
                return err
×
3675
        }
×
3676

3677
        // Copy the nodes latest set of feature bits.
3678
        newFeatures := make(map[int32]struct{})
×
3679
        if features != nil {
×
3680
                for feature := range features.Features() {
×
3681
                        newFeatures[int32(feature)] = struct{}{}
×
3682
                }
×
3683
        }
3684

3685
        // For any current feature that already exists in the DB, remove it from
3686
        // the in-memory map. For any existing feature that does not exist in
3687
        // the in-memory map, delete it from the database.
3688
        for _, feature := range existingFeatures {
×
3689
                // The feature is still present, so there are no updates to be
×
3690
                // made.
×
3691
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3692
                        delete(newFeatures, feature.FeatureBit)
×
3693
                        continue
×
3694
                }
3695

3696
                // The feature is no longer present, so we remove it from the
3697
                // database.
3698
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3699
                        NodeID:     nodeID,
×
3700
                        FeatureBit: feature.FeatureBit,
×
3701
                })
×
3702
                if err != nil {
×
3703
                        return fmt.Errorf("unable to delete node(%d) "+
×
3704
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3705
                                err)
×
3706
                }
×
3707
        }
3708

3709
        // Any remaining entries in newFeatures are new features that need to be
3710
        // added to the database for the first time.
3711
        for feature := range newFeatures {
×
3712
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3713
                        NodeID:     nodeID,
×
3714
                        FeatureBit: feature,
×
3715
                })
×
3716
                if err != nil {
×
3717
                        return fmt.Errorf("unable to insert node(%d) "+
×
3718
                                "feature(%v): %w", nodeID, feature, err)
×
3719
                }
×
3720
        }
3721

3722
        return nil
×
3723
}
3724

3725
// fetchNodeFeatures fetches the features for a node with the given public key.
3726
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3727
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3728

×
3729
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3730
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3731
                        PubKey:  nodePub[:],
×
3732
                        Version: int16(ProtocolV1),
×
3733
                },
×
3734
        )
×
3735
        if err != nil {
×
3736
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3737
                        nodePub, err)
×
3738
        }
×
3739

3740
        features := lnwire.EmptyFeatureVector()
×
3741
        for _, bit := range rows {
×
3742
                features.Set(lnwire.FeatureBit(bit))
×
3743
        }
×
3744

3745
        return features, nil
×
3746
}
3747

3748
// dbAddressType is an enum type that represents the different address types
3749
// that we store in the node_addresses table. The address type determines how
3750
// the address is to be serialised/deserialize.
3751
type dbAddressType uint8
3752

3753
const (
3754
        addressTypeIPv4   dbAddressType = 1
3755
        addressTypeIPv6   dbAddressType = 2
3756
        addressTypeTorV2  dbAddressType = 3
3757
        addressTypeTorV3  dbAddressType = 4
3758
        addressTypeDNS    dbAddressType = 5
3759
        addressTypeOpaque dbAddressType = math.MaxInt8
3760
)
3761

3762
// collectAddressRecords collects the addresses from the provided
3763
// net.Addr slice and returns a map of dbAddressType to a slice of address
3764
// strings.
3765
func collectAddressRecords(addresses []net.Addr) (map[dbAddressType][]string,
3766
        error) {
×
3767

×
3768
        // Copy the nodes latest set of addresses.
×
3769
        newAddresses := map[dbAddressType][]string{
×
3770
                addressTypeIPv4:   {},
×
3771
                addressTypeIPv6:   {},
×
3772
                addressTypeTorV2:  {},
×
3773
                addressTypeTorV3:  {},
×
3774
                addressTypeDNS:    {},
×
3775
                addressTypeOpaque: {},
×
3776
        }
×
3777
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3778
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3779
        }
×
3780

3781
        for _, address := range addresses {
×
3782
                switch addr := address.(type) {
×
3783
                case *net.TCPAddr:
×
3784
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3785
                                addAddr(addressTypeIPv4, addr)
×
3786
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3787
                                addAddr(addressTypeIPv6, addr)
×
3788
                        } else {
×
3789
                                return nil, fmt.Errorf("unhandled IP "+
×
3790
                                        "address: %v", addr)
×
3791
                        }
×
3792

3793
                case *tor.OnionAddr:
×
3794
                        switch len(addr.OnionService) {
×
3795
                        case tor.V2Len:
×
3796
                                addAddr(addressTypeTorV2, addr)
×
3797
                        case tor.V3Len:
×
3798
                                addAddr(addressTypeTorV3, addr)
×
3799
                        default:
×
3800
                                return nil, fmt.Errorf("invalid length for " +
×
3801
                                        "a tor address")
×
3802
                        }
3803

3804
                case *lnwire.DNSAddress:
×
3805
                        addAddr(addressTypeDNS, addr)
×
3806

3807
                case *lnwire.OpaqueAddrs:
×
3808
                        addAddr(addressTypeOpaque, addr)
×
3809

3810
                default:
×
3811
                        return nil, fmt.Errorf("unhandled address type: %T",
×
3812
                                addr)
×
3813
                }
3814
        }
3815

3816
        return newAddresses, nil
×
3817
}
3818

3819
// upsertNodeAddresses updates the node's addresses in the database. This
3820
// includes deleting any existing addresses and inserting the new set of
3821
// addresses. The deletion is necessary since the ordering of the addresses may
3822
// change, and we need to ensure that the database reflects the latest set of
3823
// addresses so that at the time of reconstructing the node announcement, the
3824
// order is preserved and the signature over the message remains valid.
3825
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3826
        addresses []net.Addr) error {
×
3827

×
3828
        // Delete any existing addresses for the node. This is required since
×
3829
        // even if the new set of addresses is the same, the ordering may have
×
3830
        // changed for a given address type.
×
3831
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3832
        if err != nil {
×
3833
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3834
                        nodeID, err)
×
3835
        }
×
3836

3837
        newAddresses, err := collectAddressRecords(addresses)
×
3838
        if err != nil {
×
3839
                return err
×
3840
        }
×
3841

3842
        // Any remaining entries in newAddresses are new addresses that need to
3843
        // be added to the database for the first time.
3844
        for addrType, addrList := range newAddresses {
×
3845
                for position, addr := range addrList {
×
3846
                        err := db.UpsertNodeAddress(
×
3847
                                ctx, sqlc.UpsertNodeAddressParams{
×
3848
                                        NodeID:   nodeID,
×
3849
                                        Type:     int16(addrType),
×
3850
                                        Address:  addr,
×
3851
                                        Position: int32(position),
×
3852
                                },
×
3853
                        )
×
3854
                        if err != nil {
×
3855
                                return fmt.Errorf("unable to insert "+
×
3856
                                        "node(%d) address(%v): %w", nodeID,
×
3857
                                        addr, err)
×
3858
                        }
×
3859
                }
3860
        }
3861

3862
        return nil
×
3863
}
3864

3865
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3866
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
3867
        error) {
×
3868

×
3869
        // GetNodeAddresses ensures that the addresses for a given type are
×
3870
        // returned in the same order as they were inserted.
×
3871
        rows, err := db.GetNodeAddresses(ctx, id)
×
3872
        if err != nil {
×
3873
                return nil, err
×
3874
        }
×
3875

3876
        addresses := make([]net.Addr, 0, len(rows))
×
3877
        for _, row := range rows {
×
3878
                address := row.Address
×
3879

×
3880
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
3881
                if err != nil {
×
3882
                        return nil, fmt.Errorf("unable to parse address "+
×
3883
                                "for node(%d): %v: %w", id, address, err)
×
3884
                }
×
3885

3886
                addresses = append(addresses, addr)
×
3887
        }
3888

3889
        // If we have no addresses, then we'll return nil instead of an
3890
        // empty slice.
3891
        if len(addresses) == 0 {
×
3892
                addresses = nil
×
3893
        }
×
3894

3895
        return addresses, nil
×
3896
}
3897

3898
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3899
// database. This includes updating any existing types, inserting any new types,
3900
// and deleting any types that are no longer present.
3901
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3902
        nodeID int64, extraFields map[uint64][]byte) error {
×
3903

×
3904
        // Get any existing extra signed fields for the node.
×
3905
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3906
        if err != nil {
×
3907
                return err
×
3908
        }
×
3909

3910
        // Make a lookup map of the existing field types so that we can use it
3911
        // to keep track of any fields we should delete.
3912
        m := make(map[uint64]bool)
×
3913
        for _, field := range existingFields {
×
3914
                m[uint64(field.Type)] = true
×
3915
        }
×
3916

3917
        // For all the new fields, we'll upsert them and remove them from the
3918
        // map of existing fields.
3919
        for tlvType, value := range extraFields {
×
3920
                err = db.UpsertNodeExtraType(
×
3921
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3922
                                NodeID: nodeID,
×
3923
                                Type:   int64(tlvType),
×
3924
                                Value:  value,
×
3925
                        },
×
3926
                )
×
3927
                if err != nil {
×
3928
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3929
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3930
                }
×
3931

3932
                // Remove the field from the map of existing fields if it was
3933
                // present.
3934
                delete(m, tlvType)
×
3935
        }
3936

3937
        // For all the fields that are left in the map of existing fields, we'll
3938
        // delete them as they are no longer present in the new set of fields.
3939
        for tlvType := range m {
×
3940
                err = db.DeleteExtraNodeType(
×
3941
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3942
                                NodeID: nodeID,
×
3943
                                Type:   int64(tlvType),
×
3944
                        },
×
3945
                )
×
3946
                if err != nil {
×
3947
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3948
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3949
                }
×
3950
        }
3951

3952
        return nil
×
3953
}
3954

3955
// srcNodeInfo holds the information about the source node of the graph.
3956
type srcNodeInfo struct {
3957
        // id is the DB level ID of the source node entry in the "nodes" table.
3958
        id int64
3959

3960
        // pub is the public key of the source node.
3961
        pub route.Vertex
3962
}
3963

3964
// sourceNode returns the DB node ID and pub key of the source node for the
3965
// specified protocol version.
3966
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3967
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3968

×
3969
        s.srcNodeMu.Lock()
×
3970
        defer s.srcNodeMu.Unlock()
×
3971

×
3972
        // If we already have the source node ID and pub key cached, then
×
3973
        // return them.
×
3974
        if info, ok := s.srcNodes[version]; ok {
×
3975
                return info.id, info.pub, nil
×
3976
        }
×
3977

3978
        var pubKey route.Vertex
×
3979

×
3980
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3981
        if err != nil {
×
3982
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3983
                        err)
×
3984
        }
×
3985

3986
        if len(nodes) == 0 {
×
3987
                return 0, pubKey, ErrSourceNodeNotSet
×
3988
        } else if len(nodes) > 1 {
×
3989
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3990
                        "protocol %s found", version)
×
3991
        }
×
3992

3993
        copy(pubKey[:], nodes[0].PubKey)
×
3994

×
3995
        s.srcNodes[version] = &srcNodeInfo{
×
3996
                id:  nodes[0].NodeID,
×
3997
                pub: pubKey,
×
3998
        }
×
3999

×
4000
        return nodes[0].NodeID, pubKey, nil
×
4001
}
4002

4003
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
4004
// This then produces a map from TLV type to value. If the input is not a
4005
// valid TLV stream, then an error is returned.
4006
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
4007
        r := bytes.NewReader(data)
×
4008

×
4009
        tlvStream, err := tlv.NewStream()
×
4010
        if err != nil {
×
4011
                return nil, err
×
4012
        }
×
4013

4014
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
4015
        // pass it into the P2P decoding variant.
4016
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
4017
        if err != nil {
×
4018
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
4019
        }
×
4020
        if len(parsedTypes) == 0 {
×
4021
                return nil, nil
×
4022
        }
×
4023

4024
        records := make(map[uint64][]byte)
×
4025
        for k, v := range parsedTypes {
×
4026
                records[uint64(k)] = v
×
4027
        }
×
4028

4029
        return records, nil
×
4030
}
4031

4032
// insertChannel inserts a new channel record into the database.
4033
func insertChannel(ctx context.Context, db SQLQueries,
4034
        edge *models.ChannelEdgeInfo) error {
×
4035

×
4036
        // Make sure that at least a "shell" entry for each node is present in
×
4037
        // the nodes table.
×
4038
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
4039
        if err != nil {
×
4040
                return fmt.Errorf("unable to create shell node: %w", err)
×
4041
        }
×
4042

4043
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
4044
        if err != nil {
×
4045
                return fmt.Errorf("unable to create shell node: %w", err)
×
4046
        }
×
4047

4048
        var capacity sql.NullInt64
×
4049
        if edge.Capacity != 0 {
×
4050
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
4051
        }
×
4052

4053
        createParams := sqlc.CreateChannelParams{
×
4054
                Version:     int16(ProtocolV1),
×
4055
                Scid:        channelIDToBytes(edge.ChannelID),
×
4056
                NodeID1:     node1DBID,
×
4057
                NodeID2:     node2DBID,
×
4058
                Outpoint:    edge.ChannelPoint.String(),
×
4059
                Capacity:    capacity,
×
4060
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
4061
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
4062
        }
×
4063

×
4064
        if edge.AuthProof != nil {
×
4065
                proof := edge.AuthProof
×
4066

×
4067
                createParams.Node1Signature = proof.NodeSig1Bytes
×
4068
                createParams.Node2Signature = proof.NodeSig2Bytes
×
4069
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
4070
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
4071
        }
×
4072

4073
        // Insert the new channel record.
4074
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
4075
        if err != nil {
×
4076
                return err
×
4077
        }
×
4078

4079
        // Insert any channel features.
4080
        for feature := range edge.Features.Features() {
×
4081
                err = db.InsertChannelFeature(
×
4082
                        ctx, sqlc.InsertChannelFeatureParams{
×
4083
                                ChannelID:  dbChanID,
×
4084
                                FeatureBit: int32(feature),
×
4085
                        },
×
4086
                )
×
4087
                if err != nil {
×
4088
                        return fmt.Errorf("unable to insert channel(%d) "+
×
4089
                                "feature(%v): %w", dbChanID, feature, err)
×
4090
                }
×
4091
        }
4092

4093
        // Finally, insert any extra TLV fields in the channel announcement.
4094
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
4095
        if err != nil {
×
4096
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
4097
                        err)
×
4098
        }
×
4099

4100
        for tlvType, value := range extra {
×
4101
                err := db.UpsertChannelExtraType(
×
4102
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
4103
                                ChannelID: dbChanID,
×
4104
                                Type:      int64(tlvType),
×
4105
                                Value:     value,
×
4106
                        },
×
4107
                )
×
4108
                if err != nil {
×
4109
                        return fmt.Errorf("unable to upsert channel(%d) "+
×
4110
                                "extra signed field(%v): %w", edge.ChannelID,
×
4111
                                tlvType, err)
×
4112
                }
×
4113
        }
4114

4115
        return nil
×
4116
}
4117

4118
// maybeCreateShellNode checks if a shell node entry exists for the
4119
// given public key. If it does not exist, then a new shell node entry is
4120
// created. The ID of the node is returned. A shell node only has a protocol
4121
// version and public key persisted.
4122
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
4123
        pubKey route.Vertex) (int64, error) {
×
4124

×
4125
        dbNode, err := db.GetNodeByPubKey(
×
4126
                ctx, sqlc.GetNodeByPubKeyParams{
×
4127
                        PubKey:  pubKey[:],
×
4128
                        Version: int16(ProtocolV1),
×
4129
                },
×
4130
        )
×
4131
        // The node exists. Return the ID.
×
4132
        if err == nil {
×
4133
                return dbNode.ID, nil
×
4134
        } else if !errors.Is(err, sql.ErrNoRows) {
×
4135
                return 0, err
×
4136
        }
×
4137

4138
        // Otherwise, the node does not exist, so we create a shell entry for
4139
        // it.
4140
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
4141
                Version: int16(ProtocolV1),
×
4142
                PubKey:  pubKey[:],
×
4143
        })
×
4144
        if err != nil {
×
4145
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4146
        }
×
4147

4148
        return id, nil
×
4149
}
4150

4151
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4152
// the database. This includes deleting any existing types and then inserting
4153
// the new types.
4154
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4155
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4156

×
4157
        // Delete all existing extra signed fields for the channel policy.
×
4158
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4159
        if err != nil {
×
4160
                return fmt.Errorf("unable to delete "+
×
4161
                        "existing policy extra signed fields for policy %d: %w",
×
4162
                        chanPolicyID, err)
×
4163
        }
×
4164

4165
        // Insert all new extra signed fields for the channel policy.
4166
        for tlvType, value := range extraFields {
×
4167
                err = db.UpsertChanPolicyExtraType(
×
4168
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
4169
                                ChannelPolicyID: chanPolicyID,
×
4170
                                Type:            int64(tlvType),
×
4171
                                Value:           value,
×
4172
                        },
×
4173
                )
×
4174
                if err != nil {
×
4175
                        return fmt.Errorf("unable to insert "+
×
4176
                                "channel_policy(%d) extra signed field(%v): %w",
×
4177
                                chanPolicyID, tlvType, err)
×
4178
                }
×
4179
        }
4180

4181
        return nil
×
4182
}
4183

4184
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4185
// provided dbChanRow and also fetches any other required information
4186
// to construct the edge info.
4187
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
4188
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
4189
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4190

×
4191
        data, err := batchLoadChannelData(
×
4192
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
4193
        )
×
4194
        if err != nil {
×
4195
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
4196
                        err)
×
4197
        }
×
4198

4199
        return buildEdgeInfoWithBatchData(
×
4200
                cfg.ChainHash, dbChan, node1, node2, data,
×
4201
        )
×
4202
}
4203

4204
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
4205
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
4206
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
4207
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
4208

×
4209
        if dbChan.Version != int16(ProtocolV1) {
×
4210
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4211
                        dbChan.Version)
×
4212
        }
×
4213

4214
        // Use pre-loaded features and extras types.
4215
        fv := lnwire.EmptyFeatureVector()
×
4216
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
4217
                for _, bit := range features {
×
4218
                        fv.Set(lnwire.FeatureBit(bit))
×
4219
                }
×
4220
        }
4221

4222
        var extras map[uint64][]byte
×
4223
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
4224
        if exists {
×
4225
                extras = channelExtras
×
4226
        } else {
×
4227
                extras = make(map[uint64][]byte)
×
4228
        }
×
4229

4230
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4231
        if err != nil {
×
4232
                return nil, err
×
4233
        }
×
4234

4235
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4236
        if err != nil {
×
4237
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4238
                        "fields: %w", err)
×
4239
        }
×
4240
        if recs == nil {
×
4241
                recs = make([]byte, 0)
×
4242
        }
×
4243

4244
        var btcKey1, btcKey2 route.Vertex
×
4245
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4246
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4247

×
4248
        channel := &models.ChannelEdgeInfo{
×
4249
                ChainHash:        chain,
×
4250
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4251
                NodeKey1Bytes:    node1,
×
4252
                NodeKey2Bytes:    node2,
×
4253
                BitcoinKey1Bytes: btcKey1,
×
4254
                BitcoinKey2Bytes: btcKey2,
×
4255
                ChannelPoint:     *op,
×
4256
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4257
                Features:         fv,
×
4258
                ExtraOpaqueData:  recs,
×
4259
        }
×
4260

×
4261
        // We always set all the signatures at the same time, so we can
×
4262
        // safely check if one signature is present to determine if we have the
×
4263
        // rest of the signatures for the auth proof.
×
4264
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4265
                channel.AuthProof = &models.ChannelAuthProof{
×
4266
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4267
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4268
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4269
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4270
                }
×
4271
        }
×
4272

4273
        return channel, nil
×
4274
}
4275

4276
// buildNodeVertices is a helper that converts raw node public keys
4277
// into route.Vertex instances.
4278
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4279
        route.Vertex, error) {
×
4280

×
4281
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4282
        if err != nil {
×
4283
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4284
                        "create vertex from node1 pubkey: %w", err)
×
4285
        }
×
4286

4287
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4288
        if err != nil {
×
4289
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4290
                        "create vertex from node2 pubkey: %w", err)
×
4291
        }
×
4292

4293
        return node1Vertex, node2Vertex, nil
×
4294
}
4295

4296
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4297
// retrieves all the extra info required to build the complete
4298
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4299
// the provided sqlc.GraphChannelPolicy records are nil.
4300
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
4301
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4302
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
4303
        *models.ChannelEdgePolicy, error) {
×
4304

×
4305
        if dbPol1 == nil && dbPol2 == nil {
×
4306
                return nil, nil, nil
×
4307
        }
×
4308

4309
        var policyIDs = make([]int64, 0, 2)
×
4310
        if dbPol1 != nil {
×
4311
                policyIDs = append(policyIDs, dbPol1.ID)
×
4312
        }
×
4313
        if dbPol2 != nil {
×
4314
                policyIDs = append(policyIDs, dbPol2.ID)
×
4315
        }
×
4316

4317
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4318
        if err != nil {
×
4319
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4320
                        "data: %w", err)
×
4321
        }
×
4322

4323
        pol1, err := buildChanPolicyWithBatchData(
×
4324
                dbPol1, channelID, node2, batchData,
×
4325
        )
×
4326
        if err != nil {
×
4327
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4328
        }
×
4329

4330
        pol2, err := buildChanPolicyWithBatchData(
×
4331
                dbPol2, channelID, node1, batchData,
×
4332
        )
×
4333
        if err != nil {
×
4334
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4335
        }
×
4336

4337
        return pol1, pol2, nil
×
4338
}
4339

4340
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4341
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4342
// then nil is returned for it.
4343
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4344
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4345
        *models.CachedEdgePolicy, error) {
×
4346

×
4347
        var p1, p2 *models.CachedEdgePolicy
×
4348
        if dbPol1 != nil {
×
4349
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
4350
                if err != nil {
×
4351
                        return nil, nil, err
×
4352
                }
×
4353

4354
                p1 = models.NewCachedPolicy(policy1)
×
4355
        }
4356
        if dbPol2 != nil {
×
4357
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4358
                if err != nil {
×
4359
                        return nil, nil, err
×
4360
                }
×
4361

4362
                p2 = models.NewCachedPolicy(policy2)
×
4363
        }
4364

4365
        return p1, p2, nil
×
4366
}
4367

4368
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4369
// provided sqlc.GraphChannelPolicy and other required information.
4370
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4371
        extras map[uint64][]byte,
4372
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4373

×
4374
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4375
        if err != nil {
×
4376
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4377
                        "fields: %w", err)
×
4378
        }
×
4379

4380
        var inboundFee fn.Option[lnwire.Fee]
×
4381
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4382
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4383

×
4384
                inboundFee = fn.Some(lnwire.Fee{
×
4385
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4386
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4387
                })
×
4388
        }
×
4389

4390
        return &models.ChannelEdgePolicy{
×
4391
                SigBytes:  dbPolicy.Signature,
×
4392
                ChannelID: channelID,
×
4393
                LastUpdate: time.Unix(
×
4394
                        dbPolicy.LastUpdate.Int64, 0,
×
4395
                ),
×
4396
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4397
                        dbPolicy.MessageFlags,
×
4398
                ),
×
4399
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4400
                        dbPolicy.ChannelFlags,
×
4401
                ),
×
4402
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4403
                MinHTLC: lnwire.MilliSatoshi(
×
4404
                        dbPolicy.MinHtlcMsat,
×
4405
                ),
×
4406
                MaxHTLC: lnwire.MilliSatoshi(
×
4407
                        dbPolicy.MaxHtlcMsat.Int64,
×
4408
                ),
×
4409
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4410
                        dbPolicy.BaseFeeMsat,
×
4411
                ),
×
4412
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4413
                ToNode:                    toNode,
×
4414
                InboundFee:                inboundFee,
×
4415
                ExtraOpaqueData:           recs,
×
4416
        }, nil
×
4417
}
4418

4419
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4420
// row which is expected to be a sqlc type that contains channel policy
4421
// information. It returns two policies, which may be nil if the policy
4422
// information is not present in the row.
4423
//
4424
//nolint:ll,dupl,funlen
4425
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4426
        *sqlc.GraphChannelPolicy, error) {
×
4427

×
4428
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4429
        switch r := row.(type) {
×
4430
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4431
                if r.Policy1Timelock.Valid {
×
4432
                        policy1 = &sqlc.GraphChannelPolicy{
×
4433
                                Timelock:                r.Policy1Timelock.Int32,
×
4434
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4435
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4436
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4437
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4438
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4439
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4440
                                Disabled:                r.Policy1Disabled,
×
4441
                                MessageFlags:            r.Policy1MessageFlags,
×
4442
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4443
                        }
×
4444
                }
×
4445
                if r.Policy2Timelock.Valid {
×
4446
                        policy2 = &sqlc.GraphChannelPolicy{
×
4447
                                Timelock:                r.Policy2Timelock.Int32,
×
4448
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4449
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4450
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4451
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4452
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4453
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4454
                                Disabled:                r.Policy2Disabled,
×
4455
                                MessageFlags:            r.Policy2MessageFlags,
×
4456
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4457
                        }
×
4458
                }
×
4459

4460
                return policy1, policy2, nil
×
4461

4462
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4463
                if r.Policy1ID.Valid {
×
4464
                        policy1 = &sqlc.GraphChannelPolicy{
×
4465
                                ID:                      r.Policy1ID.Int64,
×
4466
                                Version:                 r.Policy1Version.Int16,
×
4467
                                ChannelID:               r.GraphChannel.ID,
×
4468
                                NodeID:                  r.Policy1NodeID.Int64,
×
4469
                                Timelock:                r.Policy1Timelock.Int32,
×
4470
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4471
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4472
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4473
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4474
                                LastUpdate:              r.Policy1LastUpdate,
×
4475
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4476
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4477
                                Disabled:                r.Policy1Disabled,
×
4478
                                MessageFlags:            r.Policy1MessageFlags,
×
4479
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4480
                                Signature:               r.Policy1Signature,
×
4481
                        }
×
4482
                }
×
4483
                if r.Policy2ID.Valid {
×
4484
                        policy2 = &sqlc.GraphChannelPolicy{
×
4485
                                ID:                      r.Policy2ID.Int64,
×
4486
                                Version:                 r.Policy2Version.Int16,
×
4487
                                ChannelID:               r.GraphChannel.ID,
×
4488
                                NodeID:                  r.Policy2NodeID.Int64,
×
4489
                                Timelock:                r.Policy2Timelock.Int32,
×
4490
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4491
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4492
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4493
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4494
                                LastUpdate:              r.Policy2LastUpdate,
×
4495
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4496
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4497
                                Disabled:                r.Policy2Disabled,
×
4498
                                MessageFlags:            r.Policy2MessageFlags,
×
4499
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4500
                                Signature:               r.Policy2Signature,
×
4501
                        }
×
4502
                }
×
4503

4504
                return policy1, policy2, nil
×
4505

4506
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4507
                if r.Policy1ID.Valid {
×
4508
                        policy1 = &sqlc.GraphChannelPolicy{
×
4509
                                ID:                      r.Policy1ID.Int64,
×
4510
                                Version:                 r.Policy1Version.Int16,
×
4511
                                ChannelID:               r.GraphChannel.ID,
×
4512
                                NodeID:                  r.Policy1NodeID.Int64,
×
4513
                                Timelock:                r.Policy1Timelock.Int32,
×
4514
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4515
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4516
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4517
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4518
                                LastUpdate:              r.Policy1LastUpdate,
×
4519
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4520
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4521
                                Disabled:                r.Policy1Disabled,
×
4522
                                MessageFlags:            r.Policy1MessageFlags,
×
4523
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4524
                                Signature:               r.Policy1Signature,
×
4525
                        }
×
4526
                }
×
4527
                if r.Policy2ID.Valid {
×
4528
                        policy2 = &sqlc.GraphChannelPolicy{
×
4529
                                ID:                      r.Policy2ID.Int64,
×
4530
                                Version:                 r.Policy2Version.Int16,
×
4531
                                ChannelID:               r.GraphChannel.ID,
×
4532
                                NodeID:                  r.Policy2NodeID.Int64,
×
4533
                                Timelock:                r.Policy2Timelock.Int32,
×
4534
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4535
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4536
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4537
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4538
                                LastUpdate:              r.Policy2LastUpdate,
×
4539
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4540
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4541
                                Disabled:                r.Policy2Disabled,
×
4542
                                MessageFlags:            r.Policy2MessageFlags,
×
4543
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4544
                                Signature:               r.Policy2Signature,
×
4545
                        }
×
4546
                }
×
4547

4548
                return policy1, policy2, nil
×
4549

4550
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4551
                if r.Policy1ID.Valid {
×
4552
                        policy1 = &sqlc.GraphChannelPolicy{
×
4553
                                ID:                      r.Policy1ID.Int64,
×
4554
                                Version:                 r.Policy1Version.Int16,
×
4555
                                ChannelID:               r.GraphChannel.ID,
×
4556
                                NodeID:                  r.Policy1NodeID.Int64,
×
4557
                                Timelock:                r.Policy1Timelock.Int32,
×
4558
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4559
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4560
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4561
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4562
                                LastUpdate:              r.Policy1LastUpdate,
×
4563
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4564
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4565
                                Disabled:                r.Policy1Disabled,
×
4566
                                MessageFlags:            r.Policy1MessageFlags,
×
4567
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4568
                                Signature:               r.Policy1Signature,
×
4569
                        }
×
4570
                }
×
4571
                if r.Policy2ID.Valid {
×
4572
                        policy2 = &sqlc.GraphChannelPolicy{
×
4573
                                ID:                      r.Policy2ID.Int64,
×
4574
                                Version:                 r.Policy2Version.Int16,
×
4575
                                ChannelID:               r.GraphChannel.ID,
×
4576
                                NodeID:                  r.Policy2NodeID.Int64,
×
4577
                                Timelock:                r.Policy2Timelock.Int32,
×
4578
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4579
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4580
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4581
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4582
                                LastUpdate:              r.Policy2LastUpdate,
×
4583
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4584
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4585
                                Disabled:                r.Policy2Disabled,
×
4586
                                MessageFlags:            r.Policy2MessageFlags,
×
4587
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4588
                                Signature:               r.Policy2Signature,
×
4589
                        }
×
4590
                }
×
4591

4592
                return policy1, policy2, nil
×
4593

4594
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4595
                if r.Policy1ID.Valid {
×
4596
                        policy1 = &sqlc.GraphChannelPolicy{
×
4597
                                ID:                      r.Policy1ID.Int64,
×
4598
                                Version:                 r.Policy1Version.Int16,
×
4599
                                ChannelID:               r.GraphChannel.ID,
×
4600
                                NodeID:                  r.Policy1NodeID.Int64,
×
4601
                                Timelock:                r.Policy1Timelock.Int32,
×
4602
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4603
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4604
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4605
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4606
                                LastUpdate:              r.Policy1LastUpdate,
×
4607
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4608
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4609
                                Disabled:                r.Policy1Disabled,
×
4610
                                MessageFlags:            r.Policy1MessageFlags,
×
4611
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4612
                                Signature:               r.Policy1Signature,
×
4613
                        }
×
4614
                }
×
4615
                if r.Policy2ID.Valid {
×
4616
                        policy2 = &sqlc.GraphChannelPolicy{
×
4617
                                ID:                      r.Policy2ID.Int64,
×
4618
                                Version:                 r.Policy2Version.Int16,
×
4619
                                ChannelID:               r.GraphChannel.ID,
×
4620
                                NodeID:                  r.Policy2NodeID.Int64,
×
4621
                                Timelock:                r.Policy2Timelock.Int32,
×
4622
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4623
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4624
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4625
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4626
                                LastUpdate:              r.Policy2LastUpdate,
×
4627
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4628
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4629
                                Disabled:                r.Policy2Disabled,
×
4630
                                MessageFlags:            r.Policy2MessageFlags,
×
4631
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4632
                                Signature:               r.Policy2Signature,
×
4633
                        }
×
4634
                }
×
4635

4636
                return policy1, policy2, nil
×
4637

4638
        case sqlc.ListChannelsForNodeIDsRow:
×
4639
                if r.Policy1ID.Valid {
×
4640
                        policy1 = &sqlc.GraphChannelPolicy{
×
4641
                                ID:                      r.Policy1ID.Int64,
×
4642
                                Version:                 r.Policy1Version.Int16,
×
4643
                                ChannelID:               r.GraphChannel.ID,
×
4644
                                NodeID:                  r.Policy1NodeID.Int64,
×
4645
                                Timelock:                r.Policy1Timelock.Int32,
×
4646
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4647
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4648
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4649
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4650
                                LastUpdate:              r.Policy1LastUpdate,
×
4651
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4652
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4653
                                Disabled:                r.Policy1Disabled,
×
4654
                                MessageFlags:            r.Policy1MessageFlags,
×
4655
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4656
                                Signature:               r.Policy1Signature,
×
4657
                        }
×
4658
                }
×
4659
                if r.Policy2ID.Valid {
×
4660
                        policy2 = &sqlc.GraphChannelPolicy{
×
4661
                                ID:                      r.Policy2ID.Int64,
×
4662
                                Version:                 r.Policy2Version.Int16,
×
4663
                                ChannelID:               r.GraphChannel.ID,
×
4664
                                NodeID:                  r.Policy2NodeID.Int64,
×
4665
                                Timelock:                r.Policy2Timelock.Int32,
×
4666
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4667
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4668
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4669
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4670
                                LastUpdate:              r.Policy2LastUpdate,
×
4671
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4672
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4673
                                Disabled:                r.Policy2Disabled,
×
4674
                                MessageFlags:            r.Policy2MessageFlags,
×
4675
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4676
                                Signature:               r.Policy2Signature,
×
4677
                        }
×
4678
                }
×
4679

4680
                return policy1, policy2, nil
×
4681

4682
        case sqlc.ListChannelsByNodeIDRow:
×
4683
                if r.Policy1ID.Valid {
×
4684
                        policy1 = &sqlc.GraphChannelPolicy{
×
4685
                                ID:                      r.Policy1ID.Int64,
×
4686
                                Version:                 r.Policy1Version.Int16,
×
4687
                                ChannelID:               r.GraphChannel.ID,
×
4688
                                NodeID:                  r.Policy1NodeID.Int64,
×
4689
                                Timelock:                r.Policy1Timelock.Int32,
×
4690
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4691
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4692
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4693
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4694
                                LastUpdate:              r.Policy1LastUpdate,
×
4695
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4696
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4697
                                Disabled:                r.Policy1Disabled,
×
4698
                                MessageFlags:            r.Policy1MessageFlags,
×
4699
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4700
                                Signature:               r.Policy1Signature,
×
4701
                        }
×
4702
                }
×
4703
                if r.Policy2ID.Valid {
×
4704
                        policy2 = &sqlc.GraphChannelPolicy{
×
4705
                                ID:                      r.Policy2ID.Int64,
×
4706
                                Version:                 r.Policy2Version.Int16,
×
4707
                                ChannelID:               r.GraphChannel.ID,
×
4708
                                NodeID:                  r.Policy2NodeID.Int64,
×
4709
                                Timelock:                r.Policy2Timelock.Int32,
×
4710
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4711
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4712
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4713
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4714
                                LastUpdate:              r.Policy2LastUpdate,
×
4715
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4716
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4717
                                Disabled:                r.Policy2Disabled,
×
4718
                                MessageFlags:            r.Policy2MessageFlags,
×
4719
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4720
                                Signature:               r.Policy2Signature,
×
4721
                        }
×
4722
                }
×
4723

4724
                return policy1, policy2, nil
×
4725

4726
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4727
                if r.Policy1ID.Valid {
×
4728
                        policy1 = &sqlc.GraphChannelPolicy{
×
4729
                                ID:                      r.Policy1ID.Int64,
×
4730
                                Version:                 r.Policy1Version.Int16,
×
4731
                                ChannelID:               r.GraphChannel.ID,
×
4732
                                NodeID:                  r.Policy1NodeID.Int64,
×
4733
                                Timelock:                r.Policy1Timelock.Int32,
×
4734
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4735
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4736
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4737
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4738
                                LastUpdate:              r.Policy1LastUpdate,
×
4739
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4740
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4741
                                Disabled:                r.Policy1Disabled,
×
4742
                                MessageFlags:            r.Policy1MessageFlags,
×
4743
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4744
                                Signature:               r.Policy1Signature,
×
4745
                        }
×
4746
                }
×
4747
                if r.Policy2ID.Valid {
×
4748
                        policy2 = &sqlc.GraphChannelPolicy{
×
4749
                                ID:                      r.Policy2ID.Int64,
×
4750
                                Version:                 r.Policy2Version.Int16,
×
4751
                                ChannelID:               r.GraphChannel.ID,
×
4752
                                NodeID:                  r.Policy2NodeID.Int64,
×
4753
                                Timelock:                r.Policy2Timelock.Int32,
×
4754
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4755
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4756
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4757
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4758
                                LastUpdate:              r.Policy2LastUpdate,
×
4759
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4760
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4761
                                Disabled:                r.Policy2Disabled,
×
4762
                                MessageFlags:            r.Policy2MessageFlags,
×
4763
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4764
                                Signature:               r.Policy2Signature,
×
4765
                        }
×
4766
                }
×
4767

4768
                return policy1, policy2, nil
×
4769

4770
        case sqlc.GetChannelsByIDsRow:
×
4771
                if r.Policy1ID.Valid {
×
4772
                        policy1 = &sqlc.GraphChannelPolicy{
×
4773
                                ID:                      r.Policy1ID.Int64,
×
4774
                                Version:                 r.Policy1Version.Int16,
×
4775
                                ChannelID:               r.GraphChannel.ID,
×
4776
                                NodeID:                  r.Policy1NodeID.Int64,
×
4777
                                Timelock:                r.Policy1Timelock.Int32,
×
4778
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4779
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4780
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4781
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4782
                                LastUpdate:              r.Policy1LastUpdate,
×
4783
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4784
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4785
                                Disabled:                r.Policy1Disabled,
×
4786
                                MessageFlags:            r.Policy1MessageFlags,
×
4787
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4788
                                Signature:               r.Policy1Signature,
×
4789
                        }
×
4790
                }
×
4791
                if r.Policy2ID.Valid {
×
4792
                        policy2 = &sqlc.GraphChannelPolicy{
×
4793
                                ID:                      r.Policy2ID.Int64,
×
4794
                                Version:                 r.Policy2Version.Int16,
×
4795
                                ChannelID:               r.GraphChannel.ID,
×
4796
                                NodeID:                  r.Policy2NodeID.Int64,
×
4797
                                Timelock:                r.Policy2Timelock.Int32,
×
4798
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4799
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4800
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4801
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4802
                                LastUpdate:              r.Policy2LastUpdate,
×
4803
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4804
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4805
                                Disabled:                r.Policy2Disabled,
×
4806
                                MessageFlags:            r.Policy2MessageFlags,
×
4807
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4808
                                Signature:               r.Policy2Signature,
×
4809
                        }
×
4810
                }
×
4811

4812
                return policy1, policy2, nil
×
4813

4814
        default:
×
4815
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4816
                        "extractChannelPolicies: %T", r)
×
4817
        }
4818
}
4819

4820
// channelIDToBytes converts a channel ID (SCID) to a byte array
4821
// representation.
4822
func channelIDToBytes(channelID uint64) []byte {
×
4823
        var chanIDB [8]byte
×
4824
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4825

×
4826
        return chanIDB[:]
×
4827
}
×
4828

4829
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4830
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4831
        if len(addresses) == 0 {
×
4832
                return nil, nil
×
4833
        }
×
4834

4835
        result := make([]net.Addr, 0, len(addresses))
×
4836
        for _, addr := range addresses {
×
4837
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
4838
                if err != nil {
×
4839
                        return nil, fmt.Errorf("unable to parse address %s "+
×
4840
                                "of type %d: %w", addr.address, addr.addrType,
×
4841
                                err)
×
4842
                }
×
4843
                if netAddr != nil {
×
4844
                        result = append(result, netAddr)
×
4845
                }
×
4846
        }
4847

4848
        // If we have no valid addresses, return nil instead of empty slice.
4849
        if len(result) == 0 {
×
4850
                return nil, nil
×
4851
        }
×
4852

4853
        return result, nil
×
4854
}
4855

4856
// parseAddress parses the given address string based on the address type
4857
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4858
// and opaque addresses.
4859
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4860
        switch addrType {
×
4861
        case addressTypeIPv4:
×
4862
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4863
                if err != nil {
×
4864
                        return nil, err
×
4865
                }
×
4866

4867
                tcp.IP = tcp.IP.To4()
×
4868

×
4869
                return tcp, nil
×
4870

4871
        case addressTypeIPv6:
×
4872
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
4873
                if err != nil {
×
4874
                        return nil, err
×
4875
                }
×
4876

4877
                return tcp, nil
×
4878

4879
        case addressTypeTorV3, addressTypeTorV2:
×
4880
                service, portStr, err := net.SplitHostPort(address)
×
4881
                if err != nil {
×
4882
                        return nil, fmt.Errorf("unable to split tor "+
×
4883
                                "address: %v", address)
×
4884
                }
×
4885

4886
                port, err := strconv.Atoi(portStr)
×
4887
                if err != nil {
×
4888
                        return nil, err
×
4889
                }
×
4890

4891
                return &tor.OnionAddr{
×
4892
                        OnionService: service,
×
4893
                        Port:         port,
×
4894
                }, nil
×
4895

4896
        case addressTypeDNS:
×
4897
                hostname, portStr, err := net.SplitHostPort(address)
×
4898
                if err != nil {
×
4899
                        return nil, fmt.Errorf("unable to split DNS "+
×
4900
                                "address: %v", address)
×
4901
                }
×
4902

4903
                port, err := strconv.Atoi(portStr)
×
4904
                if err != nil {
×
4905
                        return nil, err
×
4906
                }
×
4907

4908
                return &lnwire.DNSAddress{
×
4909
                        Hostname: hostname,
×
4910
                        Port:     uint16(port),
×
4911
                }, nil
×
4912

4913
        case addressTypeOpaque:
×
4914
                opaque, err := hex.DecodeString(address)
×
4915
                if err != nil {
×
4916
                        return nil, fmt.Errorf("unable to decode opaque "+
×
4917
                                "address: %v", address)
×
4918
                }
×
4919

4920
                return &lnwire.OpaqueAddrs{
×
4921
                        Payload: opaque,
×
4922
                }, nil
×
4923

4924
        default:
×
4925
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
4926
        }
4927
}
4928

4929
// batchNodeData holds all the related data for a batch of nodes.
4930
type batchNodeData struct {
4931
        // features is a map from a DB node ID to the feature bits for that
4932
        // node.
4933
        features map[int64][]int
4934

4935
        // addresses is a map from a DB node ID to the node's addresses.
4936
        addresses map[int64][]nodeAddress
4937

4938
        // extraFields is a map from a DB node ID to the extra signed fields
4939
        // for that node.
4940
        extraFields map[int64]map[uint64][]byte
4941
}
4942

4943
// nodeAddress holds the address type, position and address string for a
4944
// node. This is used to batch the fetching of node addresses.
4945
type nodeAddress struct {
4946
        addrType dbAddressType
4947
        position int32
4948
        address  string
4949
}
4950

4951
// batchLoadNodeData loads all related data for a batch of node IDs using the
4952
// provided SQLQueries interface. It returns a batchNodeData instance containing
4953
// the node features, addresses and extra signed fields.
4954
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
4955
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
4956

×
4957
        // Batch load the node features.
×
4958
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
4959
        if err != nil {
×
4960
                return nil, fmt.Errorf("unable to batch load node "+
×
4961
                        "features: %w", err)
×
4962
        }
×
4963

4964
        // Batch load the node addresses.
4965
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
4966
        if err != nil {
×
4967
                return nil, fmt.Errorf("unable to batch load node "+
×
4968
                        "addresses: %w", err)
×
4969
        }
×
4970

4971
        // Batch load the node extra signed fields.
4972
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
4973
        if err != nil {
×
4974
                return nil, fmt.Errorf("unable to batch load node extra "+
×
4975
                        "signed fields: %w", err)
×
4976
        }
×
4977

4978
        return &batchNodeData{
×
4979
                features:    features,
×
4980
                addresses:   addrs,
×
4981
                extraFields: extraTypes,
×
4982
        }, nil
×
4983
}
4984

4985
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
4986
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
4987
func batchLoadNodeFeaturesHelper(ctx context.Context,
4988
        cfg *sqldb.QueryConfig, db SQLQueries,
4989
        nodeIDs []int64) (map[int64][]int, error) {
×
4990

×
4991
        features := make(map[int64][]int)
×
4992

×
4993
        return features, sqldb.ExecuteBatchQuery(
×
4994
                ctx, cfg, nodeIDs,
×
4995
                func(id int64) int64 {
×
4996
                        return id
×
4997
                },
×
4998
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
4999
                        error) {
×
5000

×
5001
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
5002
                },
×
5003
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
5004
                        features[feature.NodeID] = append(
×
5005
                                features[feature.NodeID],
×
5006
                                int(feature.FeatureBit),
×
5007
                        )
×
5008

×
5009
                        return nil
×
5010
                },
×
5011
        )
5012
}
5013

5014
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
5015
// wrapper around the GetNodeAddressesBatch query. It returns a map from
5016
// node ID to a slice of nodeAddress structs.
5017
func batchLoadNodeAddressesHelper(ctx context.Context,
5018
        cfg *sqldb.QueryConfig, db SQLQueries,
5019
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
5020

×
5021
        addrs := make(map[int64][]nodeAddress)
×
5022

×
5023
        return addrs, sqldb.ExecuteBatchQuery(
×
5024
                ctx, cfg, nodeIDs,
×
5025
                func(id int64) int64 {
×
5026
                        return id
×
5027
                },
×
5028
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
5029
                        error) {
×
5030

×
5031
                        return db.GetNodeAddressesBatch(ctx, ids)
×
5032
                },
×
5033
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
5034
                        addrs[addr.NodeID] = append(
×
5035
                                addrs[addr.NodeID], nodeAddress{
×
5036
                                        addrType: dbAddressType(addr.Type),
×
5037
                                        position: addr.Position,
×
5038
                                        address:  addr.Address,
×
5039
                                },
×
5040
                        )
×
5041

×
5042
                        return nil
×
5043
                },
×
5044
        )
5045
}
5046

5047
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
5048
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
5049
// query.
5050
func batchLoadNodeExtraTypesHelper(ctx context.Context,
5051
        cfg *sqldb.QueryConfig, db SQLQueries,
5052
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5053

×
5054
        extraFields := make(map[int64]map[uint64][]byte)
×
5055

×
5056
        callback := func(ctx context.Context,
×
5057
                field sqlc.GraphNodeExtraType) error {
×
5058

×
5059
                if extraFields[field.NodeID] == nil {
×
5060
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
5061
                }
×
5062
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
5063

×
5064
                return nil
×
5065
        }
5066

5067
        return extraFields, sqldb.ExecuteBatchQuery(
×
5068
                ctx, cfg, nodeIDs,
×
5069
                func(id int64) int64 {
×
5070
                        return id
×
5071
                },
×
5072
                func(ctx context.Context, ids []int64) (
5073
                        []sqlc.GraphNodeExtraType, error) {
×
5074

×
5075
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
5076
                },
×
5077
                callback,
5078
        )
5079
}
5080

5081
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
5082
// from the provided sqlc.GraphChannelPolicy records and the
5083
// provided batchChannelData.
5084
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
5085
        channelID uint64, node1, node2 route.Vertex,
5086
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
5087
        *models.ChannelEdgePolicy, error) {
×
5088

×
5089
        pol1, err := buildChanPolicyWithBatchData(
×
5090
                dbPol1, channelID, node2, batchData,
×
5091
        )
×
5092
        if err != nil {
×
5093
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
5094
        }
×
5095

5096
        pol2, err := buildChanPolicyWithBatchData(
×
5097
                dbPol2, channelID, node1, batchData,
×
5098
        )
×
5099
        if err != nil {
×
5100
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
5101
        }
×
5102

5103
        return pol1, pol2, nil
×
5104
}
5105

5106
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
5107
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
5108
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
5109
        channelID uint64, toNode route.Vertex,
5110
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
5111

×
5112
        if dbPol == nil {
×
5113
                return nil, nil
×
5114
        }
×
5115

5116
        var dbPol1Extras map[uint64][]byte
×
5117
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
5118
                dbPol1Extras = extras
×
5119
        } else {
×
5120
                dbPol1Extras = make(map[uint64][]byte)
×
5121
        }
×
5122

5123
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
5124
}
5125

5126
// batchChannelData holds all the related data for a batch of channels.
5127
type batchChannelData struct {
5128
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
5129
        chanfeatures map[int64][]int
5130

5131
        // chanExtras is a map from DB channel ID to a map of TLV type to
5132
        // extra signed field bytes.
5133
        chanExtraTypes map[int64]map[uint64][]byte
5134

5135
        // policyExtras is a map from DB channel policy ID to a map of TLV type
5136
        // to extra signed field bytes.
5137
        policyExtras map[int64]map[uint64][]byte
5138
}
5139

5140
// batchLoadChannelData loads all related data for batches of channels and
5141
// policies.
5142
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
5143
        db SQLQueries, channelIDs []int64,
5144
        policyIDs []int64) (*batchChannelData, error) {
×
5145

×
5146
        batchData := &batchChannelData{
×
5147
                chanfeatures:   make(map[int64][]int),
×
5148
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
5149
                policyExtras:   make(map[int64]map[uint64][]byte),
×
5150
        }
×
5151

×
5152
        // Batch load channel features and extras
×
5153
        var err error
×
5154
        if len(channelIDs) > 0 {
×
5155
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
5156
                        ctx, cfg, db, channelIDs,
×
5157
                )
×
5158
                if err != nil {
×
5159
                        return nil, fmt.Errorf("unable to batch load "+
×
5160
                                "channel features: %w", err)
×
5161
                }
×
5162

5163
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
5164
                        ctx, cfg, db, channelIDs,
×
5165
                )
×
5166
                if err != nil {
×
5167
                        return nil, fmt.Errorf("unable to batch load "+
×
5168
                                "channel extras: %w", err)
×
5169
                }
×
5170
        }
5171

5172
        if len(policyIDs) > 0 {
×
5173
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
5174
                        ctx, cfg, db, policyIDs,
×
5175
                )
×
5176
                if err != nil {
×
5177
                        return nil, fmt.Errorf("unable to batch load "+
×
5178
                                "policy extras: %w", err)
×
5179
                }
×
5180
                batchData.policyExtras = policyExtras
×
5181
        }
5182

5183
        return batchData, nil
×
5184
}
5185

5186
// batchLoadChannelFeaturesHelper loads channel features for a batch of
5187
// channel IDs using ExecuteBatchQuery wrapper around the
5188
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
5189
// slice of feature bits.
5190
func batchLoadChannelFeaturesHelper(ctx context.Context,
5191
        cfg *sqldb.QueryConfig, db SQLQueries,
5192
        channelIDs []int64) (map[int64][]int, error) {
×
5193

×
5194
        features := make(map[int64][]int)
×
5195

×
5196
        return features, sqldb.ExecuteBatchQuery(
×
5197
                ctx, cfg, channelIDs,
×
5198
                func(id int64) int64 {
×
5199
                        return id
×
5200
                },
×
5201
                func(ctx context.Context,
5202
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
5203

×
5204
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
5205
                },
×
5206
                func(ctx context.Context,
5207
                        feature sqlc.GraphChannelFeature) error {
×
5208

×
5209
                        features[feature.ChannelID] = append(
×
5210
                                features[feature.ChannelID],
×
5211
                                int(feature.FeatureBit),
×
5212
                        )
×
5213

×
5214
                        return nil
×
5215
                },
×
5216
        )
5217
}
5218

5219
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
5220
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
5221
// query. It returns a map from DB channel ID to a map of TLV type to extra
5222
// signed field bytes.
5223
func batchLoadChannelExtrasHelper(ctx context.Context,
5224
        cfg *sqldb.QueryConfig, db SQLQueries,
5225
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5226

×
5227
        extras := make(map[int64]map[uint64][]byte)
×
5228

×
5229
        cb := func(ctx context.Context,
×
5230
                extra sqlc.GraphChannelExtraType) error {
×
5231

×
5232
                if extras[extra.ChannelID] == nil {
×
5233
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5234
                }
×
5235
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
5236

×
5237
                return nil
×
5238
        }
5239

5240
        return extras, sqldb.ExecuteBatchQuery(
×
5241
                ctx, cfg, channelIDs,
×
5242
                func(id int64) int64 {
×
5243
                        return id
×
5244
                },
×
5245
                func(ctx context.Context,
5246
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5247

×
5248
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5249
                }, cb,
×
5250
        )
5251
}
5252

5253
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5254
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5255
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5256
// a map of TLV type to extra signed field bytes.
5257
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5258
        cfg *sqldb.QueryConfig, db SQLQueries,
5259
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5260

×
5261
        extras := make(map[int64]map[uint64][]byte)
×
5262

×
5263
        return extras, sqldb.ExecuteBatchQuery(
×
5264
                ctx, cfg, policyIDs,
×
5265
                func(id int64) int64 {
×
5266
                        return id
×
5267
                },
×
5268
                func(ctx context.Context, ids []int64) (
5269
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5270

×
5271
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5272
                },
×
5273
                func(ctx context.Context,
5274
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5275

×
5276
                        if extras[row.PolicyID] == nil {
×
5277
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5278
                        }
×
5279
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5280

×
5281
                        return nil
×
5282
                },
5283
        )
5284
}
5285

5286
// forEachNodePaginated executes a paginated query to process each node in the
5287
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5288
// and applies the provided processNode function to each node.
5289
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5290
        db SQLQueries, protocol ProtocolVersion,
5291
        processNode func(context.Context, int64,
5292
                *models.Node) error) error {
×
5293

×
5294
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5295
                limit int32) ([]sqlc.GraphNode, error) {
×
5296

×
5297
                return db.ListNodesPaginated(
×
5298
                        ctx, sqlc.ListNodesPaginatedParams{
×
5299
                                Version: int16(protocol),
×
5300
                                ID:      lastID,
×
5301
                                Limit:   limit,
×
5302
                        },
×
5303
                )
×
5304
        }
×
5305

5306
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5307
                return node.ID
×
5308
        }
×
5309

5310
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5311
                return node.ID, nil
×
5312
        }
×
5313

5314
        batchQueryFunc := func(ctx context.Context,
×
5315
                nodeIDs []int64) (*batchNodeData, error) {
×
5316

×
5317
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5318
        }
×
5319

5320
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5321
                batchData *batchNodeData) error {
×
5322

×
5323
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
5324
                if err != nil {
×
5325
                        return fmt.Errorf("unable to build "+
×
5326
                                "node(id=%d): %w", dbNode.ID, err)
×
5327
                }
×
5328

5329
                return processNode(ctx, dbNode.ID, node)
×
5330
        }
5331

5332
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5333
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5334
                collectFunc, batchQueryFunc, processItem,
×
5335
        )
×
5336
}
5337

5338
// forEachChannelWithPolicies executes a paginated query to process each channel
5339
// with policies in the graph.
5340
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5341
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5342
                *models.ChannelEdgePolicy,
5343
                *models.ChannelEdgePolicy) error) error {
×
5344

×
5345
        type channelBatchIDs struct {
×
5346
                channelID int64
×
5347
                policyIDs []int64
×
5348
        }
×
5349

×
5350
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5351
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5352
                error) {
×
5353

×
5354
                return db.ListChannelsWithPoliciesPaginated(
×
5355
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
5356
                                Version: int16(ProtocolV1),
×
5357
                                ID:      lastID,
×
5358
                                Limit:   limit,
×
5359
                        },
×
5360
                )
×
5361
        }
×
5362

5363
        extractPageCursor := func(
×
5364
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5365

×
5366
                return row.GraphChannel.ID
×
5367
        }
×
5368

5369
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
5370
                channelBatchIDs, error) {
×
5371

×
5372
                ids := channelBatchIDs{
×
5373
                        channelID: row.GraphChannel.ID,
×
5374
                }
×
5375

×
5376
                // Extract policy IDs from the row.
×
5377
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5378
                if err != nil {
×
5379
                        return ids, err
×
5380
                }
×
5381

5382
                if dbPol1 != nil {
×
5383
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5384
                }
×
5385
                if dbPol2 != nil {
×
5386
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5387
                }
×
5388

5389
                return ids, nil
×
5390
        }
5391

5392
        batchDataFunc := func(ctx context.Context,
×
5393
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5394

×
5395
                // Separate channel IDs from policy IDs.
×
5396
                var (
×
5397
                        channelIDs = make([]int64, len(allIDs))
×
5398
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5399
                )
×
5400

×
5401
                for i, ids := range allIDs {
×
5402
                        channelIDs[i] = ids.channelID
×
5403
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5404
                }
×
5405

5406
                return batchLoadChannelData(
×
5407
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5408
                )
×
5409
        }
5410

5411
        processItem := func(ctx context.Context,
×
5412
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5413
                batchData *batchChannelData) error {
×
5414

×
5415
                node1, node2, err := buildNodeVertices(
×
5416
                        row.Node1Pubkey, row.Node2Pubkey,
×
5417
                )
×
5418
                if err != nil {
×
5419
                        return err
×
5420
                }
×
5421

5422
                edge, err := buildEdgeInfoWithBatchData(
×
5423
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5424
                        batchData,
×
5425
                )
×
5426
                if err != nil {
×
5427
                        return fmt.Errorf("unable to build channel info: %w",
×
5428
                                err)
×
5429
                }
×
5430

5431
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5432
                if err != nil {
×
5433
                        return err
×
5434
                }
×
5435

5436
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5437
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
5438
                )
×
5439
                if err != nil {
×
5440
                        return err
×
5441
                }
×
5442

5443
                return processChannel(edge, p1, p2)
×
5444
        }
5445

5446
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5447
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5448
                collectFunc, batchDataFunc, processItem,
×
5449
        )
×
5450
}
5451

5452
// buildDirectedChannel builds a DirectedChannel instance from the provided
5453
// data.
5454
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5455
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5456
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5457
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
5458

×
5459
        node1, node2, err := buildNodeVertices(
×
5460
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5461
        )
×
5462
        if err != nil {
×
5463
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5464
        }
×
5465

5466
        edge, err := buildEdgeInfoWithBatchData(
×
5467
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5468
        )
×
5469
        if err != nil {
×
5470
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5471
        }
×
5472

5473
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
5474
        if err != nil {
×
5475
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5476
                        err)
×
5477
        }
×
5478

5479
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5480
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5481
                channelBatchData,
×
5482
        )
×
5483
        if err != nil {
×
5484
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5485
                        err)
×
5486
        }
×
5487

5488
        // Determine outgoing and incoming policy for this specific node.
5489
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5490
        p2ToNode := channelRow.GraphChannel.NodeID1
×
5491
        outPolicy, inPolicy := p1, p2
×
5492
        if (p1 != nil && p1ToNode == nodeID) ||
×
5493
                (p2 != nil && p2ToNode != nodeID) {
×
5494

×
5495
                outPolicy, inPolicy = p2, p1
×
5496
        }
×
5497

5498
        // Build cached policy.
5499
        var cachedInPolicy *models.CachedEdgePolicy
×
5500
        if inPolicy != nil {
×
5501
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5502
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5503
                cachedInPolicy.ToNodeFeatures = features
×
5504
        }
×
5505

5506
        // Extract inbound fee.
5507
        var inboundFee lnwire.Fee
×
5508
        if outPolicy != nil {
×
5509
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5510
                        inboundFee = fee
×
5511
                })
×
5512
        }
5513

5514
        // Build directed channel.
5515
        directedChannel := &DirectedChannel{
×
5516
                ChannelID:    edge.ChannelID,
×
5517
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
5518
                OtherNode:    edge.NodeKey2Bytes,
×
5519
                Capacity:     edge.Capacity,
×
5520
                OutPolicySet: outPolicy != nil,
×
5521
                InPolicy:     cachedInPolicy,
×
5522
                InboundFee:   inboundFee,
×
5523
        }
×
5524

×
5525
        if nodePub == edge.NodeKey2Bytes {
×
5526
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5527
        }
×
5528

5529
        return directedChannel, nil
×
5530
}
5531

5532
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
5533
// provided rows. It uses batch loading for channels, policies, and nodes.
5534
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
5535
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
5536

×
5537
        var (
×
5538
                channelIDs = make([]int64, len(rows))
×
5539
                policyIDs  = make([]int64, 0, len(rows)*2)
×
5540
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
5541

×
5542
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
5543
                nodeIDSet = make(map[int64]bool)
×
5544

×
5545
                // edges will hold the final channel edges built from the rows.
×
5546
                edges = make([]ChannelEdge, 0, len(rows))
×
5547
        )
×
5548

×
5549
        // Collect all IDs needed for batch loading.
×
5550
        for i, row := range rows {
×
5551
                channelIDs[i] = row.Channel().ID
×
5552

×
5553
                // Collect policy IDs
×
5554
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5555
                if err != nil {
×
5556
                        return nil, fmt.Errorf("unable to extract channel "+
×
5557
                                "policies: %w", err)
×
5558
                }
×
5559
                if dbPol1 != nil {
×
5560
                        policyIDs = append(policyIDs, dbPol1.ID)
×
5561
                }
×
5562
                if dbPol2 != nil {
×
5563
                        policyIDs = append(policyIDs, dbPol2.ID)
×
5564
                }
×
5565

5566
                var (
×
5567
                        node1ID = row.Node1().ID
×
5568
                        node2ID = row.Node2().ID
×
5569
                )
×
5570

×
5571
                // Collect unique node IDs.
×
5572
                if !nodeIDSet[node1ID] {
×
5573
                        nodeIDs = append(nodeIDs, node1ID)
×
5574
                        nodeIDSet[node1ID] = true
×
5575
                }
×
5576

5577
                if !nodeIDSet[node2ID] {
×
5578
                        nodeIDs = append(nodeIDs, node2ID)
×
5579
                        nodeIDSet[node2ID] = true
×
5580
                }
×
5581
        }
5582

5583
        // Batch the data for all the channels and policies.
5584
        channelBatchData, err := batchLoadChannelData(
×
5585
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5586
        )
×
5587
        if err != nil {
×
5588
                return nil, fmt.Errorf("unable to batch load channel and "+
×
5589
                        "policy data: %w", err)
×
5590
        }
×
5591

5592
        // Batch the data for all the nodes.
5593
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
5594
        if err != nil {
×
5595
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
5596
                        err)
×
5597
        }
×
5598

5599
        // Build all channel edges using batch data.
5600
        for _, row := range rows {
×
5601
                // Build nodes using batch data.
×
5602
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
5603
                if err != nil {
×
5604
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
5605
                }
×
5606

5607
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
5608
                if err != nil {
×
5609
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
5610
                }
×
5611

5612
                // Build channel info using batch data.
5613
                channel, err := buildEdgeInfoWithBatchData(
×
5614
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
5615
                        node2.PubKeyBytes, channelBatchData,
×
5616
                )
×
5617
                if err != nil {
×
5618
                        return nil, fmt.Errorf("unable to build channel "+
×
5619
                                "info: %w", err)
×
5620
                }
×
5621

5622
                // Extract and build policies using batch data.
5623
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5624
                if err != nil {
×
5625
                        return nil, fmt.Errorf("unable to extract channel "+
×
5626
                                "policies: %w", err)
×
5627
                }
×
5628

5629
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5630
                        dbPol1, dbPol2, channel.ChannelID,
×
5631
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
5632
                )
×
5633
                if err != nil {
×
5634
                        return nil, fmt.Errorf("unable to build channel "+
×
5635
                                "policies: %w", err)
×
5636
                }
×
5637

5638
                edges = append(edges, ChannelEdge{
×
5639
                        Info:    channel,
×
5640
                        Policy1: p1,
×
5641
                        Policy2: p2,
×
5642
                        Node1:   node1,
×
5643
                        Node2:   node2,
×
5644
                })
×
5645
        }
5646

5647
        return edges, nil
×
5648
}
5649

5650
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
5651
// instances from the provided rows using batch loading for channel data.
5652
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
5653
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
5654
        []*models.ChannelEdgeInfo, []int64, error) {
×
5655

×
5656
        if len(rows) == 0 {
×
5657
                return nil, nil, nil
×
5658
        }
×
5659

5660
        // Collect all the channel IDs needed for batch loading.
5661
        channelIDs := make([]int64, len(rows))
×
5662
        for i, row := range rows {
×
5663
                channelIDs[i] = row.Channel().ID
×
5664
        }
×
5665

5666
        // Batch load the channel data.
5667
        channelBatchData, err := batchLoadChannelData(
×
5668
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
5669
        )
×
5670
        if err != nil {
×
5671
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
5672
                        "data: %w", err)
×
5673
        }
×
5674

5675
        // Build all channel edges using batch data.
5676
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
5677
        for _, row := range rows {
×
5678
                node1, node2, err := buildNodeVertices(
×
5679
                        row.Node1Pub(), row.Node2Pub(),
×
5680
                )
×
5681
                if err != nil {
×
5682
                        return nil, nil, err
×
5683
                }
×
5684

5685
                // Build channel info using batch data
5686
                info, err := buildEdgeInfoWithBatchData(
×
5687
                        cfg.ChainHash, row.Channel(), node1, node2,
×
5688
                        channelBatchData,
×
5689
                )
×
5690
                if err != nil {
×
5691
                        return nil, nil, err
×
5692
                }
×
5693

5694
                edges = append(edges, info)
×
5695
        }
5696

5697
        return edges, channelIDs, nil
×
5698
}
5699

5700
// handleZombieMarking is a helper function that handles the logic of
5701
// marking a channel as a zombie in the database. It takes into account whether
5702
// we are in strict zombie pruning mode, and adjusts the node public keys
5703
// accordingly based on the last update timestamps of the channel policies.
5704
func handleZombieMarking(ctx context.Context, db SQLQueries,
5705
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
5706
        strictZombiePruning bool, scid uint64) error {
×
5707

×
5708
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
5709

×
5710
        if strictZombiePruning {
×
5711
                var e1UpdateTime, e2UpdateTime *time.Time
×
5712
                if row.Policy1LastUpdate.Valid {
×
5713
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
5714
                        e1UpdateTime = &e1Time
×
5715
                }
×
5716
                if row.Policy2LastUpdate.Valid {
×
5717
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
5718
                        e2UpdateTime = &e2Time
×
5719
                }
×
5720

5721
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
5722
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
5723
                        e2UpdateTime,
×
5724
                )
×
5725
        }
5726

5727
        return db.UpsertZombieChannel(
×
5728
                ctx, sqlc.UpsertZombieChannelParams{
×
5729
                        Version:  int16(ProtocolV1),
×
5730
                        Scid:     channelIDToBytes(scid),
×
5731
                        NodeKey1: nodeKey1[:],
×
5732
                        NodeKey2: nodeKey2[:],
×
5733
                },
×
5734
        )
×
5735
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc