• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16777740336

06 Aug 2025 01:04PM UTC coverage: 54.85% (-12.1%) from 66.954%
16777740336

Pull #10135

github

web-flow
Merge 429aa830c into e512770f1
Pull Request #10135: docs: move v0.19.3 items to correct file

108702 of 198181 relevant lines covered (54.85%)

22045.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// ProtocolVersion is an enum that defines the gossip protocol version of a
35
// message.
36
type ProtocolVersion uint8
37

38
const (
39
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
40
        ProtocolV1 ProtocolVersion = 1
41
)
42

43
// String returns a string representation of the protocol version.
44
func (v ProtocolVersion) String() string {
×
45
        return fmt.Sprintf("V%d", v)
×
46
}
×
47

48
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
49
// execute queries against the SQL graph tables.
50
//
51
//nolint:ll,interfacebloat
52
type SQLQueries interface {
53
        /*
54
                Node queries.
55
        */
56
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
57
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
58
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
59
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
60
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
61
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
62
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
63
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
64
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
65
        DeleteNode(ctx context.Context, id int64) error
66

67
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
68
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
69
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
70
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
71

72
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
73
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
74
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
75
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
76

77
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
78
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
79
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
80
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
81
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
82

83
        /*
84
                Source node queries.
85
        */
86
        AddSourceNode(ctx context.Context, nodeID int64) error
87
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
88

89
        /*
90
                Channel queries.
91
        */
92
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
93
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
94
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
95
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
96
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
97
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
98
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
99
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
100
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
101
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
102
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
103
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
104
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
105
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
106
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
107
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
108
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
109
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
110
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
111
        DeleteChannels(ctx context.Context, ids []int64) error
112

113
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
114
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
115
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
116
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
117

118
        /*
119
                Channel Policy table queries.
120
        */
121
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
122
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
123
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
124

125
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
126
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
127
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
128

129
        /*
130
                Zombie index queries.
131
        */
132
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
133
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
134
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
135
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
136
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
137

138
        /*
139
                Prune log table queries.
140
        */
141
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
142
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
143
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
144
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
145

146
        /*
147
                Closed SCID table queries.
148
        */
149
        InsertClosedChannel(ctx context.Context, scid []byte) error
150
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
151
}
152

153
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
154
// database operations.
155
type BatchedSQLQueries interface {
156
        SQLQueries
157
        sqldb.BatchedTx[SQLQueries]
158
}
159

160
// SQLStore is an implementation of the V1Store interface that uses a SQL
161
// database as the backend.
162
type SQLStore struct {
163
        cfg *SQLStoreConfig
164
        db  BatchedSQLQueries
165

166
        // cacheMu guards all caches (rejectCache and chanCache). If
167
        // this mutex will be acquired at the same time as the DB mutex then
168
        // the cacheMu MUST be acquired first to prevent deadlock.
169
        cacheMu     sync.RWMutex
170
        rejectCache *rejectCache
171
        chanCache   *channelCache
172

173
        chanScheduler batch.Scheduler[SQLQueries]
174
        nodeScheduler batch.Scheduler[SQLQueries]
175

176
        srcNodes  map[ProtocolVersion]*srcNodeInfo
177
        srcNodeMu sync.Mutex
178
}
179

180
// A compile-time assertion to ensure that SQLStore implements the V1Store
181
// interface.
182
var _ V1Store = (*SQLStore)(nil)
183

184
// SQLStoreConfig holds the configuration for the SQLStore.
185
type SQLStoreConfig struct {
186
        // ChainHash is the genesis hash for the chain that all the gossip
187
        // messages in this store are aimed at.
188
        ChainHash chainhash.Hash
189

190
        // QueryConfig holds configuration values for SQL queries.
191
        QueryCfg *sqldb.QueryConfig
192
}
193

194
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
195
// storage backend.
196
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
197
        options ...StoreOptionModifier) (*SQLStore, error) {
×
198

×
199
        opts := DefaultOptions()
×
200
        for _, o := range options {
×
201
                o(opts)
×
202
        }
×
203

204
        if opts.NoMigration {
×
205
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
206
                        "supported for SQL stores")
×
207
        }
×
208

209
        s := &SQLStore{
×
210
                cfg:         cfg,
×
211
                db:          db,
×
212
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
213
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
214
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
215
        }
×
216

×
217
        s.chanScheduler = batch.NewTimeScheduler(
×
218
                db, &s.cacheMu, opts.BatchCommitInterval,
×
219
        )
×
220
        s.nodeScheduler = batch.NewTimeScheduler(
×
221
                db, nil, opts.BatchCommitInterval,
×
222
        )
×
223

×
224
        return s, nil
×
225
}
226

227
// AddLightningNode adds a vertex/node to the graph database. If the node is not
228
// in the database from before, this will add a new, unconnected one to the
229
// graph. If it is present from before, this will update that node's
230
// information.
231
//
232
// NOTE: part of the V1Store interface.
233
func (s *SQLStore) AddLightningNode(ctx context.Context,
234
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
235

×
236
        r := &batch.Request[SQLQueries]{
×
237
                Opts: batch.NewSchedulerOptions(opts...),
×
238
                Do: func(queries SQLQueries) error {
×
239
                        _, err := upsertNode(ctx, queries, node)
×
240
                        return err
×
241
                },
×
242
        }
243

244
        return s.nodeScheduler.Execute(ctx, r)
×
245
}
246

247
// FetchLightningNode attempts to look up a target node by its identity public
248
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
249
// returned.
250
//
251
// NOTE: part of the V1Store interface.
252
func (s *SQLStore) FetchLightningNode(ctx context.Context,
253
        pubKey route.Vertex) (*models.LightningNode, error) {
×
254

×
255
        var node *models.LightningNode
×
256
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
257
                var err error
×
258
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
259

×
260
                return err
×
261
        }, sqldb.NoOpReset)
×
262
        if err != nil {
×
263
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
264
        }
×
265

266
        return node, nil
×
267
}
268

269
// HasLightningNode determines if the graph has a vertex identified by the
270
// target node identity public key. If the node exists in the database, a
271
// timestamp of when the data for the node was lasted updated is returned along
272
// with a true boolean. Otherwise, an empty time.Time is returned with a false
273
// boolean.
274
//
275
// NOTE: part of the V1Store interface.
276
func (s *SQLStore) HasLightningNode(ctx context.Context,
277
        pubKey [33]byte) (time.Time, bool, error) {
×
278

×
279
        var (
×
280
                exists     bool
×
281
                lastUpdate time.Time
×
282
        )
×
283
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
284
                dbNode, err := db.GetNodeByPubKey(
×
285
                        ctx, sqlc.GetNodeByPubKeyParams{
×
286
                                Version: int16(ProtocolV1),
×
287
                                PubKey:  pubKey[:],
×
288
                        },
×
289
                )
×
290
                if errors.Is(err, sql.ErrNoRows) {
×
291
                        return nil
×
292
                } else if err != nil {
×
293
                        return fmt.Errorf("unable to fetch node: %w", err)
×
294
                }
×
295

296
                exists = true
×
297

×
298
                if dbNode.LastUpdate.Valid {
×
299
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
300
                }
×
301

302
                return nil
×
303
        }, sqldb.NoOpReset)
304
        if err != nil {
×
305
                return time.Time{}, false,
×
306
                        fmt.Errorf("unable to fetch node: %w", err)
×
307
        }
×
308

309
        return lastUpdate, exists, nil
×
310
}
311

312
// AddrsForNode returns all known addresses for the target node public key
313
// that the graph DB is aware of. The returned boolean indicates if the
314
// given node is unknown to the graph DB or not.
315
//
316
// NOTE: part of the V1Store interface.
317
func (s *SQLStore) AddrsForNode(ctx context.Context,
318
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
319

×
320
        var (
×
321
                addresses []net.Addr
×
322
                known     bool
×
323
        )
×
324
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
325
                // First, check if the node exists and get its DB ID if it
×
326
                // does.
×
327
                dbID, err := db.GetNodeIDByPubKey(
×
328
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
329
                                Version: int16(ProtocolV1),
×
330
                                PubKey:  nodePub.SerializeCompressed(),
×
331
                        },
×
332
                )
×
333
                if errors.Is(err, sql.ErrNoRows) {
×
334
                        return nil
×
335
                }
×
336

337
                known = true
×
338

×
339
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
340
                if err != nil {
×
341
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
342
                                err)
×
343
                }
×
344

345
                return nil
×
346
        }, sqldb.NoOpReset)
347
        if err != nil {
×
348
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
349
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
350
        }
×
351

352
        return known, addresses, nil
×
353
}
354

355
// DeleteLightningNode starts a new database transaction to remove a vertex/node
356
// from the database according to the node's public key.
357
//
358
// NOTE: part of the V1Store interface.
359
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
360
        pubKey route.Vertex) error {
×
361

×
362
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
363
                res, err := db.DeleteNodeByPubKey(
×
364
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
365
                                Version: int16(ProtocolV1),
×
366
                                PubKey:  pubKey[:],
×
367
                        },
×
368
                )
×
369
                if err != nil {
×
370
                        return err
×
371
                }
×
372

373
                rows, err := res.RowsAffected()
×
374
                if err != nil {
×
375
                        return err
×
376
                }
×
377

378
                if rows == 0 {
×
379
                        return ErrGraphNodeNotFound
×
380
                } else if rows > 1 {
×
381
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
382
                }
×
383

384
                return err
×
385
        }, sqldb.NoOpReset)
386
        if err != nil {
×
387
                return fmt.Errorf("unable to delete node: %w", err)
×
388
        }
×
389

390
        return nil
×
391
}
392

393
// FetchNodeFeatures returns the features of the given node. If no features are
394
// known for the node, an empty feature vector is returned.
395
//
396
// NOTE: this is part of the graphdb.NodeTraverser interface.
397
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
398
        *lnwire.FeatureVector, error) {
×
399

×
400
        ctx := context.TODO()
×
401

×
402
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
403
}
×
404

405
// DisabledChannelIDs returns the channel ids of disabled channels.
406
// A channel is disabled when two of the associated ChanelEdgePolicies
407
// have their disabled bit on.
408
//
409
// NOTE: part of the V1Store interface.
410
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
411
        var (
×
412
                ctx     = context.TODO()
×
413
                chanIDs []uint64
×
414
        )
×
415
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
416
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
417
                if err != nil {
×
418
                        return fmt.Errorf("unable to fetch disabled "+
×
419
                                "channels: %w", err)
×
420
                }
×
421

422
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
423

×
424
                return nil
×
425
        }, sqldb.NoOpReset)
426
        if err != nil {
×
427
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
428
                        err)
×
429
        }
×
430

431
        return chanIDs, nil
×
432
}
433

434
// LookupAlias attempts to return the alias as advertised by the target node.
435
//
436
// NOTE: part of the V1Store interface.
437
func (s *SQLStore) LookupAlias(ctx context.Context,
438
        pub *btcec.PublicKey) (string, error) {
×
439

×
440
        var alias string
×
441
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
442
                dbNode, err := db.GetNodeByPubKey(
×
443
                        ctx, sqlc.GetNodeByPubKeyParams{
×
444
                                Version: int16(ProtocolV1),
×
445
                                PubKey:  pub.SerializeCompressed(),
×
446
                        },
×
447
                )
×
448
                if errors.Is(err, sql.ErrNoRows) {
×
449
                        return ErrNodeAliasNotFound
×
450
                } else if err != nil {
×
451
                        return fmt.Errorf("unable to fetch node: %w", err)
×
452
                }
×
453

454
                if !dbNode.Alias.Valid {
×
455
                        return ErrNodeAliasNotFound
×
456
                }
×
457

458
                alias = dbNode.Alias.String
×
459

×
460
                return nil
×
461
        }, sqldb.NoOpReset)
462
        if err != nil {
×
463
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
464
        }
×
465

466
        return alias, nil
×
467
}
468

469
// SourceNode returns the source node of the graph. The source node is treated
470
// as the center node within a star-graph. This method may be used to kick off
471
// a path finding algorithm in order to explore the reachability of another
472
// node based off the source node.
473
//
474
// NOTE: part of the V1Store interface.
475
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
476
        error) {
×
477

×
478
        var node *models.LightningNode
×
479
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
480
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
481
                if err != nil {
×
482
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
483
                                err)
×
484
                }
×
485

486
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
487

×
488
                return err
×
489
        }, sqldb.NoOpReset)
490
        if err != nil {
×
491
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
492
        }
×
493

494
        return node, nil
×
495
}
496

497
// SetSourceNode sets the source node within the graph database. The source
498
// node is to be used as the center of a star-graph within path finding
499
// algorithms.
500
//
501
// NOTE: part of the V1Store interface.
502
func (s *SQLStore) SetSourceNode(ctx context.Context,
503
        node *models.LightningNode) error {
×
504

×
505
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
506
                id, err := upsertNode(ctx, db, node)
×
507
                if err != nil {
×
508
                        return fmt.Errorf("unable to upsert source node: %w",
×
509
                                err)
×
510
                }
×
511

512
                // Make sure that if a source node for this version is already
513
                // set, then the ID is the same as the one we are about to set.
514
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
515
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
516
                        return fmt.Errorf("unable to fetch source node: %w",
×
517
                                err)
×
518
                } else if err == nil {
×
519
                        if dbSourceNodeID != id {
×
520
                                return fmt.Errorf("v1 source node already "+
×
521
                                        "set to a different node: %d vs %d",
×
522
                                        dbSourceNodeID, id)
×
523
                        }
×
524

525
                        return nil
×
526
                }
527

528
                return db.AddSourceNode(ctx, id)
×
529
        }, sqldb.NoOpReset)
530
}
531

532
// NodeUpdatesInHorizon returns all the known lightning node which have an
533
// update timestamp within the passed range. This method can be used by two
534
// nodes to quickly determine if they have the same set of up to date node
535
// announcements.
536
//
537
// NOTE: This is part of the V1Store interface.
538
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
539
        endTime time.Time) ([]models.LightningNode, error) {
×
540

×
541
        ctx := context.TODO()
×
542

×
543
        var nodes []models.LightningNode
×
544
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
545
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
546
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
547
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
548
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
549
                        },
×
550
                )
×
551
                if err != nil {
×
552
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
553
                }
×
554

555
                err = forEachNodeInBatch(
×
556
                        ctx, s.cfg.QueryCfg, db, dbNodes,
×
557
                        func(_ int64, node *models.LightningNode) error {
×
558
                                nodes = append(nodes, *node)
×
559

×
560
                                return nil
×
561
                        },
×
562
                )
563
                if err != nil {
×
564
                        return fmt.Errorf("unable to build nodes: %w", err)
×
565
                }
×
566

567
                return nil
×
568
        }, sqldb.NoOpReset)
569
        if err != nil {
×
570
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
571
        }
×
572

573
        return nodes, nil
×
574
}
575

576
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
577
// undirected edge from the two target nodes are created. The information stored
578
// denotes the static attributes of the channel, such as the channelID, the keys
579
// involved in creation of the channel, and the set of features that the channel
580
// supports. The chanPoint and chanID are used to uniquely identify the edge
581
// globally within the database.
582
//
583
// NOTE: part of the V1Store interface.
584
func (s *SQLStore) AddChannelEdge(ctx context.Context,
585
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
586

×
587
        var alreadyExists bool
×
588
        r := &batch.Request[SQLQueries]{
×
589
                Opts: batch.NewSchedulerOptions(opts...),
×
590
                Reset: func() {
×
591
                        alreadyExists = false
×
592
                },
×
593
                Do: func(tx SQLQueries) error {
×
594
                        _, err := insertChannel(ctx, tx, edge)
×
595

×
596
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
597
                        // succeed, but propagate the error via local state.
×
598
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
599
                                alreadyExists = true
×
600
                                return nil
×
601
                        }
×
602

603
                        return err
×
604
                },
605
                OnCommit: func(err error) error {
×
606
                        switch {
×
607
                        case err != nil:
×
608
                                return err
×
609
                        case alreadyExists:
×
610
                                return ErrEdgeAlreadyExist
×
611
                        default:
×
612
                                s.rejectCache.remove(edge.ChannelID)
×
613
                                s.chanCache.remove(edge.ChannelID)
×
614
                                return nil
×
615
                        }
616
                },
617
        }
618

619
        return s.chanScheduler.Execute(ctx, r)
×
620
}
621

622
// HighestChanID returns the "highest" known channel ID in the channel graph.
623
// This represents the "newest" channel from the PoV of the chain. This method
624
// can be used by peers to quickly determine if their graphs are in sync.
625
//
626
// NOTE: This is part of the V1Store interface.
627
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
628
        var highestChanID uint64
×
629
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
630
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
631
                if errors.Is(err, sql.ErrNoRows) {
×
632
                        return nil
×
633
                } else if err != nil {
×
634
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
635
                                err)
×
636
                }
×
637

638
                highestChanID = byteOrder.Uint64(chanID)
×
639

×
640
                return nil
×
641
        }, sqldb.NoOpReset)
642
        if err != nil {
×
643
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
644
        }
×
645

646
        return highestChanID, nil
×
647
}
648

649
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
650
// within the database for the referenced channel. The `flags` attribute within
651
// the ChannelEdgePolicy determines which of the directed edges are being
652
// updated. If the flag is 1, then the first node's information is being
653
// updated, otherwise it's the second node's information. The node ordering is
654
// determined by the lexicographical ordering of the identity public keys of the
655
// nodes on either side of the channel.
656
//
657
// NOTE: part of the V1Store interface.
658
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
659
        edge *models.ChannelEdgePolicy,
660
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
661

×
662
        var (
×
663
                isUpdate1    bool
×
664
                edgeNotFound bool
×
665
                from, to     route.Vertex
×
666
        )
×
667

×
668
        r := &batch.Request[SQLQueries]{
×
669
                Opts: batch.NewSchedulerOptions(opts...),
×
670
                Reset: func() {
×
671
                        isUpdate1 = false
×
672
                        edgeNotFound = false
×
673
                },
×
674
                Do: func(tx SQLQueries) error {
×
675
                        var err error
×
676
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
677
                                ctx, tx, edge,
×
678
                        )
×
679
                        if err != nil {
×
680
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
681
                        }
×
682

683
                        // Silence ErrEdgeNotFound so that the batch can
684
                        // succeed, but propagate the error via local state.
685
                        if errors.Is(err, ErrEdgeNotFound) {
×
686
                                edgeNotFound = true
×
687
                                return nil
×
688
                        }
×
689

690
                        return err
×
691
                },
692
                OnCommit: func(err error) error {
×
693
                        switch {
×
694
                        case err != nil:
×
695
                                return err
×
696
                        case edgeNotFound:
×
697
                                return ErrEdgeNotFound
×
698
                        default:
×
699
                                s.updateEdgeCache(edge, isUpdate1)
×
700
                                return nil
×
701
                        }
702
                },
703
        }
704

705
        err := s.chanScheduler.Execute(ctx, r)
×
706

×
707
        return from, to, err
×
708
}
709

710
// updateEdgeCache updates our reject and channel caches with the new
711
// edge policy information.
712
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
713
        isUpdate1 bool) {
×
714

×
715
        // If an entry for this channel is found in reject cache, we'll modify
×
716
        // the entry with the updated timestamp for the direction that was just
×
717
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
718
        // during the next query for this edge.
×
719
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
720
                if isUpdate1 {
×
721
                        entry.upd1Time = e.LastUpdate.Unix()
×
722
                } else {
×
723
                        entry.upd2Time = e.LastUpdate.Unix()
×
724
                }
×
725
                s.rejectCache.insert(e.ChannelID, entry)
×
726
        }
727

728
        // If an entry for this channel is found in channel cache, we'll modify
729
        // the entry with the updated policy for the direction that was just
730
        // written. If the edge doesn't exist, we'll defer loading the info and
731
        // policies and lazily read from disk during the next query.
732
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
733
                if isUpdate1 {
×
734
                        channel.Policy1 = e
×
735
                } else {
×
736
                        channel.Policy2 = e
×
737
                }
×
738
                s.chanCache.insert(e.ChannelID, channel)
×
739
        }
740
}
741

742
// ForEachSourceNodeChannel iterates through all channels of the source node,
743
// executing the passed callback on each. The call-back is provided with the
744
// channel's outpoint, whether we have a policy for the channel and the channel
745
// peer's node information.
746
//
747
// NOTE: part of the V1Store interface.
748
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
749
        cb func(chanPoint wire.OutPoint, havePolicy bool,
750
                otherNode *models.LightningNode) error, reset func()) error {
×
751

×
752
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
753
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
754
                if err != nil {
×
755
                        return fmt.Errorf("unable to fetch source node: %w",
×
756
                                err)
×
757
                }
×
758

759
                return forEachNodeChannel(
×
760
                        ctx, db, s.cfg, nodeID,
×
761
                        func(info *models.ChannelEdgeInfo,
×
762
                                outPolicy *models.ChannelEdgePolicy,
×
763
                                _ *models.ChannelEdgePolicy) error {
×
764

×
765
                                // Fetch the other node.
×
766
                                var (
×
767
                                        otherNodePub [33]byte
×
768
                                        node1        = info.NodeKey1Bytes
×
769
                                        node2        = info.NodeKey2Bytes
×
770
                                )
×
771
                                switch {
×
772
                                case bytes.Equal(node1[:], nodePub[:]):
×
773
                                        otherNodePub = node2
×
774
                                case bytes.Equal(node2[:], nodePub[:]):
×
775
                                        otherNodePub = node1
×
776
                                default:
×
777
                                        return fmt.Errorf("node not " +
×
778
                                                "participating in this channel")
×
779
                                }
780

781
                                _, otherNode, err := getNodeByPubKey(
×
782
                                        ctx, db, otherNodePub,
×
783
                                )
×
784
                                if err != nil {
×
785
                                        return fmt.Errorf("unable to fetch "+
×
786
                                                "other node(%x): %w",
×
787
                                                otherNodePub, err)
×
788
                                }
×
789

790
                                return cb(
×
791
                                        info.ChannelPoint, outPolicy != nil,
×
792
                                        otherNode,
×
793
                                )
×
794
                        },
795
                )
796
        }, reset)
797
}
798

799
// ForEachNode iterates through all the stored vertices/nodes in the graph,
800
// executing the passed callback with each node encountered. If the callback
801
// returns an error, then the transaction is aborted and the iteration stops
802
// early. Any operations performed on the NodeTx passed to the call-back are
803
// executed under the same read transaction and so, methods on the NodeTx object
804
// _MUST_ only be called from within the call-back.
805
//
806
// NOTE: part of the V1Store interface.
×
807
func (s *SQLStore) ForEachNode(ctx context.Context,
×
808
        cb func(tx NodeRTx) error, reset func()) error {
×
809

×
810
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
811
                return forEachNodePaginated(
×
812
                        ctx, s.cfg.QueryCfg, db,
×
813
                        ProtocolV1,
×
814
                        func(ctx context.Context, dbNodeID int64,
×
815
                                node *models.LightningNode) error {
×
816

817
                                return cb(newSQLGraphNodeTx(
818
                                        db, s.cfg, dbNodeID, node,
819
                                ))
820
                        },
821
                )
822
        }, reset)
823
}
824

825
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
826
// SQLStore and a SQL transaction.
827
type sqlGraphNodeTx struct {
828
        db   SQLQueries
829
        id   int64
×
830
        node *models.LightningNode
×
831
        cfg  *SQLStoreConfig
×
832
}
×
833

×
834
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
×
835
// interface.
×
836
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
837

838
func newSQLGraphNodeTx(db SQLQueries, cfg *SQLStoreConfig,
839
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
840

841
        return &sqlGraphNodeTx{
842
                db:   db,
843
                cfg:  cfg,
844
                id:   id,
×
845
                node: node,
×
846
        }
×
847
}
×
848

×
849
// Node returns the raw information of the node.
×
850
//
×
851
// NOTE: This is a part of the NodeRTx interface.
×
852
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
853
        return s.node
×
854
}
855

856
// ForEachChannel can be used to iterate over the node's channels under the same
×
857
// transaction used to fetch the node.
×
858
//
×
859
// NOTE: This is a part of the NodeRTx interface.
860
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
×
861
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
862

863
        ctx := context.TODO()
864

865
        return forEachNodeChannel(ctx, s.db, s.cfg, s.id, cb)
866
}
867

868
// FetchNode fetches the node with the given pub key under the same transaction
869
// used to fetch the current node. The returned node is also a NodeRTx and any
870
// operations on that NodeRTx will also be done under the same transaction.
871
//
872
// NOTE: This is a part of the NodeRTx interface.
873
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
874
        ctx := context.TODO()
875

×
876
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
877
        if err != nil {
×
878
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
879
                        nodePub, err)
×
880
        }
×
881

×
882
        return newSQLGraphNodeTx(s.db, s.cfg, id, node), nil
×
883
}
×
884

×
885
// ForEachNodeDirectedChannel iterates through all channels of a given node,
×
886
// executing the passed callback on the directed edge representing the channel
×
887
// and its incoming policy. If the callback returns an error, then the iteration
×
888
// is halted with the error propagated back up to the caller.
×
889
//
890
// Unknown policies are passed into the callback as nil values.
×
891
//
892
// NOTE: this is part of the graphdb.NodeTraverser interface.
893
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
894
        cb func(channel *DirectedChannel) error, reset func()) error {
895

896
        var ctx = context.TODO()
897

898
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
899
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
900
        }, reset)
×
901
}
×
902

×
903
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
×
904
// graph, executing the passed callback with each node encountered. If the
×
905
// callback returns an error, then the transaction is aborted and the iteration
×
906
// stops early.
×
907
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
×
908
        cb func(route.Vertex, *lnwire.FeatureVector) error,
×
909
        reset func()) error {
×
910

×
911
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
912
                return forEachNodeCacheable(
×
913
                        ctx, s.cfg.QueryCfg, db,
×
914
                        func(_ int64, nodePub route.Vertex,
×
915
                                features *lnwire.FeatureVector) error {
×
916

×
917
                                return cb(nodePub, features)
×
918
                        },
×
919
                )
×
920
        }, reset)
×
921
        if err != nil {
×
922
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
923
        }
×
924

×
925
        return nil
926
}
×
927

×
928
// ForEachNodeChannel iterates through all channels of the given node,
×
929
// executing the passed callback with an edge info structure and the policies
×
930
// of each end of the channel. The first edge policy is the outgoing edge *to*
×
931
// the connecting node, while the second is the incoming edge *from* the
×
932
// connecting node. If the callback returns an error, then the iteration is
×
933
// halted with the error propagated back up to the caller.
934
//
935
// Unknown policies are passed into the callback as nil values.
×
936
//
×
937
// NOTE: part of the V1Store interface.
×
938
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
×
939
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
×
940
                *models.ChannelEdgePolicy) error, reset func()) error {
×
941

942
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
943
                dbNode, err := db.GetNodeByPubKey(
×
944
                        ctx, sqlc.GetNodeByPubKeyParams{
×
945
                                Version: int16(ProtocolV1),
×
946
                                PubKey:  nodePub[:],
×
947
                        },
×
948
                )
×
949
                if errors.Is(err, sql.ErrNoRows) {
950
                        return nil
×
951
                } else if err != nil {
×
952
                        return fmt.Errorf("unable to fetch node: %w", err)
×
953
                }
×
954

×
955
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
956
        }, reset)
×
957
}
×
958

959
// ChanUpdatesInHorizon returns all the known channel edges which have at least
×
960
// one edge that has an update timestamp within the specified horizon.
×
961
//
×
962
// NOTE: This is part of the V1Store interface.
×
963
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
×
964
        endTime time.Time) ([]ChannelEdge, error) {
965

×
966
        s.cacheMu.Lock()
×
967
        defer s.cacheMu.Unlock()
×
968

×
969
        var (
×
970
                ctx = context.TODO()
×
971
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
972
                // an additional map to keep track of the edges already seen to
×
973
                // prevent re-adding it.
974
                edgesSeen    = make(map[uint64]struct{})
×
975
                edgesToCache = make(map[uint64]ChannelEdge)
×
976
                edges        []ChannelEdge
×
977
                hits         int
×
978
        )
×
979
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
980
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
981
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
982
                                Version:   int16(ProtocolV1),
×
983
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
984
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
985
                        },
986
                )
×
987
                if err != nil {
×
988
                        return err
×
989
                }
×
990

×
991
                for _, row := range rows {
×
992
                        // If we've already retrieved the info and policies for
×
993
                        // this edge, then we can skip it as we don't need to do
×
994
                        // so again.
×
995
                        chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid)
996
                        if _, ok := edgesSeen[chanIDInt]; ok {
997
                                continue
×
998
                        }
×
999

×
1000
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
1001
                                hits++
×
1002
                                edgesSeen[chanIDInt] = struct{}{}
×
1003
                                edges = append(edges, channel)
×
1004

×
1005
                                continue
×
1006
                        }
×
1007

×
1008
                        node1, node2, err := buildNodes(
1009
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
1010
                        )
1011
                        if err != nil {
1012
                                return err
1013
                        }
1014

1015
                        channel, err := getAndBuildEdgeInfo(
1016
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
1017
                                node1.PubKeyBytes, node2.PubKeyBytes,
1018
                        )
1019
                        if err != nil {
1020
                                return fmt.Errorf("unable to build channel "+
1021
                                        "info: %w", err)
×
1022
                        }
×
1023

×
1024
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1025
                        if err != nil {
×
1026
                                return fmt.Errorf("unable to extract channel "+
×
1027
                                        "policies: %w", err)
×
1028
                        }
×
1029

×
1030
                        p1, p2, err := getAndBuildChanPolicies(
×
1031
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1032
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1033
                        )
×
1034
                        if err != nil {
×
1035
                                return fmt.Errorf("unable to build channel "+
×
1036
                                        "policies: %w", err)
×
1037
                        }
×
1038

×
1039
                        edgesSeen[chanIDInt] = struct{}{}
×
1040
                        chanEdge := ChannelEdge{
×
1041
                                Info:    channel,
×
1042
                                Policy1: p1,
×
1043
                                Policy2: p2,
1044
                                Node1:   node1,
1045
                                Node2:   node2,
1046
                        }
×
1047
                        edges = append(edges, chanEdge)
×
1048
                        edgesToCache[chanIDInt] = chanEdge
×
1049
                }
×
1050

×
1051
                return nil
×
1052
        }, func() {
×
1053
                edgesSeen = make(map[uint64]struct{})
×
1054
                edgesToCache = make(map[uint64]ChannelEdge)
×
1055
                edges = nil
×
1056
        })
×
1057
        if err != nil {
1058
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
1059
        }
×
1060

×
1061
        // Insert any edges loaded from disk into the cache.
×
1062
        for chanid, channel := range edgesToCache {
×
1063
                s.chanCache.insert(chanid, channel)
×
1064
        }
×
1065

×
1066
        if len(edges) > 0 {
×
1067
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1068
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1069
        } else {
1070
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
1071
                        "horizon (%s, %s)", startTime, endTime)
1072
        }
1073

×
1074
        return edges, nil
×
1075
}
×
1076

×
1077
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
×
1078
// data to the call-back.
×
1079
//
×
1080
// NOTE: The callback contents MUST not be modified.
×
1081
//
×
1082
// NOTE: part of the V1Store interface.
×
1083
func (s *SQLStore) ForEachNodeCached(ctx context.Context,
×
1084
        cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error,
1085
        reset func()) error {
1086

×
1087
        type nodeCachedBatchData struct {
×
1088
                features      map[int64][]int
×
1089
                chanBatchData *batchChannelData
×
1090
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1091
        }
×
1092

×
1093
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1094
                // pageQueryFunc is used to query the next page of nodes.
×
1095
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1096
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1097

×
1098
                        return db.ListNodeIDsAndPubKeys(
×
1099
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1100
                                        Version: int16(ProtocolV1),
×
1101
                                        ID:      lastID,
1102
                                        Limit:   limit,
1103
                                },
×
1104
                        )
×
1105
                }
×
1106

×
1107
                // batchDataFunc is then used to batch load the data required
×
1108
                // for each page of nodes.
×
1109
                batchDataFunc := func(ctx context.Context,
×
1110
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1111

×
1112
                        // Batch load node features.
×
1113
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1114
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1115
                        )
×
1116
                        if err != nil {
×
1117
                                return nil, fmt.Errorf("unable to batch load "+
×
1118
                                        "node features: %w", err)
1119
                        }
1120

1121
                        // Batch load ALL unique channels for ALL nodes in this
×
1122
                        // page.
×
1123
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1124
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1125
                                        Version:  int16(ProtocolV1),
×
1126
                                        Node1Ids: nodeIDs,
×
1127
                                        Node2Ids: nodeIDs,
×
1128
                                },
×
1129
                        )
1130
                        if err != nil {
1131
                                return nil, fmt.Errorf("unable to batch "+
1132
                                        "fetch channels for nodes: %w", err)
×
1133
                        }
×
1134

×
1135
                        // Deduplicate channels and collect IDs.
×
1136
                        var (
1137
                                allChannelIDs []int64
×
1138
                                allPolicyIDs  []int64
×
1139
                        )
×
1140
                        uniqueChannels := make(
×
1141
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1142
                        )
×
1143

×
1144
                        for _, channel := range allChannels {
×
1145
                                channelID := channel.GraphChannel.ID
×
1146

×
1147
                                // Only process each unique channel once.
×
1148
                                _, exists := uniqueChannels[channelID]
×
1149
                                if exists {
×
1150
                                        continue
×
1151
                                }
×
1152

×
1153
                                uniqueChannels[channelID] = channel
×
1154
                                allChannelIDs = append(allChannelIDs, channelID)
×
1155

1156
                                if channel.Policy1ID.Valid {
1157
                                        allPolicyIDs = append(
×
1158
                                                allPolicyIDs,
×
1159
                                                channel.Policy1ID.Int64,
×
1160
                                        )
×
1161
                                }
×
1162
                                if channel.Policy2ID.Valid {
×
1163
                                        allPolicyIDs = append(
1164
                                                allPolicyIDs,
1165
                                                channel.Policy2ID.Int64,
1166
                                        )
×
1167
                                }
×
1168
                        }
×
1169

×
1170
                        // Batch load channel data for all unique channels.
×
1171
                        channelBatchData, err := batchLoadChannelData(
×
1172
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1173
                                allPolicyIDs,
×
1174
                        )
×
1175
                        if err != nil {
×
1176
                                return nil, fmt.Errorf("unable to batch "+
×
1177
                                        "load channel data: %w", err)
1178
                        }
1179

×
1180
                        // Create map of node ID to channels that involve this
×
1181
                        // node.
×
1182
                        nodeIDSet := make(map[int64]bool)
×
1183
                        for _, nodeID := range nodeIDs {
×
1184
                                nodeIDSet[nodeID] = true
×
1185
                        }
×
1186

×
1187
                        nodeChannelMap := make(
1188
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
1189
                        )
×
1190
                        for _, channel := range uniqueChannels {
×
1191
                                // Add channel to both nodes if they're in our
×
1192
                                // current page.
×
1193
                                node1 := channel.GraphChannel.NodeID1
×
1194
                                if nodeIDSet[node1] {
×
1195
                                        nodeChannelMap[node1] = append(
×
1196
                                                nodeChannelMap[node1], channel,
×
1197
                                        )
×
1198
                                }
×
1199
                                node2 := channel.GraphChannel.NodeID2
1200
                                if nodeIDSet[node2] {
×
1201
                                        nodeChannelMap[node2] = append(
1202
                                                nodeChannelMap[node2], channel,
1203
                                        )
×
1204
                                }
×
1205
                        }
×
1206

×
1207
                        return &nodeCachedBatchData{
×
1208
                                features:      nodeFeatures,
×
1209
                                chanBatchData: channelBatchData,
×
1210
                                chanMap:       nodeChannelMap,
1211
                        }, nil
×
1212
                }
1213

1214
                // processItem is used to process each node in the current page.
×
1215
                processItem := func(ctx context.Context,
×
1216
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1217
                        batchData *nodeCachedBatchData) error {
×
1218

×
1219
                        // Build feature vector for this node.
1220
                        fv := lnwire.EmptyFeatureVector()
×
1221
                        features, exists := batchData.features[nodeData.ID]
×
1222
                        if exists {
×
1223
                                for _, bit := range features {
×
1224
                                        fv.Set(lnwire.FeatureBit(bit))
1225
                                }
1226
                        }
1227

1228
                        var nodePub route.Vertex
1229
                        copy(nodePub[:], nodeData.PubKey)
1230

1231
                        nodeChannels := batchData.chanMap[nodeData.ID]
1232

1233
                        toNodeCallback := func() route.Vertex {
1234
                                return nodePub
1235
                        }
1236

1237
                        // Build cached channels map for this node.
1238
                        channels := make(map[uint64]*DirectedChannel)
1239
                        for _, channelRow := range nodeChannels {
1240
                                directedChan, err := buildDirectedChannel(
1241
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
1242
                                        channelRow, batchData.chanBatchData, fv,
1243
                                        toNodeCallback,
×
1244
                                )
×
1245
                                if err != nil {
×
1246
                                        return err
×
1247
                                }
×
1248

×
1249
                                channels[directedChan.ChannelID] = directedChan
×
1250
                        }
×
1251

×
1252
                        return cb(nodePub, channels)
×
1253
                }
×
1254

×
1255
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1256
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
1257
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1258
                                return node.ID
×
1259
                        },
×
1260
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
×
1261
                                error) {
×
1262

×
1263
                                return node.ID, nil
×
1264
                        },
×
1265
                        batchDataFunc, processItem,
1266
                )
×
1267
        }, reset)
×
1268
}
×
1269

×
1270
// ForEachChannelCacheable iterates through all the channel edges stored
×
1271
// within the graph and invokes the passed callback for each edge. The
×
1272
// callback takes two edges as since this is a directed graph, both the
1273
// in/out edges are visited. If the callback returns an error, then the
×
1274
// transaction is aborted and the iteration stops early.
1275
//
1276
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
×
1277
// pointer for that particular channel edge routing policy will be
×
1278
// passed into the callback.
×
1279
//
×
1280
// NOTE: this method is like ForEachChannel but fetches only the data
×
1281
// required for the graph cache.
1282
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
×
1283
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
×
1284
        reset func()) error {
×
1285

×
1286
        ctx := context.TODO()
×
1287

×
1288
        handleChannel := func(_ context.Context,
×
1289
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1290

×
1291
                node1, node2, err := buildNodeVertices(
×
1292
                        row.Node1Pubkey, row.Node2Pubkey,
×
1293
                )
×
1294
                if err != nil {
×
1295
                        return err
×
1296
                }
1297

×
1298
                edge := buildCacheableChannelInfo(
×
1299
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1300
                )
×
1301

1302
                dbPol1, dbPol2, err := extractChannelPolicies(row)
1303
                if err != nil {
1304
                        return err
1305
                }
1306

1307
                pol1, pol2, err := buildCachedChanPolicies(
1308
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
1309
                )
1310
                if err != nil {
1311
                        return err
1312
                }
1313

1314
                return cb(edge, pol1, pol2)
1315
        }
1316

1317
        extractCursor := func(
×
1318
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1319

×
1320
                return row.ID
×
1321
        }
×
1322

1323
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
1324
                //nolint:ll
1325
                queryFunc := func(ctx context.Context, lastID int64,
1326
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
1327
                        error) {
1328

1329
                        return db.ListChannelsWithPoliciesForCachePaginated(
1330
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
1331
                                        Version: int16(ProtocolV1),
1332
                                        ID:      lastID,
1333
                                        Limit:   limit,
1334
                                },
×
1335
                        )
×
1336
                }
×
1337

×
1338
                return sqldb.ExecutePaginatedQuery(
×
1339
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1340
                        extractCursor, handleChannel,
×
1341
                )
×
1342
        }, reset)
×
1343
}
×
1344

×
1345
// ForEachChannel iterates through all the channel edges stored within the
×
1346
// graph and invokes the passed callback for each edge. The callback takes two
×
1347
// edges as since this is a directed graph, both the in/out edges are visited.
×
1348
// If the callback returns an error, then the transaction is aborted and the
×
1349
// iteration stops early.
×
1350
//
×
1351
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
×
1352
// for that particular channel edge routing policy will be passed into the
×
1353
// callback.
×
1354
//
×
1355
// NOTE: part of the V1Store interface.
×
1356
func (s *SQLStore) ForEachChannel(ctx context.Context,
×
1357
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
×
1358
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1359

×
1360
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1361
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1362
        }, reset)
×
1363
}
×
1364

×
1365
// FilterChannelRange returns the channel ID's of all known channels which were
×
1366
// mined in a block height within the passed range. The channel IDs are grouped
×
1367
// by their common block height. This method can be used to quickly share with a
1368
// peer the set of channels we know of within a particular range to catch them
×
1369
// up after a period of time offline. If withTimestamps is true then the
×
1370
// timestamp info of the latest received channel update messages of the channel
×
1371
// will be included in the response.
×
1372
//
×
1373
// NOTE: This is part of the V1Store interface.
×
1374
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
×
1375
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1376

×
1377
        var (
×
1378
                ctx       = context.TODO()
×
1379
                startSCID = &lnwire.ShortChannelID{
×
1380
                        BlockHeight: startHeight,
×
1381
                }
×
1382
                endSCID = lnwire.ShortChannelID{
×
1383
                        BlockHeight: endHeight,
1384
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
1385
                        TxPosition:  math.MaxUint16,
1386
                }
×
1387
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1388
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1389
        )
×
1390

×
1391
        // 1) get all channels where channelID is between start and end chan ID.
×
1392
        // 2) skip if not public (ie, no channel_proof)
×
1393
        // 3) collect that channel.
×
1394
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1395
        //    and add those timestamps to the collected channel.
×
1396
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1397
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1398
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1399
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1400
                                StartScid: chanIDStart,
×
1401
                                EndScid:   chanIDEnd,
1402
                        },
1403
                )
×
1404
                if err != nil {
×
1405
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1406
                                err)
×
1407
                }
×
1408

×
1409
                for _, dbChan := range dbChans {
×
1410
                        cid := lnwire.NewShortChanIDFromInt(
×
1411
                                byteOrder.Uint64(dbChan.Scid),
×
1412
                        )
×
1413
                        chanInfo := NewChannelUpdateInfo(
×
1414
                                cid, time.Time{}, time.Time{},
×
1415
                        )
×
1416

×
1417
                        if !withTimestamps {
×
1418
                                channelsPerBlock[cid.BlockHeight] = append(
1419
                                        channelsPerBlock[cid.BlockHeight],
×
1420
                                        chanInfo,
×
1421
                                )
×
1422

1423
                                continue
1424
                        }
×
1425

×
1426
                        //nolint:ll
×
1427
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1428
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1429
                                        Version:   int16(ProtocolV1),
×
1430
                                        ChannelID: dbChan.ID,
×
1431
                                        NodeID:    dbChan.NodeID1,
1432
                                },
×
1433
                        )
×
1434
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1435
                                return fmt.Errorf("unable to fetch node1 "+
1436
                                        "policy: %w", err)
1437
                        } else if err == nil {
×
1438
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1439
                                        node1Policy.LastUpdate.Int64, 0,
×
1440
                                )
×
1441
                        }
×
1442

×
1443
                        //nolint:ll
×
1444
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1445
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1446
                                        Version:   int16(ProtocolV1),
1447
                                        ChannelID: dbChan.ID,
1448
                                        NodeID:    dbChan.NodeID2,
1449
                                },
1450
                        )
1451
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
1452
                                return fmt.Errorf("unable to fetch node2 "+
1453
                                        "policy: %w", err)
1454
                        } else if err == nil {
×
1455
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1456
                                        node2Policy.LastUpdate.Int64, 0,
×
1457
                                )
×
1458
                        }
×
1459

×
1460
                        channelsPerBlock[cid.BlockHeight] = append(
×
1461
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1462
                        )
×
1463
                }
×
1464

×
1465
                return nil
×
1466
        }, func() {
×
1467
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1468
        })
×
1469
        if err != nil {
×
1470
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1471
        }
×
1472

×
1473
        if len(channelsPerBlock) == 0 {
×
1474
                return nil, nil
×
1475
        }
×
1476

×
1477
        // Return the channel ranges in ascending block height order.
1478
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1479
        slices.Sort(blocks)
×
1480

×
1481
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1482
                return BlockChannelRange{
1483
                        Height:   block,
1484
                        Channels: channelsPerBlock[block],
1485
                }
1486
        }), nil
1487
}
×
1488

×
1489
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
×
1490
// zombie. This method is used on an ad-hoc basis, when channels need to be
×
1491
// marked as zombies outside the normal pruning cycle.
×
1492
//
×
1493
// NOTE: part of the V1Store interface.
×
1494
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
×
1495
        pubKey1, pubKey2 [33]byte) error {
×
1496

×
1497
        ctx := context.TODO()
×
1498

×
1499
        s.cacheMu.Lock()
×
1500
        defer s.cacheMu.Unlock()
×
1501

×
1502
        chanIDB := channelIDToBytes(chanID)
×
1503

×
1504
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1505
                return db.UpsertZombieChannel(
×
1506
                        ctx, sqlc.UpsertZombieChannelParams{
×
1507
                                Version:  int16(ProtocolV1),
1508
                                Scid:     chanIDB,
×
1509
                                NodeKey1: pubKey1[:],
×
1510
                                NodeKey2: pubKey2[:],
×
1511
                        },
×
1512
                )
1513
        }, sqldb.NoOpReset)
×
1514
        if err != nil {
×
1515
                return fmt.Errorf("unable to upsert zombie channel "+
×
1516
                        "(channel_id=%d): %w", chanID, err)
×
1517
        }
×
1518

×
1519
        s.rejectCache.remove(chanID)
1520
        s.chanCache.remove(chanID)
×
1521

1522
        return nil
×
1523
}
×
1524

×
1525
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
×
1526
//
1527
// NOTE: part of the V1Store interface.
×
1528
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1529
        s.cacheMu.Lock()
×
1530
        defer s.cacheMu.Unlock()
×
1531

1532
        var (
1533
                ctx     = context.TODO()
1534
                chanIDB = channelIDToBytes(chanID)
1535
        )
1536

1537
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
1538
                res, err := db.DeleteZombieChannel(
1539
                        ctx, sqlc.DeleteZombieChannelParams{
×
1540
                                Scid:    chanIDB,
×
1541
                                Version: int16(ProtocolV1),
×
1542
                        },
×
1543
                )
×
1544
                if err != nil {
×
1545
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1546
                                err)
×
1547
                }
×
1548

×
1549
                rows, err := res.RowsAffected()
×
1550
                if err != nil {
×
1551
                        return err
×
1552
                }
×
1553

×
1554
                if rows == 0 {
×
1555
                        return ErrZombieEdgeNotFound
×
1556
                } else if rows > 1 {
×
1557
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1558
                                "expected 1", rows)
×
1559
                }
×
1560

×
1561
                return nil
×
1562
        }, sqldb.NoOpReset)
1563
        if err != nil {
×
1564
                return fmt.Errorf("unable to mark edge live "+
×
1565
                        "(channel_id=%d): %w", chanID, err)
×
1566
        }
×
1567

×
1568
        s.rejectCache.remove(chanID)
1569
        s.chanCache.remove(chanID)
×
1570

×
1571
        return err
×
1572
}
×
1573

×
1574
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1575
// zombie, then the two node public keys corresponding to this edge are also
×
1576
// returned.
1577
//
1578
// NOTE: part of the V1Store interface.
1579
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1580
        error) {
1581

×
1582
        var (
×
1583
                ctx              = context.TODO()
×
1584
                isZombie         bool
×
1585
                pubKey1, pubKey2 route.Vertex
×
1586
                chanIDB          = channelIDToBytes(chanID)
×
1587
        )
×
1588

×
1589
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1590
                zombie, err := db.GetZombieChannel(
×
1591
                        ctx, sqlc.GetZombieChannelParams{
×
1592
                                Scid:    chanIDB,
1593
                                Version: int16(ProtocolV1),
×
1594
                        },
×
1595
                )
×
1596
                if errors.Is(err, sql.ErrNoRows) {
1597
                        return nil
×
1598
                }
×
1599
                if err != nil {
×
1600
                        return fmt.Errorf("unable to fetch zombie channel: %w",
1601
                                err)
×
1602
                }
1603

1604
                copy(pubKey1[:], zombie.NodeKey1)
1605
                copy(pubKey2[:], zombie.NodeKey2)
1606
                isZombie = true
1607

1608
                return nil
1609
        }, sqldb.NoOpReset)
1610
        if err != nil {
1611
                return false, route.Vertex{}, route.Vertex{},
1612
                        fmt.Errorf("%w: %w (chanID=%d)",
1613
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
1614
        }
1615

×
1616
        return isZombie, pubKey1, pubKey2, nil
×
1617
}
×
1618

×
1619
// NumZombies returns the current number of zombie channels in the graph.
×
1620
//
×
1621
// NOTE: part of the V1Store interface.
×
1622
func (s *SQLStore) NumZombies() (uint64, error) {
×
1623
        var (
×
1624
                ctx        = context.TODO()
×
1625
                numZombies uint64
×
1626
        )
1627
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1628
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1629
                if err != nil {
×
1630
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1631
                                err)
×
1632
                }
×
1633

×
1634
                numZombies = uint64(count)
×
1635

×
1636
                return nil
×
1637
        }, sqldb.NoOpReset)
×
1638
        if err != nil {
×
1639
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1640
        }
×
1641

×
1642
        return numZombies, nil
×
1643
}
×
1644

×
1645
// DeleteChannelEdges removes edges with the given channel IDs from the
×
1646
// database and marks them as zombies. This ensures that we're unable to re-add
×
1647
// it to our database once again. If an edge does not exist within the
1648
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
×
1649
// true, then when we mark these edges as zombies, we'll set up the keys such
×
1650
// that we require the node that failed to send the fresh update to be the one
×
1651
// that resurrects the channel from its zombie state. The markZombie bool
×
1652
// denotes whether to mark the channel as a zombie.
×
1653
//
×
1654
// NOTE: part of the V1Store interface.
×
1655
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1656
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1657

×
1658
        s.cacheMu.Lock()
×
1659
        defer s.cacheMu.Unlock()
×
1660

×
1661
        // Keep track of which channels we end up finding so that we can
×
1662
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1663
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1664
        for _, chanID := range chanIDs {
1665
                chanLookup[chanID] = struct{}{}
×
1666
        }
×
1667

×
1668
        var (
×
1669
                ctx     = context.TODO()
×
1670
                deleted []*models.ChannelEdgeInfo
×
1671
        )
×
1672
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1673
                chanIDsToDelete := make([]int64, 0, len(chanIDs))
×
1674
                chanCallBack := func(ctx context.Context,
×
1675
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1676

×
1677
                        // Deleting the entry from the map indicates that we
×
1678
                        // have found the channel.
×
1679
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1680
                        delete(chanLookup, scid)
×
1681

1682
                        node1, node2, err := buildNodeVertices(
×
1683
                                row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1684
                        )
×
1685
                        if err != nil {
1686
                                return err
1687
                        }
×
1688

×
1689
                        info, err := getAndBuildEdgeInfo(
×
1690
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
1691
                                node1, node2,
×
1692
                        )
×
1693
                        if err != nil {
×
1694
                                return err
×
1695
                        }
×
1696

×
1697
                        deleted = append(deleted, info)
×
1698
                        chanIDsToDelete = append(
×
1699
                                chanIDsToDelete, row.GraphChannel.ID,
1700
                        )
×
1701

1702
                        if !markZombie {
1703
                                return nil
×
1704
                        }
×
1705

×
1706
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1707
                                info.NodeKey2Bytes
×
1708
                        if strictZombiePruning {
×
1709
                                var e1UpdateTime, e2UpdateTime *time.Time
1710
                                if row.Policy1LastUpdate.Valid {
×
1711
                                        e1Time := time.Unix(
×
1712
                                                row.Policy1LastUpdate.Int64, 0,
×
1713
                                        )
1714
                                        e1UpdateTime = &e1Time
×
1715
                                }
×
1716
                                if row.Policy2LastUpdate.Valid {
×
1717
                                        e2Time := time.Unix(
×
1718
                                                row.Policy2LastUpdate.Int64, 0,
×
1719
                                        )
×
1720
                                        e2UpdateTime = &e2Time
×
1721
                                }
×
1722

1723
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1724
                                        info, e1UpdateTime, e2UpdateTime,
×
1725
                                )
×
1726
                        }
×
1727

1728
                        err = db.UpsertZombieChannel(
×
1729
                                ctx, sqlc.UpsertZombieChannelParams{
×
1730
                                        Version:  int16(ProtocolV1),
×
1731
                                        Scid:     channelIDToBytes(scid),
×
1732
                                        NodeKey1: nodeKey1[:],
1733
                                        NodeKey2: nodeKey2[:],
×
1734
                                },
1735
                        )
1736
                        if err != nil {
1737
                                return fmt.Errorf("unable to mark channel as "+
1738
                                        "zombie: %w", err)
1739
                        }
1740

1741
                        return nil
1742
                }
1743

1744
                err := s.forEachChanWithPoliciesInSCIDList(
1745
                        ctx, db, chanCallBack, chanIDs,
1746
                )
1747
                if err != nil {
1748
                        return err
1749
                }
×
1750

×
1751
                if len(chanLookup) > 0 {
×
1752
                        return ErrEdgeNotFound
×
1753
                }
×
1754

×
1755
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1756
        }, func() {
×
1757
                deleted = nil
×
1758

×
1759
                // Re-fill the lookup map.
×
1760
                for _, chanID := range chanIDs {
×
1761
                        chanLookup[chanID] = struct{}{}
×
1762
                }
×
1763
        })
×
1764
        if err != nil {
×
1765
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1766
                        err)
×
1767
        }
×
1768

×
1769
        for _, chanID := range chanIDs {
×
1770
                s.rejectCache.remove(chanID)
×
1771
                s.chanCache.remove(chanID)
×
1772
        }
×
1773

×
1774
        return deleted, nil
×
1775
}
×
1776

×
1777
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
×
1778
// channel identified by the channel ID. If the channel can't be found, then
×
1779
// ErrEdgeNotFound is returned. A struct which houses the general information
1780
// for the channel itself is returned as well as two structs that contain the
1781
// routing policies for the channel in either direction.
1782
//
1783
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1784
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1785
// the ChannelEdgeInfo will only include the public keys of each node.
×
1786
//
×
1787
// NOTE: part of the V1Store interface.
×
1788
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
×
1789
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
×
1790
        *models.ChannelEdgePolicy, error) {
×
1791

×
1792
        var (
×
1793
                ctx              = context.TODO()
1794
                edge             *models.ChannelEdgeInfo
×
1795
                policy1, policy2 *models.ChannelEdgePolicy
×
1796
                chanIDB          = channelIDToBytes(chanID)
×
1797
        )
×
1798
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1799
                row, err := db.GetChannelBySCIDWithPolicies(
×
1800
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
1801
                                Scid:    chanIDB,
×
1802
                                Version: int16(ProtocolV1),
×
1803
                        },
×
1804
                )
×
1805
                if errors.Is(err, sql.ErrNoRows) {
×
1806
                        // First check if this edge is perhaps in the zombie
×
1807
                        // index.
×
1808
                        zombie, err := db.GetZombieChannel(
×
1809
                                ctx, sqlc.GetZombieChannelParams{
1810
                                        Scid:    chanIDB,
×
1811
                                        Version: int16(ProtocolV1),
×
1812
                                },
×
1813
                        )
×
1814
                        if errors.Is(err, sql.ErrNoRows) {
×
1815
                                return ErrEdgeNotFound
1816
                        } else if err != nil {
×
1817
                                return fmt.Errorf("unable to check if "+
×
1818
                                        "channel is zombie: %w", err)
×
1819
                        }
×
1820

×
1821
                        // At this point, we know the channel is a zombie, so
×
1822
                        // we'll return an error indicating this, and we will
×
1823
                        // populate the edge info with the public keys of each
1824
                        // party as this is the only information we have about
×
1825
                        // it.
1826
                        edge = &models.ChannelEdgeInfo{}
×
1827
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1828
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1829

×
1830
                        return ErrZombieEdge
×
1831
                } else if err != nil {
×
1832
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1833
                }
1834

×
1835
                node1, node2, err := buildNodeVertices(
1836
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
1837
                )
1838
                if err != nil {
1839
                        return err
1840
                }
1841

1842
                edge, err = getAndBuildEdgeInfo(
1843
                        ctx, db, s.cfg.ChainHash, row.GraphChannel, node1,
1844
                        node2,
1845
                )
1846
                if err != nil {
×
1847
                        return fmt.Errorf("unable to build channel info: %w",
×
1848
                                err)
×
1849
                }
×
1850

×
1851
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1852
                if err != nil {
×
1853
                        return fmt.Errorf("unable to extract channel "+
×
1854
                                "policies: %w", err)
×
1855
                }
×
1856

×
1857
                policy1, policy2, err = getAndBuildChanPolicies(
×
1858
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1859
                )
×
1860
                if err != nil {
×
1861
                        return fmt.Errorf("unable to build channel "+
×
1862
                                "policies: %w", err)
×
1863
                }
×
1864

×
1865
                return nil
1866
        }, sqldb.NoOpReset)
×
1867
        if err != nil {
×
1868
                // If we are returning the ErrZombieEdge, then we also need to
×
1869
                // return the edge info as the method comment indicates that
×
1870
                // this will be populated when the edge is a zombie.
×
1871
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1872
                        err)
1873
        }
×
1874

×
1875
        return edge, policy1, policy2, nil
×
1876
}
×
1877

×
1878
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
×
1879
// the channel identified by the funding outpoint. If the channel can't be
×
1880
// found, then ErrEdgeNotFound is returned. A struct which houses the general
×
1881
// information for the channel itself is returned as well as two structs that
1882
// contain the routing policies for the channel in either direction.
×
1883
//
×
1884
// NOTE: part of the V1Store interface.
×
1885
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
×
1886
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
×
1887
        *models.ChannelEdgePolicy, error) {
1888

×
1889
        var (
×
1890
                ctx              = context.TODO()
×
1891
                edge             *models.ChannelEdgeInfo
×
1892
                policy1, policy2 *models.ChannelEdgePolicy
×
1893
        )
×
1894
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1895
                row, err := db.GetChannelByOutpointWithPolicies(
1896
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1897
                                Outpoint: op.String(),
1898
                                Version:  int16(ProtocolV1),
×
1899
                        },
×
1900
                )
×
1901
                if errors.Is(err, sql.ErrNoRows) {
×
1902
                        return ErrEdgeNotFound
1903
                } else if err != nil {
×
1904
                        return fmt.Errorf("unable to fetch channel: %w", err)
1905
                }
1906

1907
                node1, node2, err := buildNodeVertices(
1908
                        row.Node1Pubkey, row.Node2Pubkey,
1909
                )
1910
                if err != nil {
1911
                        return err
1912
                }
1913

1914
                edge, err = getAndBuildEdgeInfo(
1915
                        ctx, db, s.cfg.ChainHash, row.GraphChannel, node1,
×
1916
                        node2,
×
1917
                )
×
1918
                if err != nil {
×
1919
                        return fmt.Errorf("unable to build channel info: %w",
×
1920
                                err)
×
1921
                }
×
1922

×
1923
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1924
                if err != nil {
×
1925
                        return fmt.Errorf("unable to extract channel "+
×
1926
                                "policies: %w", err)
×
1927
                }
×
1928

×
1929
                policy1, policy2, err = getAndBuildChanPolicies(
×
1930
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1931
                )
×
1932
                if err != nil {
×
1933
                        return fmt.Errorf("unable to build channel "+
×
1934
                                "policies: %w", err)
×
1935
                }
×
1936

×
1937
                return nil
×
1938
        }, sqldb.NoOpReset)
×
1939
        if err != nil {
×
1940
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1941
                        err)
×
1942
        }
×
1943

×
1944
        return edge, policy1, policy2, nil
×
1945
}
×
1946

×
1947
// HasChannelEdge returns true if the database knows of a channel edge with the
×
1948
// passed channel ID, and false otherwise. If an edge with that ID is found
×
1949
// within the graph, then two time stamps representing the last time the edge
×
1950
// was updated for both directed edges are returned along with the boolean. If
×
1951
// it is not found, then the zombie index is checked and its result is returned
×
1952
// as the second boolean.
1953
//
×
1954
// NOTE: part of the V1Store interface.
×
1955
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
×
1956
        bool, error) {
×
1957

×
1958
        ctx := context.TODO()
×
1959

×
1960
        var (
×
1961
                exists          bool
×
1962
                isZombie        bool
×
1963
                node1LastUpdate time.Time
×
1964
                node2LastUpdate time.Time
×
1965
        )
×
1966

×
1967
        // We'll query the cache with the shared lock held to allow multiple
×
1968
        // readers to access values in the cache concurrently if they exist.
×
1969
        s.cacheMu.RLock()
×
1970
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1971
                s.cacheMu.RUnlock()
×
1972
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1973
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
1974
                exists, isZombie = entry.flags.unpack()
×
1975

×
1976
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1977
        }
×
1978
        s.cacheMu.RUnlock()
1979

×
1980
        s.cacheMu.Lock()
×
1981
        defer s.cacheMu.Unlock()
×
1982

×
1983
        // The item was not found with the shared lock, so we'll acquire the
×
1984
        // exclusive lock and check the cache again in case another method added
×
1985
        // the entry to the cache while no lock was held.
×
1986
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1987
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1988
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1989
                exists, isZombie = entry.flags.unpack()
×
1990

×
1991
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1992
        }
×
1993

×
1994
        chanIDB := channelIDToBytes(chanID)
1995
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1996
                channel, err := db.GetChannelBySCID(
×
1997
                        ctx, sqlc.GetChannelBySCIDParams{
×
1998
                                Scid:    chanIDB,
×
1999
                                Version: int16(ProtocolV1),
×
2000
                        },
×
2001
                )
×
2002
                if errors.Is(err, sql.ErrNoRows) {
×
2003
                        // Check if it is a zombie channel.
×
2004
                        isZombie, err = db.IsZombieChannel(
×
2005
                                ctx, sqlc.IsZombieChannelParams{
×
2006
                                        Scid:    chanIDB,
×
2007
                                        Version: int16(ProtocolV1),
×
2008
                                },
2009
                        )
×
2010
                        if err != nil {
2011
                                return fmt.Errorf("could not check if channel "+
×
2012
                                        "is zombie: %w", err)
×
2013
                        }
×
2014

×
2015
                        return nil
2016
                } else if err != nil {
×
2017
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2018
                }
×
2019

×
2020
                exists = true
×
2021

×
2022
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2023
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
2024
                                Version:   int16(ProtocolV1),
2025
                                ChannelID: channel.ID,
2026
                                NodeID:    channel.NodeID1,
2027
                        },
2028
                )
2029
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
2030
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2031
                                err)
×
2032
                } else if err == nil {
×
2033
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2034
                }
×
2035

×
2036
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2037
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2038
                                Version:   int16(ProtocolV1),
×
2039
                                ChannelID: channel.ID,
×
2040
                                NodeID:    channel.NodeID2,
×
2041
                        },
×
2042
                )
×
2043
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2044
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2045
                                err)
×
2046
                } else if err == nil {
×
2047
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2048
                }
2049

×
2050
                return nil
×
2051
        }, sqldb.NoOpReset)
×
2052
        if err != nil {
2053
                return time.Time{}, time.Time{}, false, false,
×
2054
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2055
        }
×
2056

2057
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2058
                upd1Time: node1LastUpdate.Unix(),
2059
                upd2Time: node2LastUpdate.Unix(),
2060
                flags:    packRejectFlags(exists, isZombie),
2061
        })
2062

2063
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
2064
}
2065

×
2066
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
×
2067
// passed channel point (outpoint). If the passed channel doesn't exist within
×
2068
// the database, then ErrEdgeNotFound is returned.
×
2069
//
×
2070
// NOTE: part of the V1Store interface.
×
2071
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2072
        var (
×
2073
                ctx       = context.TODO()
×
2074
                channelID uint64
×
2075
        )
×
2076
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2077
                chanID, err := db.GetSCIDByOutpoint(
×
2078
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2079
                                Outpoint: chanPoint.String(),
2080
                                Version:  int16(ProtocolV1),
×
2081
                        },
2082
                )
2083
                if errors.Is(err, sql.ErrNoRows) {
2084
                        return ErrEdgeNotFound
2085
                } else if err != nil {
2086
                        return fmt.Errorf("unable to fetch channel ID: %w",
2087
                                err)
2088
                }
2089

2090
                channelID = byteOrder.Uint64(chanID)
×
2091

×
2092
                return nil
×
2093
        }, sqldb.NoOpReset)
×
2094
        if err != nil {
×
2095
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2096
        }
×
2097

×
2098
        return channelID, nil
×
2099
}
×
2100

×
2101
// IsPublicNode is a helper method that determines whether the node with the
×
2102
// given public key is seen as a public node in the graph from the graph's
×
2103
// source node's point of view.
×
2104
//
×
2105
// NOTE: part of the V1Store interface.
×
2106
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
2107
        ctx := context.TODO()
×
2108

×
2109
        var isPublic bool
×
2110
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2111
                var err error
×
2112
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2113

×
2114
                return err
×
2115
        }, sqldb.NoOpReset)
2116
        if err != nil {
×
2117
                return false, fmt.Errorf("unable to check if node is "+
×
2118
                        "public: %w", err)
×
2119
        }
×
2120

×
2121
        return isPublic, nil
2122
}
×
2123

×
2124
// FetchChanInfos returns the set of channel edges that correspond to the passed
×
2125
// channel ID's. If an edge is the query is unknown to the database, it will
×
2126
// skipped and the result will contain only those edges that exist at the time
×
2127
// of the query. This can be used to respond to peer queries that are seeking to
×
2128
// fill in gaps in their view of the channel graph.
×
2129
//
×
2130
// NOTE: part of the V1Store interface.
2131
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2132
        var (
×
2133
                ctx   = context.TODO()
×
2134
                edges = make(map[uint64]ChannelEdge)
×
2135
        )
×
2136
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2137
                chanCallBack := func(ctx context.Context,
×
2138
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2139

×
2140
                        node1, node2, err := buildNodes(
2141
                                ctx, db, row.GraphNode, row.GraphNode_2,
2142
                        )
×
2143
                        if err != nil {
×
2144
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2145
                                        err)
×
2146
                        }
×
2147

×
2148
                        edge, err := getAndBuildEdgeInfo(
×
2149
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
2150
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2151
                        )
2152
                        if err != nil {
×
2153
                                return fmt.Errorf("unable to build "+
×
2154
                                        "channel info: %w", err)
×
2155
                        }
×
2156

×
2157
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
2158
                        if err != nil {
2159
                                return fmt.Errorf("unable to extract channel "+
×
2160
                                        "policies: %w", err)
2161
                        }
2162

×
2163
                        p1, p2, err := getAndBuildChanPolicies(
2164
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
2165
                                node1.PubKeyBytes, node2.PubKeyBytes,
2166
                        )
2167
                        if err != nil {
2168
                                return fmt.Errorf("unable to build channel "+
2169
                                        "policies: %w", err)
2170
                        }
2171

×
2172
                        edges[edge.ChannelID] = ChannelEdge{
×
2173
                                Info:    edge,
×
2174
                                Policy1: p1,
×
2175
                                Policy2: p2,
×
2176
                                Node1:   node1,
×
2177
                                Node2:   node2,
×
2178
                        }
×
2179

×
2180
                        return nil
×
2181
                }
×
2182

×
2183
                return s.forEachChanWithPoliciesInSCIDList(
×
2184
                        ctx, db, chanCallBack, chanIDs,
2185
                )
×
2186
        }, func() {
×
2187
                clear(edges)
×
2188
        })
×
2189
        if err != nil {
2190
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
2191
        }
2192

2193
        res := make([]ChannelEdge, 0, len(edges))
2194
        for _, chanID := range chanIDs {
2195
                edge, ok := edges[chanID]
2196
                if !ok {
2197
                        continue
2198
                }
2199

2200
                res = append(res, edge)
×
2201
        }
×
2202

×
2203
        return res, nil
×
2204
}
×
2205

×
2206
// forEachChanWithPoliciesInSCIDList is a wrapper around the
×
2207
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
×
2208
// channels in a paginated manner.
×
2209
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
×
2210
        db SQLQueries, cb func(ctx context.Context,
×
2211
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
×
2212
        chanIDs []uint64) error {
×
2213

×
2214
        queryWrapper := func(ctx context.Context,
×
2215
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2216
                error) {
×
2217

2218
                return db.GetChannelsBySCIDWithPolicies(
×
2219
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2220
                                Version: int16(ProtocolV1),
×
2221
                                Scids:   scids,
×
2222
                        },
×
2223
                )
×
2224
        }
×
2225

×
2226
        return sqldb.ExecuteBatchQuery(
×
2227
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2228
                cb,
×
2229
        )
2230
}
×
2231

×
2232
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
×
2233
// ID's that we don't know and are not known zombies of the passed set. In other
×
2234
// words, we perform a set difference of our set of chan ID's and the ones
×
2235
// passed in. This method can be used by callers to determine the set of
2236
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2237
// known zombies is also returned.
2238
//
2239
// NOTE: part of the V1Store interface.
2240
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
×
2241
        []ChannelUpdateInfo, error) {
×
2242

×
2243
        var (
×
2244
                ctx          = context.TODO()
2245
                newChanIDs   []uint64
2246
                knownZombies []ChannelUpdateInfo
×
2247
                infoLookup   = make(
×
2248
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2249
                )
×
2250
        )
×
2251

×
2252
        // We first build a lookup map of the channel ID's to the
×
2253
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2254
        // already know about.
×
2255
        for _, chanInfo := range chansInfo {
×
2256
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
2257
        }
×
2258

×
2259
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2260
                // The call-back function deletes known channels from
×
2261
                // infoLookup, so that we can later check which channels are
2262
                // zombies by only looking at the remaining channels in the set.
2263
                cb := func(ctx context.Context,
×
2264
                        channel sqlc.GraphChannel) error {
2265

2266
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2267

×
2268
                        return nil
×
2269
                }
×
2270

×
2271
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2272
                if err != nil {
×
2273
                        return fmt.Errorf("unable to iterate through "+
×
2274
                                "channels: %w", err)
×
2275
                }
2276

×
2277
                // We want to ensure that we deal with the channels in the
×
2278
                // same order that they were passed in, so we iterate over the
×
2279
                // original chansInfo slice and then check if that channel is
2280
                // still in the infoLookup map.
×
2281
                for _, chanInfo := range chansInfo {
2282
                        channelID := chanInfo.ShortChannelID.ToUint64()
2283
                        if _, ok := infoLookup[channelID]; !ok {
2284
                                continue
2285
                        }
2286

2287
                        isZombie, err := db.IsZombieChannel(
2288
                                ctx, sqlc.IsZombieChannelParams{
2289
                                        Scid:    channelIDToBytes(channelID),
×
2290
                                        Version: int16(ProtocolV1),
×
2291
                                },
×
2292
                        )
×
2293
                        if err != nil {
×
2294
                                return fmt.Errorf("unable to fetch zombie "+
×
2295
                                        "channel: %w", err)
×
2296
                        }
×
2297

×
2298
                        if isZombie {
×
2299
                                knownZombies = append(knownZombies, chanInfo)
×
2300

×
2301
                                continue
2302
                        }
×
2303

×
2304
                        newChanIDs = append(newChanIDs, channelID)
×
2305
                }
×
2306

×
2307
                return nil
2308
        }, func() {
×
2309
                newChanIDs = nil
×
2310
                knownZombies = nil
×
2311
                // Rebuild the infoLookup map in case of a rollback.
×
2312
                for _, chanInfo := range chansInfo {
2313
                        scid := chanInfo.ShortChannelID.ToUint64()
2314
                        infoLookup[scid] = chanInfo
2315
                }
2316
        })
2317
        if err != nil {
2318
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
2319
        }
2320

2321
        return newChanIDs, knownZombies, nil
2322
}
2323

×
2324
// forEachChanInSCIDList is a helper method that executes a paged query
×
2325
// against the database to fetch all channels that match the passed
×
2326
// ChannelUpdateInfo slice. The callback function is called for each channel
×
2327
// that is found.
×
2328
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
×
2329
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
×
2330
        chansInfo []ChannelUpdateInfo) error {
×
2331

×
2332
        queryWrapper := func(ctx context.Context,
×
2333
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2334

×
2335
                return db.GetChannelsBySCIDs(
×
2336
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2337
                                Version: int16(ProtocolV1),
×
2338
                                Scids:   scids,
2339
                        },
×
2340
                )
2341
        }
2342

2343
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
2344
                channelID := chanInfo.ShortChannelID.ToUint64()
2345

2346
                return channelIDToBytes(channelID)
2347
        }
2348

2349
        return sqldb.ExecuteBatchQuery(
2350
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
2351
                cb,
2352
        )
2353
}
2354

×
2355
// PruneGraphNodes is a garbage collection method which attempts to prune out
×
2356
// any nodes from the channel graph that are currently unconnected. This ensure
×
2357
// that we only maintain a graph of reachable nodes. In the event that a pruned
×
2358
// node gains more channels, it will be re-added back to the graph.
×
2359
//
×
2360
// NOTE: this prunes nodes across protocol versions. It will never prune the
×
2361
// source nodes.
×
2362
//
×
2363
// NOTE: part of the V1Store interface.
×
2364
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2365
        var ctx = context.TODO()
×
2366

×
2367
        var prunedNodes []route.Vertex
×
2368
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2369
                var err error
×
2370
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2371

×
2372
                return err
×
2373
        }, func() {
×
2374
                prunedNodes = nil
×
2375
        })
×
2376
        if err != nil {
×
2377
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2378
        }
2379

×
2380
        return prunedNodes, nil
×
2381
}
×
2382

×
2383
// PruneGraph prunes newly closed channels from the channel graph in response
×
2384
// to a new block being solved on the network. Any transactions which spend the
×
2385
// funding output of any known channels within he graph will be deleted.
×
2386
// Additionally, the "prune tip", or the last block which has been used to
2387
// prune the graph is stored so callers can ensure the graph is fully in sync
×
2388
// with the current UTXO state. A slice of channels that have been closed by
×
2389
// the target block along with any pruned nodes are returned if the function
×
2390
// succeeds without error.
×
2391
//
×
2392
// NOTE: part of the V1Store interface.
×
2393
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2394
        blockHash *chainhash.Hash, blockHeight uint32) (
2395
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2396

×
2397
        ctx := context.TODO()
×
2398

×
2399
        s.cacheMu.Lock()
×
2400
        defer s.cacheMu.Unlock()
×
2401

×
2402
        var (
2403
                closedChans []*models.ChannelEdgeInfo
×
2404
                prunedNodes []route.Vertex
×
2405
        )
×
2406
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2407
                var chansToDelete []int64
2408

×
2409
                // Define the callback function for processing each channel.
×
2410
                channelCallback := func(ctx context.Context,
×
2411
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2412

×
2413
                        node1, node2, err := buildNodeVertices(
×
2414
                                row.Node1Pubkey, row.Node2Pubkey,
×
2415
                        )
×
2416
                        if err != nil {
×
2417
                                return err
×
2418
                        }
2419

2420
                        info, err := getAndBuildEdgeInfo(
2421
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
2422
                                node1, node2,
×
2423
                        )
×
2424
                        if err != nil {
×
2425
                                return err
×
2426
                        }
2427

×
2428
                        closedChans = append(closedChans, info)
×
2429
                        chansToDelete = append(
×
2430
                                chansToDelete, row.GraphChannel.ID,
×
2431
                        )
×
2432

×
2433
                        return nil
×
2434
                }
×
2435

2436
                err := s.forEachChanInOutpoints(
×
2437
                        ctx, db, spentOutputs, channelCallback,
×
2438
                )
×
2439
                if err != nil {
×
2440
                        return fmt.Errorf("unable to fetch channels by "+
2441
                                "outpoints: %w", err)
×
2442
                }
2443

2444
                err = s.deleteChannels(ctx, db, chansToDelete)
2445
                if err != nil {
2446
                        return fmt.Errorf("unable to delete channels: %w", err)
2447
                }
2448

2449
                err = db.UpsertPruneLogEntry(
2450
                        ctx, sqlc.UpsertPruneLogEntryParams{
2451
                                BlockHash:   blockHash[:],
×
2452
                                BlockHeight: int64(blockHeight),
×
2453
                        },
×
2454
                )
×
2455
                if err != nil {
×
2456
                        return fmt.Errorf("unable to insert prune log "+
×
2457
                                "entry: %w", err)
×
2458
                }
×
2459

×
2460
                // Now that we've pruned some channels, we'll also prune any
×
2461
                // nodes that no longer have any channels.
2462
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
2463
                if err != nil {
×
2464
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2465
                                err)
×
2466
                }
2467

×
2468
                return nil
×
2469
        }, func() {
×
2470
                prunedNodes = nil
×
2471
                closedChans = nil
2472
        })
2473
        if err != nil {
2474
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2475
        }
×
2476

×
2477
        for _, channel := range closedChans {
×
2478
                s.rejectCache.remove(channel.ChannelID)
×
2479
                s.chanCache.remove(channel.ChannelID)
×
2480
        }
×
2481

2482
        return closedChans, prunedNodes, nil
×
2483
}
×
2484

×
2485
// forEachChanInOutpoints is a helper function that executes a paginated
2486
// query to fetch channels by their outpoints and applies the given call-back
×
2487
// to each.
×
2488
//
×
2489
// NOTE: this fetches channels for all protocol versions.
×
2490
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
×
2491
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2492
                row sqlc.GetChannelsByOutpointsRow) error) error {
2493

2494
        // Create a wrapper that uses the transaction's db instance to execute
2495
        // the query.
2496
        queryWrapper := func(ctx context.Context,
2497
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
2498
                error) {
2499

2500
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2501
        }
×
2502

×
2503
        // Define the conversion function from Outpoint to string.
×
2504
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2505
                return outpoint.String()
×
2506
        }
×
2507

×
2508
        return sqldb.ExecuteBatchQuery(
×
2509
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2510
                queryWrapper, cb,
×
2511
        )
×
2512
}
×
2513

×
2514
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
×
2515
        dbIDs []int64) error {
×
2516

2517
        // Create a wrapper that uses the transaction's db instance to execute
×
2518
        // the query.
×
2519
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2520
                return nil, db.DeleteChannels(ctx, ids)
×
2521
        }
2522

×
2523
        idConverter := func(id int64) int64 {
×
2524
                return id
×
2525
        }
×
2526

×
2527
        return sqldb.ExecuteBatchQuery(
×
2528
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
2529
                queryWrapper, func(ctx context.Context, _ any) error {
2530
                        return nil
×
2531
                },
×
2532
        )
×
2533
}
×
2534

×
2535
// ChannelView returns the verifiable edge information for each active channel
×
2536
// within the known channel graph. The set of UTXOs (along with their scripts)
×
2537
// returned are the ones that need to be watched on chain to detect channel
×
2538
// closes on the resident blockchain.
×
2539
//
×
2540
// NOTE: part of the V1Store interface.
×
2541
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
2542
        var (
×
2543
                ctx        = context.TODO()
×
2544
                edgePoints []EdgePoint
×
2545
        )
2546

×
2547
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2548
                handleChannel := func(_ context.Context,
×
2549
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2550

×
2551
                        pkScript, err := genMultiSigP2WSH(
×
2552
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2553
                        )
×
2554
                        if err != nil {
×
2555
                                return err
×
2556
                        }
2557

×
2558
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
2559
                        if err != nil {
2560
                                return err
2561
                        }
2562

2563
                        edgePoints = append(edgePoints, EdgePoint{
2564
                                FundingPkScript: pkScript,
2565
                                OutPoint:        *op,
2566
                        })
×
2567

×
2568
                        return nil
×
2569
                }
×
2570

×
2571
                queryFunc := func(ctx context.Context, lastID int64,
×
2572
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2573

×
2574
                        return db.ListChannelsPaginated(
×
2575
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2576
                                        Version: int16(ProtocolV1),
×
2577
                                        ID:      lastID,
×
2578
                                        Limit:   limit,
×
2579
                                },
2580
                        )
×
2581
                }
×
2582

×
2583
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2584
                        return row.ID
2585
                }
×
2586

×
2587
                return sqldb.ExecutePaginatedQuery(
×
2588
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
2589
                        extractCursor, handleChannel,
×
2590
                )
2591
        }, func() {
2592
                edgePoints = nil
2593
        })
2594
        if err != nil {
2595
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
2596
        }
2597

×
2598
        return edgePoints, nil
×
2599
}
×
2600

×
2601
// PruneTip returns the block height and hash of the latest block that has been
×
2602
// used to prune channels in the graph. Knowing the "prune tip" allows callers
×
2603
// to tell if the graph is currently in sync with the current best known UTXO
×
2604
// state.
2605
//
×
2606
// NOTE: part of the V1Store interface.
×
2607
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2608
        var (
×
2609
                ctx       = context.TODO()
×
2610
                tipHash   chainhash.Hash
×
2611
                tipHeight uint32
×
2612
        )
2613
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2614
                pruneTip, err := db.GetPruneTip(ctx)
2615
                if errors.Is(err, sql.ErrNoRows) {
2616
                        return ErrGraphNeverPruned
×
2617
                } else if err != nil {
2618
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
2619
                }
2620

2621
                tipHash = chainhash.Hash(pruneTip.BlockHash)
2622
                tipHeight = uint32(pruneTip.BlockHeight)
2623

2624
                return nil
2625
        }, sqldb.NoOpReset)
2626
        if err != nil {
2627
                return nil, 0, err
2628
        }
2629

×
2630
        return &tipHash, tipHeight, nil
×
2631
}
×
2632

×
2633
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
×
2634
//
×
2635
// NOTE: this prunes nodes across protocol versions. It will never prune the
×
2636
// source nodes.
×
2637
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
×
2638
        db SQLQueries) ([]route.Vertex, error) {
×
2639

×
2640
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2641
        if err != nil {
×
2642
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2643
                        "nodes: %w", err)
×
2644
        }
×
2645

×
2646
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2647
        for i, nodeKey := range nodeKeys {
×
2648
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2649
                if err != nil {
×
2650
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2651
                                "from bytes: %w", err)
×
2652
                }
×
2653

×
2654
                prunedNodes[i] = pub
×
2655
        }
×
2656

×
2657
        return prunedNodes, nil
×
2658
}
×
2659

×
2660
// DisconnectBlockAtHeight is used to indicate that the block specified
2661
// by the passed height has been disconnected from the main chain. This
×
2662
// will "rewind" the graph back to the height below, deleting channels
×
2663
// that are no longer confirmed from the graph. The prune log will be
×
2664
// set to the last prune height valid for the remaining chain.
×
2665
// Channels that were removed from the graph resulting from the
×
2666
// disconnected block are returned.
×
2667
//
×
2668
// NOTE: part of the V1Store interface.
×
2669
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2670
        []*models.ChannelEdgeInfo, error) {
×
2671

×
2672
        ctx := context.TODO()
×
2673

×
2674
        var (
×
2675
                // Every channel having a ShortChannelID starting at 'height'
×
2676
                // will no longer be confirmed.
×
2677
                startShortChanID = lnwire.ShortChannelID{
2678
                        BlockHeight: height,
×
2679
                }
×
2680

2681
                // Delete everything after this height from the db up until the
2682
                // SCID alias range.
×
2683
                endShortChanID = aliasmgr.StartingAlias
×
2684

×
2685
                removedChans []*models.ChannelEdgeInfo
×
2686

2687
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2688
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2689
        )
×
2690

×
2691
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2692
                rows, err := db.GetChannelsBySCIDRange(
×
2693
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2694
                                StartScid: chanIDStart,
×
2695
                                EndScid:   chanIDEnd,
×
2696
                        },
×
2697
                )
×
2698
                if err != nil {
×
2699
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2700
                }
2701

×
2702
                chanIDsToDelete := make([]int64, len(rows))
×
2703
                for i, row := range rows {
×
2704
                        node1, node2, err := buildNodeVertices(
×
2705
                                row.Node1PubKey, row.Node2PubKey,
2706
                        )
×
2707
                        if err != nil {
2708
                                return err
2709
                        }
2710

2711
                        channel, err := getAndBuildEdgeInfo(
2712
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
2713
                                node1, node2,
×
2714
                        )
×
2715
                        if err != nil {
×
2716
                                return err
×
2717
                        }
×
2718

×
2719
                        chanIDsToDelete[i] = row.GraphChannel.ID
×
2720
                        removedChans = append(removedChans, channel)
×
2721
                }
×
2722

×
2723
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2724
                if err != nil {
×
2725
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2726
                }
×
2727

×
2728
                return db.DeletePruneLogEntriesInRange(
×
2729
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2730
                                StartHeight: int64(height),
×
2731
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2732
                        },
×
2733
                )
2734
        }, func() {
×
2735
                removedChans = nil
×
2736
        })
×
2737
        if err != nil {
×
2738
                return nil, fmt.Errorf("unable to disconnect block at "+
2739
                        "height: %w", err)
×
2740
        }
×
2741

×
2742
        for _, channel := range removedChans {
×
2743
                s.rejectCache.remove(channel.ChannelID)
×
2744
                s.chanCache.remove(channel.ChannelID)
×
2745
        }
×
2746

×
2747
        return removedChans, nil
2748
}
×
2749

2750
// AddEdgeProof sets the proof of an existing edge in the graph database.
×
2751
//
×
2752
// NOTE: part of the V1Store interface.
×
2753
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2754
        proof *models.ChannelAuthProof) error {
×
2755

2756
        var (
2757
                ctx       = context.TODO()
2758
                scidBytes = channelIDToBytes(scid.ToUint64())
2759
        )
2760

2761
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
2762
                res, err := db.AddV1ChannelProof(
×
2763
                        ctx, sqlc.AddV1ChannelProofParams{
×
2764
                                Scid:              scidBytes,
×
2765
                                Node1Signature:    proof.NodeSig1Bytes,
×
2766
                                Node2Signature:    proof.NodeSig2Bytes,
×
2767
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2768
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2769
                        },
×
2770
                )
×
2771
                if err != nil {
2772
                        return fmt.Errorf("unable to add edge proof: %w", err)
2773
                }
2774

2775
                n, err := res.RowsAffected()
2776
                if err != nil {
2777
                        return err
×
2778
                }
×
2779

×
2780
                if n == 0 {
×
2781
                        return fmt.Errorf("no rows affected when adding edge "+
×
2782
                                "proof for SCID %v", scid)
×
2783
                } else if n > 1 {
×
2784
                        return fmt.Errorf("multiple rows affected when adding "+
×
2785
                                "edge proof for SCID %v: %d rows affected",
×
2786
                                scid, n)
×
2787
                }
×
2788

×
2789
                return nil
×
2790
        }, sqldb.NoOpReset)
2791
        if err != nil {
×
2792
                return fmt.Errorf("unable to add edge proof: %w", err)
2793
        }
×
2794

×
2795
        return nil
×
2796
}
×
2797

2798
// PutClosedScid stores a SCID for a closed channel in the database. This is so
×
2799
// that we can ignore channel announcements that we know to be closed without
2800
// having to validate them and fetch a block.
2801
//
2802
// NOTE: part of the V1Store interface.
2803
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
2804
        var (
2805
                ctx     = context.TODO()
2806
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2807
        )
×
2808

×
2809
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2810
                return db.InsertClosedChannel(ctx, chanIDB)
×
2811
        }, sqldb.NoOpReset)
×
2812
}
×
2813

2814
// IsClosedScid checks whether a channel identified by the passed in scid is
2815
// closed. This helps avoid having to perform expensive validation checks.
2816
//
2817
// NOTE: part of the V1Store interface.
2818
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
2819
        var (
2820
                ctx      = context.TODO()
2821
                isClosed bool
2822
                chanIDB  = channelIDToBytes(scid.ToUint64())
2823
        )
2824
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
2825
                var err error
2826
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
2827
                if err != nil {
2828
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2829
                                err)
×
2830
                }
×
2831

×
2832
                return nil
×
2833
        }, sqldb.NoOpReset)
×
2834
        if err != nil {
×
2835
                return false, fmt.Errorf("unable to fetch closed channel: %w",
2836
                        err)
2837
        }
2838

2839
        return isClosed, nil
2840
}
2841

×
2842
// GraphSession will provide the call-back with access to a NodeTraverser
×
2843
// instance which can be used to perform queries against the channel graph.
×
2844
//
×
2845
// NOTE: part of the V1Store interface.
×
2846
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
×
2847
        reset func()) error {
2848

2849
        var ctx = context.TODO()
2850

2851
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
2852
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
2853
        }, reset)
×
2854
}
×
2855

×
2856
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
×
2857
// read only transaction for a consistent view of the graph.
×
2858
type sqlNodeTraverser struct {
×
2859
        db    SQLQueries
2860
        chain chainhash.Hash
2861
}
2862

2863
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2864
// NodeTraverser interface.
2865
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
×
2866

×
2867
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
×
2868
func newSQLNodeTraverser(db SQLQueries,
×
2869
        chain chainhash.Hash) *sqlNodeTraverser {
×
2870

2871
        return &sqlNodeTraverser{
×
2872
                db:    db,
×
2873
                chain: chain,
×
2874
        }
×
2875
}
×
2876

×
2877
// ForEachNodeDirectedChannel calls the callback for every channel of the given
×
2878
// node.
×
2879
//
×
2880
// NOTE: Part of the NodeTraverser interface.
×
2881
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
×
2882
        cb func(channel *DirectedChannel) error, _ func()) error {
2883

×
2884
        ctx := context.TODO()
×
2885

×
2886
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2887
}
×
2888

×
2889
// FetchNodeFeatures returns the features of the given node. If the node is
×
2890
// unknown, assume no additional features are supported.
×
2891
//
×
2892
// NOTE: Part of the NodeTraverser interface.
2893
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2894
        *lnwire.FeatureVector, error) {
2895

×
2896
        ctx := context.TODO()
×
2897

×
2898
        return fetchNodeFeatures(ctx, s.db, nodePub)
2899
}
×
2900

×
2901
// forEachNodeDirectedChannel iterates through all channels of a given
×
2902
// node, executing the passed callback on the directed edge representing the
×
2903
// channel and its incoming policy. If the node is not found, no error is
2904
// returned.
×
2905
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
×
2906
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2907

×
2908
        toNodeCallback := func() route.Vertex {
×
2909
                return nodePub
×
2910
        }
×
2911

×
2912
        dbID, err := db.GetNodeIDByPubKey(
2913
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2914
                        Version: int16(ProtocolV1),
×
2915
                        PubKey:  nodePub[:],
×
2916
                },
×
2917
        )
×
2918
        if errors.Is(err, sql.ErrNoRows) {
×
2919
                return nil
×
2920
        } else if err != nil {
×
2921
                return fmt.Errorf("unable to fetch node: %w", err)
×
2922
        }
2923

×
2924
        rows, err := db.ListChannelsByNodeID(
×
2925
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2926
                        Version: int16(ProtocolV1),
×
2927
                        NodeID1: dbID,
×
2928
                },
×
2929
        )
2930
        if err != nil {
2931
                return fmt.Errorf("unable to fetch channels: %w", err)
2932
        }
×
2933

×
2934
        // Exit early if there are no channels for this node so we don't
×
2935
        // do the unnecessary feature fetching.
×
2936
        if len(rows) == 0 {
×
2937
                return nil
×
2938
        }
2939

×
2940
        features, err := getNodeFeatures(ctx, db, dbID)
×
2941
        if err != nil {
×
2942
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2943
        }
×
2944

×
2945
        for _, row := range rows {
2946
                node1, node2, err := buildNodeVertices(
×
2947
                        row.Node1Pubkey, row.Node2Pubkey,
×
2948
                )
×
2949
                if err != nil {
×
2950
                        return fmt.Errorf("unable to build node vertices: %w",
×
2951
                                err)
×
2952
                }
×
2953

×
2954
                edge := buildCacheableChannelInfo(
×
2955
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
2956
                        node1, node2,
×
2957
                )
×
2958

2959
                dbPol1, dbPol2, err := extractChannelPolicies(row)
2960
                if err != nil {
×
2961
                        return err
×
2962
                }
×
2963

2964
                p1, p2, err := buildCachedChanPolicies(
×
2965
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
2966
                )
×
2967
                if err != nil {
2968
                        return err
2969
                }
×
2970

2971
                // Determine the outgoing and incoming policy for this
2972
                // channel and node combo.
2973
                outPolicy, inPolicy := p1, p2
2974
                if p1 != nil && node2 == nodePub {
2975
                        outPolicy, inPolicy = p2, p1
2976
                } else if p2 != nil && node1 != nodePub {
2977
                        outPolicy, inPolicy = p2, p1
×
2978
                }
×
2979

×
2980
                var cachedInPolicy *models.CachedEdgePolicy
×
2981
                if inPolicy != nil {
×
2982
                        cachedInPolicy = inPolicy
×
2983
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2984
                        cachedInPolicy.ToNodeFeatures = features
×
2985
                }
×
2986

×
2987
                directedChannel := &DirectedChannel{
×
2988
                        ChannelID:    edge.ChannelID,
2989
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
2990
                        OtherNode:    edge.NodeKey2Bytes,
×
2991
                        Capacity:     edge.Capacity,
×
2992
                        OutPolicySet: outPolicy != nil,
×
2993
                        InPolicy:     cachedInPolicy,
×
2994
                }
2995
                if outPolicy != nil {
2996
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2997
                                directedChannel.InboundFee = fee
×
2998
                        })
×
2999
                }
×
3000

×
3001
                if nodePub == edge.NodeKey2Bytes {
×
3002
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3003
                }
×
3004

×
3005
                if err := cb(directedChannel); err != nil {
×
3006
                        return err
×
3007
                }
3008
        }
×
3009

×
3010
        return nil
×
3011
}
3012

×
3013
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
×
3014
// and executes the provided callback for each node. It does so via pagination
×
3015
// along with batch loading of the node feature bits.
3016
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
×
3017
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
×
3018
                features *lnwire.FeatureVector) error) error {
×
3019

×
3020
        handleNode := func(_ context.Context,
×
3021
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
3022
                featureBits map[int64][]int) error {
×
3023

×
3024
                fv := lnwire.EmptyFeatureVector()
×
3025
                if features, exists := featureBits[dbNode.ID]; exists {
×
3026
                        for _, bit := range features {
3027
                                fv.Set(lnwire.FeatureBit(bit))
3028
                        }
3029
                }
3030

3031
                var pub route.Vertex
3032
                copy(pub[:], dbNode.PubKey)
3033

3034
                return processNode(dbNode.ID, pub, fv)
3035
        }
×
3036

×
3037
        queryFunc := func(ctx context.Context, lastID int64,
×
3038
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
3039

×
3040
                return db.ListNodeIDsAndPubKeys(
×
3041
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3042
                                Version: int16(ProtocolV1),
×
3043
                                ID:      lastID,
×
3044
                                Limit:   limit,
×
3045
                        },
×
3046
                )
×
3047
        }
3048

3049
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
3050
                return row.ID
×
3051
        }
×
3052

×
3053
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
3054
                return node.ID, nil
×
3055
        }
×
3056

×
3057
        batchQueryFunc := func(ctx context.Context,
×
3058
                nodeIDs []int64) (map[int64][]int, error) {
×
3059

×
3060
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
3061
        }
×
3062

3063
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
3064
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
3065
                batchQueryFunc, handleNode,
×
3066
        )
×
3067
}
×
3068

×
3069
// forEachNodeChannel iterates through all channels of a node, executing
×
3070
// the passed callback on each. The call-back is provided with the channel's
3071
// edge information, the outgoing policy and the incoming policy for the
3072
// channel and node combo.
×
3073
func forEachNodeChannel(ctx context.Context, db SQLQueries,
×
3074
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
×
3075
                *models.ChannelEdgePolicy,
×
3076
                *models.ChannelEdgePolicy) error) error {
×
3077

×
3078
        // Get all the V1 channels for this node.
×
3079
        rows, err := db.ListChannelsByNodeID(
×
3080
                ctx, sqlc.ListChannelsByNodeIDParams{
3081
                        Version: int16(ProtocolV1),
×
3082
                        NodeID1: id,
×
3083
                },
×
3084
        )
×
3085
        if err != nil {
×
3086
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3087
        }
×
3088

×
3089
        // Collect all the channel and policy IDs.
3090
        var (
×
3091
                chanIDs   = make([]int64, 0, len(rows))
×
3092
                policyIDs = make([]int64, 0, 2*len(rows))
×
3093
        )
×
3094
        for _, row := range rows {
×
3095
                chanIDs = append(chanIDs, row.GraphChannel.ID)
3096

×
3097
                if row.Policy1ID.Valid {
×
3098
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3099
                }
×
3100
                if row.Policy2ID.Valid {
×
3101
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3102
                }
×
3103
        }
3104

3105
        batchData, err := batchLoadChannelData(
3106
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3107
        )
×
3108
        if err != nil {
×
3109
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3110
        }
×
3111

×
3112
        // Call the call-back for each channel and its known policies.
×
3113
        for _, row := range rows {
×
3114
                node1, node2, err := buildNodeVertices(
3115
                        row.Node1Pubkey, row.Node2Pubkey,
×
3116
                )
×
3117
                if err != nil {
×
3118
                        return fmt.Errorf("unable to build node vertices: %w",
3119
                                err)
3120
                }
×
3121

3122
                edge, err := buildEdgeInfoWithBatchData(
3123
                        cfg.ChainHash, row.GraphChannel, node1, node2,
3124
                        batchData,
3125
                )
3126
                if err != nil {
3127
                        return fmt.Errorf("unable to build channel info: %w",
×
3128
                                err)
×
3129
                }
×
3130

×
3131
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3132
                if err != nil {
×
3133
                        return fmt.Errorf("unable to extract channel "+
×
3134
                                "policies: %w", err)
×
3135
                }
×
3136

×
3137
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3138
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3139
                )
×
3140
                if err != nil {
×
3141
                        return fmt.Errorf("unable to build channel "+
×
3142
                                "policies: %w", err)
×
3143
                }
×
3144

×
3145
                // Determine the outgoing and incoming policy for this
×
3146
                // channel and node combo.
×
3147
                p1ToNode := row.GraphChannel.NodeID2
×
3148
                p2ToNode := row.GraphChannel.NodeID1
×
3149
                outPolicy, inPolicy := p1, p2
×
3150
                if (p1 != nil && p1ToNode == id) ||
×
3151
                        (p2 != nil && p2ToNode != id) {
3152

×
3153
                        outPolicy, inPolicy = p2, p1
×
3154
                }
×
3155

×
3156
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3157
                        return err
×
3158
                }
×
3159
        }
×
3160

×
3161
        return nil
3162
}
×
3163

×
3164
// updateChanEdgePolicy upserts the channel policy info we have stored for
×
3165
// a channel we already know of.
×
3166
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
×
3167
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
×
3168
        error) {
×
3169

×
3170
        var (
3171
                node1Pub, node2Pub route.Vertex
×
3172
                isNode1            bool
×
3173
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3174
        )
×
3175

×
3176
        // Check that this edge policy refers to a channel that we already
×
3177
        // know of. We do this explicitly so that we can return the appropriate
×
3178
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3179
        // abort the transaction which would abort the entire batch.
×
3180
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3181
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3182
                        Scid:    chanIDB,
×
3183
                        Version: int16(ProtocolV1),
×
3184
                },
×
3185
        )
×
3186
        if errors.Is(err, sql.ErrNoRows) {
×
3187
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3188
        } else if err != nil {
×
3189
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3190
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3191
        }
×
3192

×
3193
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3194
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3195

×
3196
        // Figure out which node this edge is from.
×
3197
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3198
        nodeID := dbChan.NodeID1
3199
        if !isNode1 {
3200
                nodeID = dbChan.NodeID2
3201
        }
×
3202

×
3203
        var (
×
3204
                inboundBase sql.NullInt64
×
3205
                inboundRate sql.NullInt64
×
3206
        )
3207
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
3208
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3209
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3210
        })
×
3211

×
3212
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3213
                Version:     int16(ProtocolV1),
3214
                ChannelID:   dbChan.ID,
×
3215
                NodeID:      nodeID,
3216
                Timelock:    int32(edge.TimeLockDelta),
3217
                FeePpm:      int64(edge.FeeProportionalMillionths),
3218
                BaseFeeMsat: int64(edge.FeeBaseMSat),
3219
                MinHtlcMsat: int64(edge.MinHTLC),
×
3220
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3221
                Disabled: sql.NullBool{
×
3222
                        Valid: true,
×
3223
                        Bool:  edge.IsDisabled(),
×
3224
                },
×
3225
                MaxHtlcMsat: sql.NullInt64{
×
3226
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3227
                        Int64: int64(edge.MaxHTLC),
×
3228
                },
×
3229
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3230
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3231
                InboundBaseFeeMsat:      inboundBase,
×
3232
                InboundFeeRateMilliMsat: inboundRate,
3233
                Signature:               edge.SigBytes,
×
3234
        })
×
3235
        if err != nil {
×
3236
                return node1Pub, node2Pub, isNode1,
×
3237
                        fmt.Errorf("unable to upsert edge policy: %w", err)
3238
        }
×
3239

3240
        // Convert the flat extra opaque data into a map of TLV types to
3241
        // values.
3242
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
3243
        if err != nil {
3244
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3245
                        "marshal extra opaque data: %w", err)
×
3246
        }
×
3247

×
3248
        // Update the channel policy's extra signed fields.
×
3249
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3250
        if err != nil {
×
3251
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3252
                        "policy extra TLVs: %w", err)
×
3253
        }
3254

3255
        return node1Pub, node2Pub, isNode1, nil
3256
}
3257

3258
// getNodeByPubKey attempts to look up a target node by its public key.
×
3259
func getNodeByPubKey(ctx context.Context, db SQLQueries,
×
3260
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3261

×
3262
        dbNode, err := db.GetNodeByPubKey(
×
3263
                ctx, sqlc.GetNodeByPubKeyParams{
×
3264
                        Version: int16(ProtocolV1),
×
3265
                        PubKey:  pubKey[:],
×
3266
                },
×
3267
        )
×
3268
        if errors.Is(err, sql.ErrNoRows) {
×
3269
                return 0, nil, ErrGraphNodeNotFound
×
3270
        } else if err != nil {
3271
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3272
        }
3273

3274
        node, err := buildNode(ctx, db, &dbNode)
3275
        if err != nil {
3276
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
3277
        }
3278

3279
        return dbNode.ID, node, nil
×
3280
}
×
3281

×
3282
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
×
3283
// provided parameters.
×
3284
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
×
3285
        node2Pub route.Vertex) *models.CachedEdgeInfo {
3286

×
3287
        return &models.CachedEdgeInfo{
×
3288
                ChannelID:     byteOrder.Uint64(scid),
×
3289
                NodeKey1Bytes: node1Pub,
×
3290
                NodeKey2Bytes: node2Pub,
×
3291
                Capacity:      btcutil.Amount(capacity),
×
3292
        }
×
3293
}
×
3294

×
3295
// buildNode constructs a LightningNode instance from the given database node
×
3296
// record. The node's features, addresses and extra signed fields are also
×
3297
// fetched from the database and set on the node.
×
3298
func buildNode(ctx context.Context, db SQLQueries,
3299
        dbNode *sqlc.GraphNode) (*models.LightningNode, error) {
×
3300

×
3301
        // NOTE: buildNode is only used to load the data for a single node, and
×
3302
        // so no paged queries will be performed. This means that it's ok to
×
3303
        // used pass in default config values here.
×
3304
        cfg := sqldb.DefaultQueryConfig()
×
3305

×
3306
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3307
        if err != nil {
×
3308
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3309
                        err)
×
3310
        }
×
3311

3312
        return buildNodeWithBatchData(dbNode, data)
3313
}
3314

×
3315
// buildNodeWithBatchData builds a models.LightningNode instance
×
3316
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
×
3317
// features/addresses/extra fields, then the corresponding fields are expected
×
3318
// to be present in the batchNodeData.
×
3319
func buildNodeWithBatchData(dbNode *sqlc.GraphNode,
×
3320
        batchData *batchNodeData) (*models.LightningNode, error) {
3321

3322
        if dbNode.Version != int16(ProtocolV1) {
3323
                return nil, fmt.Errorf("unsupported node version: %d",
×
3324
                        dbNode.Version)
×
3325
        }
×
3326

×
3327
        var pub [33]byte
×
3328
        copy(pub[:], dbNode.PubKey)
×
3329

×
3330
        node := &models.LightningNode{
3331
                PubKeyBytes: pub,
3332
                Features:    lnwire.EmptyFeatureVector(),
3333
                LastUpdate:  time.Unix(0, 0),
×
3334
        }
×
3335

×
3336
        if len(dbNode.Signature) == 0 {
×
3337
                return node, nil
×
3338
        }
×
3339

×
3340
        node.HaveNodeAnnouncement = true
×
3341
        node.AuthSigBytes = dbNode.Signature
×
3342
        node.Alias = dbNode.Alias.String
3343
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
3344

×
3345
        var err error
3346
        if dbNode.Color.Valid {
3347
                node.Color, err = DecodeHexColor(dbNode.Color.String)
3348
                if err != nil {
3349
                        return nil, fmt.Errorf("unable to decode color: %w",
3350
                                err)
3351
                }
×
3352
        }
×
3353

×
3354
        // Use preloaded features.
×
3355
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3356
                fv := lnwire.EmptyFeatureVector()
×
3357
                for _, bit := range features {
×
3358
                        fv.Set(lnwire.FeatureBit(bit))
3359
                }
3360
                node.Features = fv
×
3361
        }
×
3362

×
3363
        // Use preloaded addresses.
×
3364
        addresses, exists := batchData.addresses[dbNode.ID]
3365
        if exists && len(addresses) > 0 {
×
3366
                node.Addresses, err = buildNodeAddresses(addresses)
×
3367
                if err != nil {
×
3368
                        return nil, fmt.Errorf("unable to build addresses "+
×
3369
                                "for node(%d): %w", dbNode.ID, err)
×
3370
                }
×
3371
        }
3372

×
3373
        // Use preloaded extra fields.
×
3374
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3375
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3376
                if err != nil {
3377
                        return nil, fmt.Errorf("unable to serialize extra "+
3378
                                "signed fields: %w", err)
×
3379
                }
3380
                if len(recs) != 0 {
3381
                        node.ExtraOpaqueData = recs
3382
                }
3383
        }
3384

×
3385
        return node, nil
×
3386
}
×
3387

×
3388
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
×
3389
// with the preloaded data, and executes the provided callback for each node.
×
3390
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
×
3391
        db SQLQueries, nodes []sqlc.GraphNode,
3392
        cb func(dbID int64, node *models.LightningNode) error) error {
×
3393

×
3394
        // Extract node IDs for batch loading.
×
3395
        nodeIDs := make([]int64, len(nodes))
×
3396
        for i, node := range nodes {
3397
                nodeIDs[i] = node.ID
×
3398
        }
3399

3400
        // Batch load all related data for this page.
3401
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
3402
        if err != nil {
3403
                return fmt.Errorf("unable to batch load node data: %w", err)
3404
        }
3405

×
3406
        for _, dbNode := range nodes {
×
3407
                node, err := buildNodeWithBatchData(&dbNode, batchData)
×
3408
                if err != nil {
×
3409
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3410
                                dbNode.ID, err)
×
3411
                }
×
3412

×
3413
                if err := cb(dbNode.ID, node); err != nil {
×
3414
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3415
                                dbNode.ID, err)
×
3416
                }
×
3417
        }
×
3418

3419
        return nil
×
3420
}
×
3421

×
3422
// getNodeFeatures fetches the feature bits and constructs the feature vector
×
3423
// for a node with the given DB ID.
×
3424
func getNodeFeatures(ctx context.Context, db SQLQueries,
3425
        nodeID int64) (*lnwire.FeatureVector, error) {
3426

×
3427
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3428
        if err != nil {
×
3429
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
3430
                        nodeID, err)
3431
        }
×
3432

×
3433
        features := lnwire.EmptyFeatureVector()
×
3434
        for _, feature := range rows {
×
3435
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
3436
        }
3437

×
3438
        return features, nil
×
3439
}
×
3440

×
3441
// upsertNode upserts the node record into the database. If the node already
3442
// exists, then the node's information is updated. If the node doesn't exist,
3443
// then a new node is created. The node's features, addresses and extra TLV
3444
// types are also updated. The node's DB ID is returned.
×
3445
func upsertNode(ctx context.Context, db SQLQueries,
×
3446
        node *models.LightningNode) (int64, error) {
×
3447

×
3448
        params := sqlc.UpsertNodeParams{
×
3449
                Version: int16(ProtocolV1),
3450
                PubKey:  node.PubKeyBytes[:],
3451
        }
×
3452

×
3453
        if node.HaveNodeAnnouncement {
×
3454
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3455
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
3456
                params.Alias = sqldb.SQLStr(node.Alias)
×
3457
                params.Signature = node.AuthSigBytes
3458
        }
3459

3460
        nodeID, err := db.UpsertNode(ctx, params)
3461
        if err != nil {
3462
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
3463
                        err)
3464
        }
×
3465

×
3466
        // We can exit here if we don't have the announcement yet.
×
3467
        if !node.HaveNodeAnnouncement {
×
3468
                return nodeID, nil
×
3469
        }
×
3470

×
3471
        // Update the node's features.
3472
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
3473
        if err != nil {
×
3474
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3475
        }
×
3476

×
3477
        // Update the node's addresses.
×
3478
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
3479
        if err != nil {
3480
                return 0, fmt.Errorf("inserting node addresses: %w", err)
3481
        }
3482

3483
        // Convert the flat extra opaque data into a map of TLV types to
×
3484
        // values.
×
3485
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3486
        if err != nil {
×
3487
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3488
                        err)
×
3489
        }
3490

3491
        // Update the node's extra signed fields.
3492
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
3493
        if err != nil {
×
3494
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3495
        }
×
3496

×
3497
        return nodeID, nil
×
3498
}
×
3499

×
3500
// upsertNodeFeatures updates the node's features node_features table. This
×
3501
// includes deleting any feature bits no longer present and inserting any new
×
3502
// feature bits. If the feature bit does not yet exist in the features table,
3503
// then an entry is created in that table first.
3504
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3505
        features *lnwire.FeatureVector) error {
3506

×
3507
        // Get any existing features for the node.
×
3508
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3509
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3510
                return err
×
3511
        }
×
3512

×
3513
        // Copy the nodes latest set of feature bits.
×
3514
        newFeatures := make(map[int32]struct{})
×
3515
        if features != nil {
3516
                for feature := range features.Features() {
3517
                        newFeatures[int32(feature)] = struct{}{}
×
3518
                }
3519
        }
3520

3521
        // For any current feature that already exists in the DB, remove it from
3522
        // the in-memory map. For any existing feature that does not exist in
×
3523
        // the in-memory map, delete it from the database.
×
3524
        for _, feature := range existingFeatures {
×
3525
                // The feature is still present, so there are no updates to be
×
3526
                // made.
×
3527
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3528
                        delete(newFeatures, feature.FeatureBit)
×
3529
                        continue
×
3530
                }
×
3531

×
3532
                // The feature is no longer present, so we remove it from the
×
3533
                // database.
×
3534
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
3535
                        NodeID:     nodeID,
×
3536
                        FeatureBit: feature.FeatureBit,
×
3537
                })
×
3538
                if err != nil {
×
3539
                        return fmt.Errorf("unable to delete node(%d) "+
3540
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3541
                                err)
3542
                }
3543
        }
3544

3545
        // Any remaining entries in newFeatures are new features that need to be
3546
        // added to the database for the first time.
3547
        for feature := range newFeatures {
3548
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
3549
                        NodeID:     nodeID,
3550
                        FeatureBit: feature,
3551
                })
3552
                if err != nil {
3553
                        return fmt.Errorf("unable to insert node(%d) "+
3554
                                "feature(%v): %w", nodeID, feature, err)
3555
                }
3556
        }
3557

3558
        return nil
3559
}
3560

3561
// fetchNodeFeatures fetches the features for a node with the given public key.
3562
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3563
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3564

×
3565
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3566
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3567
                        PubKey:  nodePub[:],
×
3568
                        Version: int16(ProtocolV1),
×
3569
                },
×
3570
        )
×
3571
        if err != nil {
×
3572
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3573
                        nodePub, err)
3574
        }
3575

×
3576
        features := lnwire.EmptyFeatureVector()
×
3577
        for _, bit := range rows {
×
3578
                features.Set(lnwire.FeatureBit(bit))
×
3579
        }
×
3580

×
3581
        return features, nil
×
3582
}
×
3583

×
3584
// dbAddressType is an enum type that represents the different address types
×
3585
// that we store in the node_addresses table. The address type determines how
3586
// the address is to be serialised/deserialize.
×
3587
type dbAddressType uint8
×
3588

×
3589
const (
×
3590
        addressTypeIPv4   dbAddressType = 1
×
3591
        addressTypeIPv6   dbAddressType = 2
×
3592
        addressTypeTorV2  dbAddressType = 3
×
3593
        addressTypeTorV3  dbAddressType = 4
×
3594
        addressTypeOpaque dbAddressType = math.MaxInt8
×
3595
)
×
3596

×
3597
// upsertNodeAddresses updates the node's addresses in the database. This
3598
// includes deleting any existing addresses and inserting the new set of
×
3599
// addresses. The deletion is necessary since the ordering of the addresses may
×
3600
// change, and we need to ensure that the database reflects the latest set of
×
3601
// addresses so that at the time of reconstructing the node announcement, the
×
3602
// order is preserved and the signature over the message remains valid.
×
3603
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
×
3604
        addresses []net.Addr) error {
×
3605

×
3606
        // Delete any existing addresses for the node. This is required since
×
3607
        // even if the new set of addresses is the same, the ordering may have
3608
        // changed for a given address type.
3609
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3610
        if err != nil {
×
3611
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
3612
                        nodeID, err)
×
3613
        }
×
3614

3615
        // Copy the nodes latest set of addresses.
3616
        newAddresses := map[dbAddressType][]string{
3617
                addressTypeIPv4:   {},
3618
                addressTypeIPv6:   {},
3619
                addressTypeTorV2:  {},
×
3620
                addressTypeTorV3:  {},
×
3621
                addressTypeOpaque: {},
×
3622
        }
×
3623
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3624
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3625
        }
×
3626

×
3627
        for _, address := range addresses {
×
3628
                switch addr := address.(type) {
×
3629
                case *net.TCPAddr:
×
3630
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3631
                                addAddr(addressTypeIPv4, addr)
×
3632
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3633
                                addAddr(addressTypeIPv6, addr)
×
3634
                        } else {
3635
                                return fmt.Errorf("unhandled IP address: %v",
3636
                                        addr)
3637
                        }
×
3638

3639
                case *tor.OnionAddr:
3640
                        switch len(addr.OnionService) {
3641
                        case tor.V2Len:
3642
                                addAddr(addressTypeTorV2, addr)
×
3643
                        case tor.V3Len:
×
3644
                                addAddr(addressTypeTorV3, addr)
×
3645
                        default:
×
3646
                                return fmt.Errorf("invalid length for a tor " +
×
3647
                                        "address")
×
3648
                        }
×
3649

×
3650
                case *lnwire.OpaqueAddrs:
3651
                        addAddr(addressTypeOpaque, addr)
×
3652

×
3653
                default:
×
3654
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3655
                }
×
3656
        }
×
3657

×
3658
        // Any remaining entries in newAddresses are new addresses that need to
×
3659
        // be added to the database for the first time.
×
3660
        for addrType, addrList := range newAddresses {
3661
                for position, addr := range addrList {
×
3662
                        err := db.InsertNodeAddress(
3663
                                ctx, sqlc.InsertNodeAddressParams{
3664
                                        NodeID:   nodeID,
3665
                                        Type:     int16(addrType),
3666
                                        Address:  addr,
×
3667
                                        Position: int32(position),
×
3668
                                },
×
3669
                        )
3670
                        if err != nil {
×
3671
                                return fmt.Errorf("unable to insert "+
3672
                                        "node(%d) address(%v): %w", nodeID,
3673
                                        addr, err)
3674
                        }
3675
                }
3676
        }
3677

×
3678
        return nil
×
3679
}
×
3680

×
3681
// getNodeAddresses fetches the addresses for a node with the given DB ID.
×
3682
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
×
3683
        error) {
×
3684

3685
        // GetNodeAddresses ensures that the addresses for a given type are
3686
        // returned in the same order as they were inserted.
3687
        rows, err := db.GetNodeAddresses(ctx, id)
×
3688
        if err != nil {
×
3689
                return nil, err
×
3690
        }
×
3691

3692
        addresses := make([]net.Addr, 0, len(rows))
3693
        for _, row := range rows {
3694
                address := row.Address
×
3695

×
3696
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
3697
                if err != nil {
×
3698
                        return nil, fmt.Errorf("unable to parse address "+
×
3699
                                "for node(%d): %v: %w", id, address, err)
×
3700
                }
×
3701

×
3702
                addresses = append(addresses, addr)
×
3703
        }
×
3704

×
3705
        // If we have no addresses, then we'll return nil instead of an
×
3706
        // empty slice.
3707
        if len(addresses) == 0 {
3708
                addresses = nil
3709
        }
×
3710

3711
        return addresses, nil
3712
}
3713

3714
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
×
3715
// database. This includes updating any existing types, inserting any new types,
×
3716
// and deleting any types that are no longer present.
×
3717
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
×
3718
        nodeID int64, extraFields map[uint64][]byte) error {
×
3719

×
3720
        // Get any existing extra signed fields for the node.
×
3721
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3722
        if err != nil {
×
3723
                return err
×
3724
        }
×
3725

3726
        // Make a lookup map of the existing field types so that we can use it
3727
        // to keep track of any fields we should delete.
×
3728
        m := make(map[uint64]bool)
3729
        for _, field := range existingFields {
3730
                m[uint64(field.Type)] = true
3731
        }
3732

3733
        // For all the new fields, we'll upsert them and remove them from the
3734
        // map of existing fields.
3735
        for tlvType, value := range extraFields {
3736
                err = db.UpsertNodeExtraType(
3737
                        ctx, sqlc.UpsertNodeExtraTypeParams{
3738
                                NodeID: nodeID,
3739
                                Type:   int64(tlvType),
3740
                                Value:  value,
3741
                        },
3742
                )
×
3743
                if err != nil {
×
3744
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3745
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3746
                }
×
3747

×
3748
                // Remove the field from the map of existing fields if it was
×
3749
                // present.
×
3750
                delete(m, tlvType)
×
3751
        }
×
3752

3753
        // For all the fields that are left in the map of existing fields, we'll
×
3754
        // delete them as they are no longer present in the new set of fields.
×
3755
        for tlvType := range m {
×
3756
                err = db.DeleteExtraNodeType(
×
3757
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3758
                                NodeID: nodeID,
×
3759
                                Type:   int64(tlvType),
×
3760
                        },
3761
                )
×
3762
                if err != nil {
×
3763
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3764
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3765
                }
×
3766
        }
×
3767

3768
        return nil
×
3769
}
×
3770

×
3771
// srcNodeInfo holds the information about the source node of the graph.
×
3772
type srcNodeInfo struct {
×
3773
        // id is the DB level ID of the source node entry in the "nodes" table.
×
3774
        id int64
×
3775

×
3776
        // pub is the public key of the source node.
3777
        pub route.Vertex
3778
}
3779

3780
// sourceNode returns the DB node ID and pub key of the source node for the
3781
// specified protocol version.
×
3782
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
×
3783
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3784

×
3785
        s.srcNodeMu.Lock()
×
3786
        defer s.srcNodeMu.Unlock()
×
3787

×
3788
        // If we already have the source node ID and pub key cached, then
3789
        // return them.
3790
        if info, ok := s.srcNodes[version]; ok {
3791
                return info.id, info.pub, nil
×
3792
        }
×
3793

×
3794
        var pubKey route.Vertex
×
3795

×
3796
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3797
        if err != nil {
×
3798
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
3799
                        err)
×
3800
        }
×
3801

×
3802
        if len(nodes) == 0 {
×
3803
                return 0, pubKey, ErrSourceNodeNotSet
3804
        } else if len(nodes) > 1 {
×
3805
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
3806
                        "protocol %s found", version)
3807
        }
3808

3809
        copy(pubKey[:], nodes[0].PubKey)
3810

3811
        s.srcNodes[version] = &srcNodeInfo{
3812
                id:  nodes[0].NodeID,
3813
                pub: pubKey,
3814
        }
3815

3816
        return nodes[0].NodeID, pubKey, nil
3817
}
×
3818

×
3819
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
×
3820
// This then produces a map from TLV type to value. If the input is not a
×
3821
// valid TLV stream, then an error is returned.
×
3822
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3823
        r := bytes.NewReader(data)
×
3824

×
3825
        tlvStream, err := tlv.NewStream()
×
3826
        if err != nil {
×
3827
                return nil, err
×
3828
        }
×
3829

×
3830
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
×
3831
        // pass it into the P2P decoding variant.
×
3832
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3833
        if err != nil {
×
3834
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3835
        }
×
3836
        if len(parsedTypes) == 0 {
3837
                return nil, nil
3838
        }
3839

×
3840
        records := make(map[uint64][]byte)
×
3841
        for k, v := range parsedTypes {
×
3842
                records[uint64(k)] = v
×
3843
        }
3844

×
3845
        return records, nil
×
3846
}
×
3847

×
3848
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3849
// channel.
×
3850
type dbChanInfo struct {
×
3851
        channelID int64
×
3852
        node1ID   int64
×
3853
        node2ID   int64
3854
}
×
3855

×
3856
// insertChannel inserts a new channel record into the database.
×
3857
func insertChannel(ctx context.Context, db SQLQueries,
×
3858
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3859

×
3860
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3861

×
3862
        // Make sure that the channel doesn't already exist. We do this
×
3863
        // explicitly instead of relying on catching a unique constraint error
×
3864
        // because relying on SQL to throw that error would abort the entire
×
3865
        // batch of transactions.
×
3866
        _, err := db.GetChannelBySCID(
×
3867
                ctx, sqlc.GetChannelBySCIDParams{
×
3868
                        Scid:    chanIDB,
×
3869
                        Version: int16(ProtocolV1),
×
3870
                },
×
3871
        )
×
3872
        if err == nil {
×
3873
                return nil, ErrEdgeAlreadyExist
3874
        } else if !errors.Is(err, sql.ErrNoRows) {
3875
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3876
        }
×
3877

×
3878
        // Make sure that at least a "shell" entry for each node is present in
×
3879
        // the nodes table.
3880
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
3881
        if err != nil {
×
3882
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3883
        }
×
3884

×
3885
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3886
        if err != nil {
×
3887
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3888
        }
×
3889

×
3890
        var capacity sql.NullInt64
×
3891
        if edge.Capacity != 0 {
×
3892
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
3893
        }
3894

3895
        createParams := sqlc.CreateChannelParams{
×
3896
                Version:     int16(ProtocolV1),
×
3897
                Scid:        chanIDB,
×
3898
                NodeID1:     node1DBID,
×
3899
                NodeID2:     node2DBID,
×
3900
                Outpoint:    edge.ChannelPoint.String(),
3901
                Capacity:    capacity,
×
3902
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3903
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3904
        }
×
3905

×
3906
        if edge.AuthProof != nil {
×
3907
                proof := edge.AuthProof
×
3908

×
3909
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3910
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3911
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3912
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3913
        }
×
3914

3915
        // Insert the new channel record.
3916
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3917
        if err != nil {
×
3918
                return nil, err
×
3919
        }
×
3920

×
3921
        // Insert any channel features.
3922
        for feature := range edge.Features.Features() {
3923
                err = db.InsertChannelFeature(
3924
                        ctx, sqlc.InsertChannelFeatureParams{
3925
                                ChannelID:  dbChanID,
3926
                                FeatureBit: int32(feature),
3927
                        },
3928
                )
×
3929
                if err != nil {
×
3930
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
3931
                                "feature(%v): %w", dbChanID, feature, err)
×
3932
                }
×
3933
        }
×
3934

×
3935
        // Finally, insert any extra TLV fields in the channel announcement.
×
3936
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3937
        if err != nil {
×
3938
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
3939
                        "data: %w", err)
×
3940
        }
×
3941

×
3942
        for tlvType, value := range extra {
3943
                err := db.CreateChannelExtraType(
3944
                        ctx, sqlc.CreateChannelExtraTypeParams{
3945
                                ChannelID: dbChanID,
×
3946
                                Type:      int64(tlvType),
×
3947
                                Value:     value,
×
3948
                        },
×
3949
                )
×
3950
                if err != nil {
×
3951
                        return nil, fmt.Errorf("unable to upsert "+
×
3952
                                "channel(%d) extra signed field(%v): %w",
3953
                                edge.ChannelID, tlvType, err)
×
3954
                }
3955
        }
3956

3957
        return &dbChanInfo{
3958
                channelID: dbChanID,
3959
                node1ID:   node1DBID,
3960
                node2ID:   node2DBID,
×
3961
        }, nil
×
3962
}
×
3963

×
3964
// maybeCreateShellNode checks if a shell node entry exists for the
×
3965
// given public key. If it does not exist, then a new shell node entry is
×
3966
// created. The ID of the node is returned. A shell node only has a protocol
×
3967
// version and public key persisted.
×
3968
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
×
3969
        pubKey route.Vertex) (int64, error) {
3970

3971
        dbNode, err := db.GetNodeByPubKey(
×
3972
                ctx, sqlc.GetNodeByPubKeyParams{
×
3973
                        PubKey:  pubKey[:],
×
3974
                        Version: int16(ProtocolV1),
×
3975
                },
×
3976
        )
×
3977
        // The node exists. Return the ID.
×
3978
        if err == nil {
×
3979
                return dbNode.ID, nil
×
3980
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3981
                return 0, err
×
3982
        }
×
3983

×
3984
        // Otherwise, the node does not exist, so we create a shell entry for
3985
        // it.
3986
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3987
                Version: int16(ProtocolV1),
3988
                PubKey:  pubKey[:],
3989
        })
3990
        if err != nil {
3991
                return 0, fmt.Errorf("unable to create shell node: %w", err)
3992
        }
3993

3994
        return id, nil
×
3995
}
×
3996

×
3997
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
×
3998
// the database. This includes deleting any existing types and then inserting
×
3999
// the new types.
×
4000
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
×
4001
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4002

×
4003
        // Delete all existing extra signed fields for the channel policy.
×
4004
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4005
        if err != nil {
×
4006
                return fmt.Errorf("unable to delete "+
4007
                        "existing policy extra signed fields for policy %d: %w",
×
4008
                        chanPolicyID, err)
4009
        }
4010

4011
        // Insert all new extra signed fields for the channel policy.
4012
        for tlvType, value := range extraFields {
4013
                err = db.InsertChanPolicyExtraType(
×
4014
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
4015
                                ChannelPolicyID: chanPolicyID,
×
4016
                                Type:            int64(tlvType),
×
4017
                                Value:           value,
×
4018
                        },
×
4019
                )
4020
                if err != nil {
4021
                        return fmt.Errorf("unable to insert "+
×
4022
                                "channel_policy(%d) extra signed field(%v): %w",
×
4023
                                chanPolicyID, tlvType, err)
×
4024
                }
×
4025
        }
×
4026

4027
        return nil
4028
}
×
4029

×
4030
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
×
4031
// provided dbChanRow and also fetches any other required information
×
4032
// to construct the edge info.
×
4033
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
×
4034
        chain chainhash.Hash, dbChan sqlc.GraphChannel, node1,
×
4035
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
4036

×
4037
        // NOTE: getAndBuildEdgeInfo is only used to load the data for a single
×
4038
        // edge, and so no paged queries will be performed. This means that
×
4039
        // it's ok to used pass in default config values here.
×
4040
        cfg := sqldb.DefaultQueryConfig()
4041

×
4042
        data, err := batchLoadChannelData(ctx, cfg, db, []int64{dbChan.ID}, nil)
×
4043
        if err != nil {
×
4044
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
4045
                        err)
×
4046
        }
×
4047

×
4048
        return buildEdgeInfoWithBatchData(chain, dbChan, node1, node2, data)
×
4049
}
4050

×
4051
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
×
4052
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
×
4053
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
×
4054
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
4055

×
4056
        if dbChan.Version != int16(ProtocolV1) {
×
4057
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4058
                        dbChan.Version)
×
4059
        }
×
4060

×
4061
        // Use pre-loaded features and extras types.
×
4062
        fv := lnwire.EmptyFeatureVector()
×
4063
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
4064
                for _, bit := range features {
×
4065
                        fv.Set(lnwire.FeatureBit(bit))
×
4066
                }
×
4067
        }
×
4068

×
4069
        var extras map[uint64][]byte
×
4070
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
4071
        if exists {
×
4072
                extras = channelExtras
×
4073
        } else {
×
4074
                extras = make(map[uint64][]byte)
×
4075
        }
×
4076

×
4077
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4078
        if err != nil {
4079
                return nil, err
×
4080
        }
4081

4082
        recs, err := lnwire.CustomRecords(extras).Serialize()
4083
        if err != nil {
4084
                return nil, fmt.Errorf("unable to serialize extra signed "+
4085
                        "fields: %w", err)
×
4086
        }
×
4087
        if recs == nil {
×
4088
                recs = make([]byte, 0)
×
4089
        }
×
4090

×
4091
        var btcKey1, btcKey2 route.Vertex
×
4092
        copy(btcKey1[:], dbChan.BitcoinKey1)
4093
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4094

×
4095
        channel := &models.ChannelEdgeInfo{
×
4096
                ChainHash:        chain,
×
4097
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4098
                NodeKey1Bytes:    node1,
4099
                NodeKey2Bytes:    node2,
×
4100
                BitcoinKey1Bytes: btcKey1,
4101
                BitcoinKey2Bytes: btcKey2,
4102
                ChannelPoint:     *op,
4103
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
4104
                Features:         fv,
4105
                ExtraOpaqueData:  recs,
4106
        }
4107

4108
        // We always set all the signatures at the same time, so we can
4109
        // safely check if one signature is present to determine if we have the
×
4110
        // rest of the signatures for the auth proof.
×
4111
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4112
                channel.AuthProof = &models.ChannelAuthProof{
×
4113
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4114
                        NodeSig2Bytes:    dbChan.Node2Signature,
4115
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4116
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4117
                }
×
4118
        }
×
4119

×
4120
        return channel, nil
×
4121
}
×
4122

4123
// buildNodeVertices is a helper that converts raw node public keys
4124
// into route.Vertex instances.
4125
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4126
        route.Vertex, error) {
4127

×
4128
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4129
        if err != nil {
×
4130
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4131
                        "create vertex from node1 pubkey: %w", err)
×
4132
        }
×
4133

×
4134
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
4135
        if err != nil {
×
4136
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4137
                        "create vertex from node2 pubkey: %w", err)
×
4138
        }
×
4139

×
4140
        return node1Vertex, node2Vertex, nil
×
4141
}
4142

×
4143
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
×
4144
// retrieves all the extra info required to build the complete
×
4145
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
×
4146
// the provided sqlc.GraphChannelPolicy records are nil.
×
4147
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
×
4148
        dbPol1, dbPol2 *sqlc.GraphChannelPolicy, channelID uint64, node1,
4149
        node2 route.Vertex) (*models.ChannelEdgePolicy,
×
4150
        *models.ChannelEdgePolicy, error) {
4151

4152
        if dbPol1 == nil && dbPol2 == nil {
4153
                return nil, nil, nil
4154
        }
4155

4156
        var policyIDs = make([]int64, 0, 2)
4157
        if dbPol1 != nil {
×
4158
                policyIDs = append(policyIDs, dbPol1.ID)
×
4159
        }
×
4160
        if dbPol2 != nil {
×
4161
                policyIDs = append(policyIDs, dbPol2.ID)
×
4162
        }
×
4163

×
4164
        // NOTE: getAndBuildChanPolicies is only used to load the data for
×
4165
        // a maximum of two policies, and so no paged queries will be
4166
        // performed (unless the page size is one). So it's ok to use
×
4167
        // the default config values here.
4168
        cfg := sqldb.DefaultQueryConfig()
×
4169

×
4170
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4171
        if err != nil {
×
4172
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4173
                        "data: %w", err)
4174
        }
×
4175

4176
        pol1, err := buildChanPolicyWithBatchData(
4177
                dbPol1, channelID, node2, batchData,
×
4178
        )
4179
        if err != nil {
4180
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
4181
        }
4182

4183
        pol2, err := buildChanPolicyWithBatchData(
4184
                dbPol2, channelID, node1, batchData,
×
4185
        )
×
4186
        if err != nil {
×
4187
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4188
        }
×
4189

×
4190
        return pol1, pol2, nil
×
4191
}
4192

×
4193
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
×
4194
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
×
4195
// then nil is returned for it.
×
4196
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
×
4197
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
×
4198
        *models.CachedEdgePolicy, error) {
×
4199

×
4200
        var p1, p2 *models.CachedEdgePolicy
×
4201
        if dbPol1 != nil {
4202
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
4203
                if err != nil {
×
4204
                        return nil, nil, err
×
4205
                }
×
4206

×
4207
                p1 = models.NewCachedPolicy(policy1)
×
4208
        }
×
4209
        if dbPol2 != nil {
×
4210
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4211
                if err != nil {
×
4212
                        return nil, nil, err
×
4213
                }
×
4214

×
4215
                p2 = models.NewCachedPolicy(policy2)
×
4216
        }
×
4217

×
4218
        return p1, p2, nil
×
4219
}
×
4220

×
4221
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
×
4222
// provided sqlc.GraphChannelPolicy and other required information.
×
4223
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
×
4224
        extras map[uint64][]byte,
×
4225
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4226

×
4227
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4228
        if err != nil {
×
4229
                return nil, fmt.Errorf("unable to serialize extra signed "+
4230
                        "fields: %w", err)
4231
        }
4232

4233
        var inboundFee fn.Option[lnwire.Fee]
4234
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
4235
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4236

×
4237
                inboundFee = fn.Some(lnwire.Fee{
×
4238
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4239
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4240
                })
×
4241
        }
4242

×
4243
        return &models.ChannelEdgePolicy{
×
4244
                SigBytes:  dbPolicy.Signature,
×
4245
                ChannelID: channelID,
×
4246
                LastUpdate: time.Unix(
4247
                        dbPolicy.LastUpdate.Int64, 0,
×
4248
                ),
4249
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
4250
                        dbPolicy.MessageFlags,
4251
                ),
4252
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
4253
                        dbPolicy.ChannelFlags,
4254
                ),
4255
                TimeLockDelta: uint16(dbPolicy.Timelock),
4256
                MinHTLC: lnwire.MilliSatoshi(
4257
                        dbPolicy.MinHtlcMsat,
×
4258
                ),
×
4259
                MaxHTLC: lnwire.MilliSatoshi(
×
4260
                        dbPolicy.MaxHtlcMsat.Int64,
×
4261
                ),
×
4262
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4263
                        dbPolicy.BaseFeeMsat,
×
4264
                ),
×
4265
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4266
                ToNode:                    toNode,
×
4267
                InboundFee:                inboundFee,
×
4268
                ExtraOpaqueData:           recs,
×
4269
        }, nil
×
4270
}
×
4271

×
4272
// buildNodes builds the models.LightningNode instances for the
×
4273
// given row which is expected to be a sqlc type that contains node information.
×
4274
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
×
4275
        dbNode2 sqlc.GraphNode) (*models.LightningNode, *models.LightningNode,
×
4276
        error) {
×
4277

×
4278
        node1, err := buildNode(ctx, db, &dbNode1)
×
4279
        if err != nil {
×
4280
                return nil, nil, err
×
4281
        }
×
4282

×
4283
        node2, err := buildNode(ctx, db, &dbNode2)
×
4284
        if err != nil {
×
4285
                return nil, nil, err
×
4286
        }
×
4287

×
4288
        return node1, node2, nil
×
4289
}
×
4290

4291
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
×
4292
// row which is expected to be a sqlc type that contains channel policy
4293
// information. It returns two policies, which may be nil if the policy
×
4294
// information is not present in the row.
×
4295
//
×
4296
//nolint:ll,dupl,funlen
×
4297
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
×
4298
        *sqlc.GraphChannelPolicy, error) {
×
4299

×
4300
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4301
        switch r := row.(type) {
×
4302
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4303
                if r.Policy1Timelock.Valid {
×
4304
                        policy1 = &sqlc.GraphChannelPolicy{
×
4305
                                Timelock:                r.Policy1Timelock.Int32,
×
4306
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4307
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4308
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4309
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4310
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4311
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4312
                                Disabled:                r.Policy1Disabled,
×
4313
                                MessageFlags:            r.Policy1MessageFlags,
×
4314
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4315
                        }
×
4316
                }
×
4317
                if r.Policy2Timelock.Valid {
×
4318
                        policy2 = &sqlc.GraphChannelPolicy{
×
4319
                                Timelock:                r.Policy2Timelock.Int32,
×
4320
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4321
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4322
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4323
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4324
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4325
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4326
                                Disabled:                r.Policy2Disabled,
×
4327
                                MessageFlags:            r.Policy2MessageFlags,
×
4328
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4329
                        }
×
4330
                }
×
4331

×
4332
                return policy1, policy2, nil
×
4333

×
4334
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
4335
                if r.Policy1ID.Valid {
×
4336
                        policy1 = &sqlc.GraphChannelPolicy{
4337
                                ID:                      r.Policy1ID.Int64,
×
4338
                                Version:                 r.Policy1Version.Int16,
×
4339
                                ChannelID:               r.GraphChannel.ID,
×
4340
                                NodeID:                  r.Policy1NodeID.Int64,
×
4341
                                Timelock:                r.Policy1Timelock.Int32,
×
4342
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4343
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4344
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4345
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4346
                                LastUpdate:              r.Policy1LastUpdate,
×
4347
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4348
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4349
                                Disabled:                r.Policy1Disabled,
×
4350
                                MessageFlags:            r.Policy1MessageFlags,
×
4351
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4352
                                Signature:               r.Policy1Signature,
×
4353
                        }
×
4354
                }
×
4355
                if r.Policy2ID.Valid {
×
4356
                        policy2 = &sqlc.GraphChannelPolicy{
×
4357
                                ID:                      r.Policy2ID.Int64,
×
4358
                                Version:                 r.Policy2Version.Int16,
×
4359
                                ChannelID:               r.GraphChannel.ID,
×
4360
                                NodeID:                  r.Policy2NodeID.Int64,
×
4361
                                Timelock:                r.Policy2Timelock.Int32,
×
4362
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4363
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4364
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4365
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4366
                                LastUpdate:              r.Policy2LastUpdate,
×
4367
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4368
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4369
                                Disabled:                r.Policy2Disabled,
×
4370
                                MessageFlags:            r.Policy2MessageFlags,
×
4371
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4372
                                Signature:               r.Policy2Signature,
×
4373
                        }
×
4374
                }
×
4375

×
4376
                return policy1, policy2, nil
×
4377

×
4378
        case sqlc.GetChannelByOutpointWithPoliciesRow:
4379
                if r.Policy1ID.Valid {
×
4380
                        policy1 = &sqlc.GraphChannelPolicy{
4381
                                ID:                      r.Policy1ID.Int64,
×
4382
                                Version:                 r.Policy1Version.Int16,
×
4383
                                ChannelID:               r.GraphChannel.ID,
×
4384
                                NodeID:                  r.Policy1NodeID.Int64,
×
4385
                                Timelock:                r.Policy1Timelock.Int32,
×
4386
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4387
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4388
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4389
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4390
                                LastUpdate:              r.Policy1LastUpdate,
×
4391
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4392
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4393
                                Disabled:                r.Policy1Disabled,
×
4394
                                MessageFlags:            r.Policy1MessageFlags,
×
4395
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4396
                                Signature:               r.Policy1Signature,
×
4397
                        }
×
4398
                }
×
4399
                if r.Policy2ID.Valid {
×
4400
                        policy2 = &sqlc.GraphChannelPolicy{
×
4401
                                ID:                      r.Policy2ID.Int64,
×
4402
                                Version:                 r.Policy2Version.Int16,
×
4403
                                ChannelID:               r.GraphChannel.ID,
×
4404
                                NodeID:                  r.Policy2NodeID.Int64,
×
4405
                                Timelock:                r.Policy2Timelock.Int32,
×
4406
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4407
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4408
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4409
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4410
                                LastUpdate:              r.Policy2LastUpdate,
×
4411
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4412
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4413
                                Disabled:                r.Policy2Disabled,
×
4414
                                MessageFlags:            r.Policy2MessageFlags,
×
4415
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4416
                                Signature:               r.Policy2Signature,
×
4417
                        }
×
4418
                }
×
4419

×
4420
                return policy1, policy2, nil
×
4421

×
4422
        case sqlc.GetChannelBySCIDWithPoliciesRow:
4423
                if r.Policy1ID.Valid {
×
4424
                        policy1 = &sqlc.GraphChannelPolicy{
4425
                                ID:                      r.Policy1ID.Int64,
×
4426
                                Version:                 r.Policy1Version.Int16,
×
4427
                                ChannelID:               r.GraphChannel.ID,
×
4428
                                NodeID:                  r.Policy1NodeID.Int64,
×
4429
                                Timelock:                r.Policy1Timelock.Int32,
×
4430
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4431
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4432
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4433
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4434
                                LastUpdate:              r.Policy1LastUpdate,
×
4435
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4436
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4437
                                Disabled:                r.Policy1Disabled,
×
4438
                                MessageFlags:            r.Policy1MessageFlags,
×
4439
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4440
                                Signature:               r.Policy1Signature,
×
4441
                        }
×
4442
                }
×
4443
                if r.Policy2ID.Valid {
×
4444
                        policy2 = &sqlc.GraphChannelPolicy{
×
4445
                                ID:                      r.Policy2ID.Int64,
×
4446
                                Version:                 r.Policy2Version.Int16,
×
4447
                                ChannelID:               r.GraphChannel.ID,
×
4448
                                NodeID:                  r.Policy2NodeID.Int64,
×
4449
                                Timelock:                r.Policy2Timelock.Int32,
×
4450
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4451
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4452
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4453
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4454
                                LastUpdate:              r.Policy2LastUpdate,
×
4455
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4456
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4457
                                Disabled:                r.Policy2Disabled,
×
4458
                                MessageFlags:            r.Policy2MessageFlags,
×
4459
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4460
                                Signature:               r.Policy2Signature,
×
4461
                        }
×
4462
                }
×
4463

×
4464
                return policy1, policy2, nil
×
4465

×
4466
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
4467
                if r.Policy1ID.Valid {
×
4468
                        policy1 = &sqlc.GraphChannelPolicy{
4469
                                ID:                      r.Policy1ID.Int64,
×
4470
                                Version:                 r.Policy1Version.Int16,
×
4471
                                ChannelID:               r.GraphChannel.ID,
×
4472
                                NodeID:                  r.Policy1NodeID.Int64,
×
4473
                                Timelock:                r.Policy1Timelock.Int32,
×
4474
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4475
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4476
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4477
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4478
                                LastUpdate:              r.Policy1LastUpdate,
×
4479
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4480
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4481
                                Disabled:                r.Policy1Disabled,
×
4482
                                MessageFlags:            r.Policy1MessageFlags,
×
4483
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4484
                                Signature:               r.Policy1Signature,
×
4485
                        }
×
4486
                }
×
4487
                if r.Policy2ID.Valid {
×
4488
                        policy2 = &sqlc.GraphChannelPolicy{
×
4489
                                ID:                      r.Policy2ID.Int64,
×
4490
                                Version:                 r.Policy2Version.Int16,
×
4491
                                ChannelID:               r.GraphChannel.ID,
×
4492
                                NodeID:                  r.Policy2NodeID.Int64,
×
4493
                                Timelock:                r.Policy2Timelock.Int32,
×
4494
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4495
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4496
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4497
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4498
                                LastUpdate:              r.Policy2LastUpdate,
×
4499
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4500
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4501
                                Disabled:                r.Policy2Disabled,
×
4502
                                MessageFlags:            r.Policy2MessageFlags,
×
4503
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4504
                                Signature:               r.Policy2Signature,
×
4505
                        }
×
4506
                }
×
4507

×
4508
                return policy1, policy2, nil
×
4509

×
4510
        case sqlc.ListChannelsForNodeIDsRow:
4511
                if r.Policy1ID.Valid {
×
4512
                        policy1 = &sqlc.GraphChannelPolicy{
4513
                                ID:                      r.Policy1ID.Int64,
×
4514
                                Version:                 r.Policy1Version.Int16,
×
4515
                                ChannelID:               r.GraphChannel.ID,
×
4516
                                NodeID:                  r.Policy1NodeID.Int64,
×
4517
                                Timelock:                r.Policy1Timelock.Int32,
×
4518
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4519
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4520
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4521
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4522
                                LastUpdate:              r.Policy1LastUpdate,
×
4523
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4524
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4525
                                Disabled:                r.Policy1Disabled,
×
4526
                                MessageFlags:            r.Policy1MessageFlags,
×
4527
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4528
                                Signature:               r.Policy1Signature,
×
4529
                        }
×
4530
                }
×
4531
                if r.Policy2ID.Valid {
×
4532
                        policy2 = &sqlc.GraphChannelPolicy{
×
4533
                                ID:                      r.Policy2ID.Int64,
×
4534
                                Version:                 r.Policy2Version.Int16,
×
4535
                                ChannelID:               r.GraphChannel.ID,
×
4536
                                NodeID:                  r.Policy2NodeID.Int64,
×
4537
                                Timelock:                r.Policy2Timelock.Int32,
×
4538
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4539
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4540
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4541
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4542
                                LastUpdate:              r.Policy2LastUpdate,
×
4543
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4544
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4545
                                Disabled:                r.Policy2Disabled,
×
4546
                                MessageFlags:            r.Policy2MessageFlags,
×
4547
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4548
                                Signature:               r.Policy2Signature,
×
4549
                        }
×
4550
                }
×
4551

×
4552
                return policy1, policy2, nil
×
4553

×
4554
        case sqlc.ListChannelsByNodeIDRow:
4555
                if r.Policy1ID.Valid {
×
4556
                        policy1 = &sqlc.GraphChannelPolicy{
4557
                                ID:                      r.Policy1ID.Int64,
×
4558
                                Version:                 r.Policy1Version.Int16,
×
4559
                                ChannelID:               r.GraphChannel.ID,
×
4560
                                NodeID:                  r.Policy1NodeID.Int64,
×
4561
                                Timelock:                r.Policy1Timelock.Int32,
×
4562
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4563
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4564
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4565
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4566
                                LastUpdate:              r.Policy1LastUpdate,
×
4567
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4568
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4569
                                Disabled:                r.Policy1Disabled,
×
4570
                                MessageFlags:            r.Policy1MessageFlags,
×
4571
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4572
                                Signature:               r.Policy1Signature,
×
4573
                        }
×
4574
                }
×
4575
                if r.Policy2ID.Valid {
×
4576
                        policy2 = &sqlc.GraphChannelPolicy{
×
4577
                                ID:                      r.Policy2ID.Int64,
×
4578
                                Version:                 r.Policy2Version.Int16,
×
4579
                                ChannelID:               r.GraphChannel.ID,
×
4580
                                NodeID:                  r.Policy2NodeID.Int64,
×
4581
                                Timelock:                r.Policy2Timelock.Int32,
×
4582
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4583
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4584
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4585
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4586
                                LastUpdate:              r.Policy2LastUpdate,
×
4587
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4588
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4589
                                Disabled:                r.Policy2Disabled,
×
4590
                                MessageFlags:            r.Policy2MessageFlags,
×
4591
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4592
                                Signature:               r.Policy2Signature,
×
4593
                        }
×
4594
                }
×
4595

×
4596
                return policy1, policy2, nil
×
4597

×
4598
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
4599
                if r.Policy1ID.Valid {
×
4600
                        policy1 = &sqlc.GraphChannelPolicy{
×
4601
                                ID:                      r.Policy1ID.Int64,
×
4602
                                Version:                 r.Policy1Version.Int16,
×
4603
                                ChannelID:               r.GraphChannel.ID,
4604
                                NodeID:                  r.Policy1NodeID.Int64,
4605
                                Timelock:                r.Policy1Timelock.Int32,
4606
                                FeePpm:                  r.Policy1FeePpm.Int64,
4607
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
4608
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4609
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4610
                                LastUpdate:              r.Policy1LastUpdate,
×
4611
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4612
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4613
                                Disabled:                r.Policy1Disabled,
×
4614
                                MessageFlags:            r.Policy1MessageFlags,
4615
                                ChannelFlags:            r.Policy1ChannelFlags,
4616
                                Signature:               r.Policy1Signature,
×
4617
                        }
×
4618
                }
×
4619
                if r.Policy2ID.Valid {
×
4620
                        policy2 = &sqlc.GraphChannelPolicy{
4621
                                ID:                      r.Policy2ID.Int64,
×
4622
                                Version:                 r.Policy2Version.Int16,
×
4623
                                ChannelID:               r.GraphChannel.ID,
×
4624
                                NodeID:                  r.Policy2NodeID.Int64,
×
4625
                                Timelock:                r.Policy2Timelock.Int32,
×
4626
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4627
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4628
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4629
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4630
                                LastUpdate:              r.Policy2LastUpdate,
×
4631
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4632
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
4633
                                Disabled:                r.Policy2Disabled,
4634
                                MessageFlags:            r.Policy2MessageFlags,
4635
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4636
                                Signature:               r.Policy2Signature,
×
4637
                        }
×
4638
                }
4639

×
4640
                return policy1, policy2, nil
4641
        default:
4642
                return nil, nil, fmt.Errorf("unexpected row type in "+
4643
                        "extractChannelPolicies: %T", r)
4644
        }
4645
}
×
4646

×
4647
// channelIDToBytes converts a channel ID (SCID) to a byte array
×
4648
// representation.
×
4649
func channelIDToBytes(channelID uint64) []byte {
×
4650
        var chanIDB [8]byte
×
4651
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4652

4653
        return chanIDB[:]
×
4654
}
×
4655

×
4656
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4657
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4658
        if len(addresses) == 0 {
×
4659
                return nil, nil
×
4660
        }
×
4661

×
4662
        result := make([]net.Addr, 0, len(addresses))
4663
        for _, addr := range addresses {
×
4664
                netAddr, err := parseAddress(addr.addrType, addr.address)
4665
                if err != nil {
×
4666
                        return nil, fmt.Errorf("unable to parse address %s "+
×
4667
                                "of type %d: %w", addr.address, addr.addrType,
×
4668
                                err)
×
4669
                }
×
4670
                if netAddr != nil {
×
4671
                        result = append(result, netAddr)
4672
                }
×
4673
        }
×
4674

×
4675
        // If we have no valid addresses, return nil instead of empty slice.
×
4676
        if len(result) == 0 {
4677
                return nil, nil
×
4678
        }
×
4679

×
4680
        return result, nil
×
4681
}
4682

×
4683
// parseAddress parses the given address string based on the address type
×
4684
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
×
4685
// and opaque addresses.
×
4686
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4687
        switch addrType {
×
4688
        case addressTypeIPv4:
4689
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4690
                if err != nil {
×
4691
                        return nil, err
×
4692
                }
4693

×
4694
                tcp.IP = tcp.IP.To4()
×
4695

4696
                return tcp, nil
4697

4698
        case addressTypeIPv6:
4699
                tcp, err := net.ResolveTCPAddr("tcp6", address)
4700
                if err != nil {
4701
                        return nil, err
4702
                }
4703

4704
                return tcp, nil
4705

4706
        case addressTypeTorV3, addressTypeTorV2:
4707
                service, portStr, err := net.SplitHostPort(address)
4708
                if err != nil {
4709
                        return nil, fmt.Errorf("unable to split tor "+
4710
                                "address: %v", address)
4711
                }
4712

4713
                port, err := strconv.Atoi(portStr)
4714
                if err != nil {
4715
                        return nil, err
4716
                }
4717

4718
                return &tor.OnionAddr{
4719
                        OnionService: service,
4720
                        Port:         port,
4721
                }, nil
4722

4723
        case addressTypeOpaque:
4724
                opaque, err := hex.DecodeString(address)
×
4725
                if err != nil {
×
4726
                        return nil, fmt.Errorf("unable to decode opaque "+
×
4727
                                "address: %v", address)
×
4728
                }
×
4729

×
4730
                return &lnwire.OpaqueAddrs{
×
4731
                        Payload: opaque,
×
4732
                }, nil
4733

4734
        default:
×
4735
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
4736
        }
×
4737
}
×
4738

×
4739
// batchNodeData holds all the related data for a batch of nodes.
4740
type batchNodeData struct {
4741
        // features is a map from a DB node ID to the feature bits for that
×
4742
        // node.
×
4743
        features map[int64][]int
×
4744

×
4745
        // addresses is a map from a DB node ID to the node's addresses.
×
4746
        addresses map[int64][]nodeAddress
4747

×
4748
        // extraFields is a map from a DB node ID to the extra signed fields
×
4749
        // for that node.
×
4750
        extraFields map[int64]map[uint64][]byte
×
4751
}
×
4752

4753
// nodeAddress holds the address type, position and address string for a
4754
// node. This is used to batch the fetching of node addresses.
4755
type nodeAddress struct {
4756
        addrType dbAddressType
4757
        position int32
4758
        address  string
×
4759
}
×
4760

×
4761
// batchLoadNodeData loads all related data for a batch of node IDs using the
×
4762
// provided SQLQueries interface. It returns a batchNodeData instance containing
×
4763
// the node features, addresses and extra signed fields.
×
4764
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
×
4765
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
4766

×
4767
        // Batch load the node features.
4768
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
4769
        if err != nil {
×
4770
                return nil, fmt.Errorf("unable to batch load node "+
×
4771
                        "features: %w", err)
×
4772
        }
×
4773

×
4774
        // Batch load the node addresses.
×
4775
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
4776
        if err != nil {
×
4777
                return nil, fmt.Errorf("unable to batch load node "+
×
4778
                        "addresses: %w", err)
×
4779
        }
×
4780

4781
        // Batch load the node extra signed fields.
4782
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
4783
        if err != nil {
4784
                return nil, fmt.Errorf("unable to batch load node extra "+
4785
                        "signed fields: %w", err)
4786
        }
4787

4788
        return &batchNodeData{
×
4789
                features:    features,
×
4790
                addresses:   addrs,
×
4791
                extraFields: extraTypes,
×
4792
        }, nil
×
4793
}
×
4794

×
4795
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
×
4796
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
×
4797
func batchLoadNodeFeaturesHelper(ctx context.Context,
4798
        cfg *sqldb.QueryConfig, db SQLQueries,
×
4799
        nodeIDs []int64) (map[int64][]int, error) {
×
4800

×
4801
        features := make(map[int64][]int)
×
4802

×
4803
        return features, sqldb.ExecuteBatchQuery(
×
4804
                ctx, cfg, nodeIDs,
×
4805
                func(id int64) int64 {
×
4806
                        return id
×
4807
                },
×
4808
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
×
4809
                        error) {
×
4810

×
4811
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
4812
                },
×
4813
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
4814
                        features[feature.NodeID] = append(
4815
                                features[feature.NodeID],
4816
                                int(feature.FeatureBit),
4817
                        )
4818

4819
                        return nil
4820
                },
4821
        )
×
4822
}
×
4823

×
4824
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
×
4825
// wrapper around the GetNodeAddressesBatch query. It returns a map from
×
4826
// node ID to a slice of nodeAddress structs.
×
4827
func batchLoadNodeAddressesHelper(ctx context.Context,
×
4828
        cfg *sqldb.QueryConfig, db SQLQueries,
×
4829
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
4830

×
4831
        addrs := make(map[int64][]nodeAddress)
×
4832

×
4833
        return addrs, sqldb.ExecuteBatchQuery(
×
4834
                ctx, cfg, nodeIDs,
4835
                func(id int64) int64 {
4836
                        return id
×
4837
                },
×
4838
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
×
4839
                        error) {
×
4840

×
4841
                        return db.GetNodeAddressesBatch(ctx, ids)
4842
                },
×
4843
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
4844
                        addrs[addr.NodeID] = append(
×
4845
                                addrs[addr.NodeID], nodeAddress{
×
4846
                                        addrType: dbAddressType(addr.Type),
4847
                                        position: addr.Position,
4848
                                        address:  addr.Address,
4849
                                },
4850
                        )
4851

4852
                        return nil
4853
                },
4854
        )
4855
}
4856

×
4857
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
×
4858
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
×
4859
// query.
×
4860
func batchLoadNodeExtraTypesHelper(ctx context.Context,
×
4861
        cfg *sqldb.QueryConfig, db SQLQueries,
×
4862
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4863

×
4864
        extraFields := make(map[int64]map[uint64][]byte)
4865

×
4866
        callback := func(ctx context.Context,
×
4867
                field sqlc.GraphNodeExtraType) error {
×
4868

×
4869
                if extraFields[field.NodeID] == nil {
×
4870
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
4871
                }
4872
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
4873

4874
                return nil
4875
        }
4876

4877
        return extraFields, sqldb.ExecuteBatchQuery(
4878
                ctx, cfg, nodeIDs,
4879
                func(id int64) int64 {
×
4880
                        return id
×
4881
                },
×
4882
                func(ctx context.Context, ids []int64) (
×
4883
                        []sqlc.GraphNodeExtraType, error) {
×
4884

4885
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
4886
                },
×
4887
                callback,
×
4888
        )
×
4889
}
×
4890

×
4891
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
4892
// from the provided sqlc.GraphChannelPolicy records and the
×
4893
// provided batchChannelData.
4894
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4895
        channelID uint64, node1, node2 route.Vertex,
4896
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
4897
        *models.ChannelEdgePolicy, error) {
4898

4899
        pol1, err := buildChanPolicyWithBatchData(
4900
                dbPol1, channelID, node2, batchData,
4901
        )
4902
        if err != nil {
4903
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
4904
        }
4905

4906
        pol2, err := buildChanPolicyWithBatchData(
4907
                dbPol2, channelID, node1, batchData,
4908
        )
4909
        if err != nil {
4910
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
4911
        }
4912

4913
        return pol1, pol2, nil
×
4914
}
×
4915

×
4916
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
×
4917
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
×
4918
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
×
4919
        channelID uint64, toNode route.Vertex,
×
4920
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
4921

×
4922
        if dbPol == nil {
×
4923
                return nil, nil
×
4924
        }
×
4925

×
4926
        var dbPol1Extras map[uint64][]byte
×
4927
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
4928
                dbPol1Extras = extras
×
4929
        } else {
×
4930
                dbPol1Extras = make(map[uint64][]byte)
×
4931
        }
4932

×
4933
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
4934
}
×
4935

×
4936
// batchChannelData holds all the related data for a batch of channels.
×
4937
type batchChannelData struct {
×
4938
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
×
4939
        chanfeatures map[int64][]int
4940

4941
        // chanExtras is a map from DB channel ID to a map of TLV type to
×
4942
        // extra signed field bytes.
×
4943
        chanExtraTypes map[int64]map[uint64][]byte
×
4944

×
4945
        // policyExtras is a map from DB channel policy ID to a map of TLV type
×
4946
        // to extra signed field bytes.
×
4947
        policyExtras map[int64]map[uint64][]byte
×
4948
}
×
4949

×
4950
// batchLoadChannelData loads all related data for batches of channels and
4951
// policies.
4952
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
×
4953
        db SQLQueries, channelIDs []int64,
4954
        policyIDs []int64) (*batchChannelData, error) {
4955

4956
        batchData := &batchChannelData{
4957
                chanfeatures:   make(map[int64][]int),
4958
                chanExtraTypes: make(map[int64]map[uint64][]byte),
4959
                policyExtras:   make(map[int64]map[uint64][]byte),
4960
        }
4961

×
4962
        // Batch load channel features and extras
×
4963
        var err error
×
4964
        if len(channelIDs) > 0 {
×
4965
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
4966
                        ctx, cfg, db, channelIDs,
×
4967
                )
×
4968
                if err != nil {
×
4969
                        return nil, fmt.Errorf("unable to batch load "+
×
4970
                                "channel features: %w", err)
4971
                }
×
4972

×
4973
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
4974
                        ctx, cfg, db, channelIDs,
×
4975
                )
4976
                if err != nil {
×
4977
                        return nil, fmt.Errorf("unable to batch load "+
×
4978
                                "channel extras: %w", err)
×
4979
                }
×
4980
        }
×
4981

×
4982
        if len(policyIDs) > 0 {
×
4983
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
4984
                        ctx, cfg, db, policyIDs,
×
4985
                )
4986
                if err != nil {
4987
                        return nil, fmt.Errorf("unable to batch load "+
4988
                                "policy extras: %w", err)
4989
                }
4990
                batchData.policyExtras = policyExtras
4991
        }
4992

4993
        return batchData, nil
4994
}
×
4995

×
4996
// batchLoadChannelFeaturesHelper loads channel features for a batch of
×
4997
// channel IDs using ExecuteBatchQuery wrapper around the
×
4998
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
×
4999
// slice of feature bits.
×
5000
func batchLoadChannelFeaturesHelper(ctx context.Context,
×
5001
        cfg *sqldb.QueryConfig, db SQLQueries,
×
5002
        channelIDs []int64) (map[int64][]int, error) {
×
5003

×
5004
        features := make(map[int64][]int)
×
5005

×
5006
        return features, sqldb.ExecuteBatchQuery(
×
5007
                ctx, cfg, channelIDs,
5008
                func(id int64) int64 {
5009
                        return id
×
5010
                },
×
5011
                func(ctx context.Context,
×
5012
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
5013

×
5014
                        return db.GetChannelFeaturesBatch(ctx, ids)
5015
                },
×
5016
                func(ctx context.Context,
×
5017
                        feature sqlc.GraphChannelFeature) error {
×
5018

×
5019
                        features[feature.ChannelID] = append(
5020
                                features[feature.ChannelID],
5021
                                int(feature.FeatureBit),
5022
                        )
5023

5024
                        return nil
5025
                },
5026
        )
5027
}
5028

×
5029
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
×
5030
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
×
5031
// query. It returns a map from DB channel ID to a map of TLV type to extra
×
5032
// signed field bytes.
×
5033
func batchLoadChannelExtrasHelper(ctx context.Context,
×
5034
        cfg *sqldb.QueryConfig, db SQLQueries,
×
5035
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5036

×
5037
        extras := make(map[int64]map[uint64][]byte)
5038

×
5039
        cb := func(ctx context.Context,
×
5040
                extra sqlc.GraphChannelExtraType) error {
×
5041

×
5042
                if extras[extra.ChannelID] == nil {
5043
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5044
                }
×
5045
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
5046

×
5047
                return nil
×
5048
        }
×
5049

×
5050
        return extras, sqldb.ExecuteBatchQuery(
×
5051
                ctx, cfg, channelIDs,
5052
                func(id int64) int64 {
5053
                        return id
5054
                },
5055
                func(ctx context.Context,
5056
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
5057

5058
                        return db.GetChannelExtrasBatch(ctx, ids)
5059
                }, cb,
5060
        )
5061
}
×
5062

×
5063
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
×
5064
// batch of policy IDs using ExecuteBatchQuery wrapper around the
×
5065
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
×
5066
// a map of TLV type to extra signed field bytes.
×
5067
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
×
5068
        cfg *sqldb.QueryConfig, db SQLQueries,
×
5069
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5070

×
5071
        extras := make(map[int64]map[uint64][]byte)
×
5072

×
5073
        return extras, sqldb.ExecuteBatchQuery(
×
5074
                ctx, cfg, policyIDs,
5075
                func(id int64) int64 {
×
5076
                        return id
×
5077
                },
×
5078
                func(ctx context.Context, ids []int64) (
5079
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5080

×
5081
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5082
                },
5083
                func(ctx context.Context,
×
5084
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5085

×
5086
                        if extras[row.PolicyID] == nil {
×
5087
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5088
                        }
5089
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5090

×
5091
                        return nil
×
5092
                },
×
5093
        )
×
5094
}
×
5095

×
5096
// forEachNodePaginated executes a paginated query to process each node in the
×
5097
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5098
// and applies the provided processNode function to each node.
×
5099
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5100
        db SQLQueries, protocol ProtocolVersion,
5101
        processNode func(context.Context, int64,
×
5102
                *models.LightningNode) error) error {
×
5103

×
5104
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5105
                limit int32) ([]sqlc.GraphNode, error) {
5106

5107
                return db.ListNodesPaginated(
5108
                        ctx, sqlc.ListNodesPaginatedParams{
5109
                                Version: int16(protocol),
5110
                                ID:      lastID,
5111
                                Limit:   limit,
5112
                        },
×
5113
                )
×
5114
        }
×
5115

×
5116
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5117
                return node.ID
×
5118
        }
×
5119

×
5120
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5121
                return node.ID, nil
×
5122
        }
×
5123

×
5124
        batchQueryFunc := func(ctx context.Context,
×
5125
                nodeIDs []int64) (*batchNodeData, error) {
×
5126

×
5127
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5128
        }
×
5129

×
5130
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5131
                batchData *batchNodeData) error {
5132

×
5133
                node, err := buildNodeWithBatchData(&dbNode, batchData)
×
5134
                if err != nil {
×
5135
                        return fmt.Errorf("unable to build "+
×
5136
                                "node(id=%d): %w", dbNode.ID, err)
×
5137
                }
5138

×
5139
                return processNode(ctx, dbNode.ID, node)
×
5140
        }
×
5141

×
5142
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5143
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5144
                collectFunc, batchQueryFunc, processItem,
×
5145
        )
×
5146
}
×
5147

×
5148
// forEachChannelWithPolicies executes a paginated query to process each channel
×
5149
// with policies in the graph.
×
5150
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5151
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
×
5152
                *models.ChannelEdgePolicy,
×
5153
                *models.ChannelEdgePolicy) error) error {
×
5154

×
5155
        type channelBatchIDs struct {
×
5156
                channelID int64
×
5157
                policyIDs []int64
5158
        }
×
5159

5160
        pageQueryFunc := func(ctx context.Context, lastID int64,
5161
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5162
                error) {
×
5163

×
5164
                return db.ListChannelsWithPoliciesPaginated(
×
5165
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
5166
                                Version: int16(ProtocolV1),
×
5167
                                ID:      lastID,
×
5168
                                Limit:   limit,
×
5169
                        },
×
5170
                )
×
5171
        }
×
5172

×
5173
        extractPageCursor := func(
×
5174
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
5175

×
5176
                return row.GraphChannel.ID
×
5177
        }
×
5178

5179
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
5180
                channelBatchIDs, error) {
×
5181

×
5182
                ids := channelBatchIDs{
×
5183
                        channelID: row.GraphChannel.ID,
×
5184
                }
×
5185

×
5186
                // Extract policy IDs from the row.
×
5187
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5188
                if err != nil {
×
5189
                        return ids, err
×
5190
                }
5191

×
5192
                if dbPol1 != nil {
×
5193
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5194
                }
×
5195
                if dbPol2 != nil {
×
5196
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5197
                }
×
5198

×
5199
                return ids, nil
5200
        }
×
5201

×
5202
        batchDataFunc := func(ctx context.Context,
×
5203
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5204

5205
                // Separate channel IDs from policy IDs.
×
5206
                var (
×
5207
                        channelIDs = make([]int64, len(allIDs))
×
5208
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5209
                )
×
5210

×
5211
                for i, ids := range allIDs {
5212
                        channelIDs[i] = ids.channelID
×
5213
                        policyIDs = append(policyIDs, ids.policyIDs...)
5214
                }
5215

×
5216
                return batchLoadChannelData(
×
5217
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5218
                )
×
5219
        }
5220

5221
        processItem := func(ctx context.Context,
5222
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
5223
                batchData *batchChannelData) error {
5224

5225
                node1, node2, err := buildNodeVertices(
5226
                        row.Node1Pubkey, row.Node2Pubkey,
×
5227
                )
×
5228
                if err != nil {
×
5229
                        return err
×
5230
                }
×
5231

×
5232
                edge, err := buildEdgeInfoWithBatchData(
×
5233
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5234
                        batchData,
5235
                )
×
5236
                if err != nil {
×
5237
                        return fmt.Errorf("unable to build channel info: %w",
×
5238
                                err)
×
5239
                }
×
5240

×
5241
                dbPol1, dbPol2, err := extractChannelPolicies(row)
5242
                if err != nil {
×
5243
                        return err
×
5244
                }
×
5245

×
5246
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5247
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
5248
                )
×
5249
                if err != nil {
×
5250
                        return err
×
5251
                }
×
5252

×
5253
                return processChannel(edge, p1, p2)
×
5254
        }
×
5255

×
5256
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
5257
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
5258
                collectFunc, batchDataFunc, processItem,
×
5259
        )
×
5260
}
×
5261

×
5262
// buildDirectedChannel builds a DirectedChannel instance from the provided
×
5263
// data.
×
5264
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
×
5265
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
×
5266
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5267
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
5268

×
5269
        node1, node2, err := buildNodeVertices(
×
5270
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5271
        )
×
5272
        if err != nil {
×
5273
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5274
        }
5275

5276
        edge, err := buildEdgeInfoWithBatchData(
×
5277
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5278
        )
×
5279
        if err != nil {
×
5280
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5281
        }
5282

5283
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
5284
        if err != nil {
×
5285
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5286
                        err)
×
5287
        }
×
5288

×
5289
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5290
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5291
                channelBatchData,
×
5292
        )
×
5293
        if err != nil {
×
5294
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5295
                        err)
×
5296
        }
×
5297

5298
        // Determine outgoing and incoming policy for this specific node.
×
5299
        p1ToNode := channelRow.GraphChannel.NodeID2
5300
        p2ToNode := channelRow.GraphChannel.NodeID1
5301
        outPolicy, inPolicy := p1, p2
5302
        if (p1 != nil && p1ToNode == nodeID) ||
5303
                (p2 != nil && p2ToNode != nodeID) {
5304

5305
                outPolicy, inPolicy = p2, p1
5306
        }
5307

5308
        // Build cached policy.
5309
        var cachedInPolicy *models.CachedEdgePolicy
5310
        if inPolicy != nil {
5311
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
5312
                cachedInPolicy.ToNodePubKey = toNodeCallback
5313
                cachedInPolicy.ToNodeFeatures = features
5314
        }
5315

5316
        // Extract inbound fee.
5317
        var inboundFee lnwire.Fee
5318
        if outPolicy != nil {
5319
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
5320
                        inboundFee = fee
5321
                })
5322
        }
5323

5324
        // Build directed channel.
5325
        directedChannel := &DirectedChannel{
5326
                ChannelID:    edge.ChannelID,
5327
                IsNode1:      nodePub == edge.NodeKey1Bytes,
5328
                OtherNode:    edge.NodeKey2Bytes,
5329
                Capacity:     edge.Capacity,
5330
                OutPolicySet: outPolicy != nil,
5331
                InPolicy:     cachedInPolicy,
5332
                InboundFee:   inboundFee,
5333
        }
5334

5335
        if nodePub == edge.NodeKey2Bytes {
5336
                directedChannel.OtherNode = edge.NodeKey1Bytes
5337
        }
5338

5339
        return directedChannel, nil
5340
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc