• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16777989842

06 Aug 2025 01:14PM UTC coverage: 66.948% (-0.006%) from 66.954%
16777989842

push

github

web-flow
Merge pull request #10132 from ffranr/update-ffranr-signing-key

scripts: update ffranr release signing key

135677 of 202660 relevant lines covered (66.95%)

21599.01 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// ProtocolVersion is an enum that defines the gossip protocol version of a
35
// message.
36
type ProtocolVersion uint8
37

38
const (
39
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
40
        ProtocolV1 ProtocolVersion = 1
41
)
42

43
// String returns a string representation of the protocol version.
44
func (v ProtocolVersion) String() string {
×
45
        return fmt.Sprintf("V%d", v)
×
46
}
×
47

48
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
49
// execute queries against the SQL graph tables.
50
//
51
//nolint:ll,interfacebloat
52
type SQLQueries interface {
53
        /*
54
                Node queries.
55
        */
56
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
57
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
58
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
59
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
60
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
61
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
62
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
63
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
64
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
65
        DeleteNode(ctx context.Context, id int64) error
66

67
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
68
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
69
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
70
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
71

72
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
73
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
74
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
75
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
76

77
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
78
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
79
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
80
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
81
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
82

83
        /*
84
                Source node queries.
85
        */
86
        AddSourceNode(ctx context.Context, nodeID int64) error
87
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
88

89
        /*
90
                Channel queries.
91
        */
92
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
93
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
94
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
95
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
96
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
97
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
98
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
99
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
100
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
101
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
102
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
103
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
104
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
105
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
106
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
107
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
108
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
109
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
110
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
111
        DeleteChannels(ctx context.Context, ids []int64) error
112

113
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
114
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
115
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
116
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
117

118
        /*
119
                Channel Policy table queries.
120
        */
121
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
122
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
123
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
124

125
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
126
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
127
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
128

129
        /*
130
                Zombie index queries.
131
        */
132
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
133
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
134
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
135
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
136
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
137

138
        /*
139
                Prune log table queries.
140
        */
141
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
142
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
143
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
144
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
145

146
        /*
147
                Closed SCID table queries.
148
        */
149
        InsertClosedChannel(ctx context.Context, scid []byte) error
150
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
151
}
152

153
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
154
// database operations.
155
type BatchedSQLQueries interface {
156
        SQLQueries
157
        sqldb.BatchedTx[SQLQueries]
158
}
159

160
// SQLStore is an implementation of the V1Store interface that uses a SQL
161
// database as the backend.
162
type SQLStore struct {
163
        cfg *SQLStoreConfig
164
        db  BatchedSQLQueries
165

166
        // cacheMu guards all caches (rejectCache and chanCache). If
167
        // this mutex will be acquired at the same time as the DB mutex then
168
        // the cacheMu MUST be acquired first to prevent deadlock.
169
        cacheMu     sync.RWMutex
170
        rejectCache *rejectCache
171
        chanCache   *channelCache
172

173
        chanScheduler batch.Scheduler[SQLQueries]
174
        nodeScheduler batch.Scheduler[SQLQueries]
175

176
        srcNodes  map[ProtocolVersion]*srcNodeInfo
177
        srcNodeMu sync.Mutex
178
}
179

180
// A compile-time assertion to ensure that SQLStore implements the V1Store
181
// interface.
182
var _ V1Store = (*SQLStore)(nil)
183

184
// SQLStoreConfig holds the configuration for the SQLStore.
185
type SQLStoreConfig struct {
186
        // ChainHash is the genesis hash for the chain that all the gossip
187
        // messages in this store are aimed at.
188
        ChainHash chainhash.Hash
189

190
        // QueryConfig holds configuration values for SQL queries.
191
        QueryCfg *sqldb.QueryConfig
192
}
193

194
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
195
// storage backend.
196
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
197
        options ...StoreOptionModifier) (*SQLStore, error) {
×
198

×
199
        opts := DefaultOptions()
×
200
        for _, o := range options {
×
201
                o(opts)
×
202
        }
×
203

204
        if opts.NoMigration {
×
205
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
206
                        "supported for SQL stores")
×
207
        }
×
208

209
        s := &SQLStore{
×
210
                cfg:         cfg,
×
211
                db:          db,
×
212
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
213
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
214
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
215
        }
×
216

×
217
        s.chanScheduler = batch.NewTimeScheduler(
×
218
                db, &s.cacheMu, opts.BatchCommitInterval,
×
219
        )
×
220
        s.nodeScheduler = batch.NewTimeScheduler(
×
221
                db, nil, opts.BatchCommitInterval,
×
222
        )
×
223

×
224
        return s, nil
×
225
}
226

227
// AddLightningNode adds a vertex/node to the graph database. If the node is not
228
// in the database from before, this will add a new, unconnected one to the
229
// graph. If it is present from before, this will update that node's
230
// information.
231
//
232
// NOTE: part of the V1Store interface.
233
func (s *SQLStore) AddLightningNode(ctx context.Context,
234
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
235

×
236
        r := &batch.Request[SQLQueries]{
×
237
                Opts: batch.NewSchedulerOptions(opts...),
×
238
                Do: func(queries SQLQueries) error {
×
239
                        _, err := upsertNode(ctx, queries, node)
×
240
                        return err
×
241
                },
×
242
        }
243

244
        return s.nodeScheduler.Execute(ctx, r)
×
245
}
246

247
// FetchLightningNode attempts to look up a target node by its identity public
248
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
249
// returned.
250
//
251
// NOTE: part of the V1Store interface.
252
func (s *SQLStore) FetchLightningNode(ctx context.Context,
253
        pubKey route.Vertex) (*models.LightningNode, error) {
×
254

×
255
        var node *models.LightningNode
×
256
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
257
                var err error
×
258
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
259

×
260
                return err
×
261
        }, sqldb.NoOpReset)
×
262
        if err != nil {
×
263
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
264
        }
×
265

266
        return node, nil
×
267
}
268

269
// HasLightningNode determines if the graph has a vertex identified by the
270
// target node identity public key. If the node exists in the database, a
271
// timestamp of when the data for the node was lasted updated is returned along
272
// with a true boolean. Otherwise, an empty time.Time is returned with a false
273
// boolean.
274
//
275
// NOTE: part of the V1Store interface.
276
func (s *SQLStore) HasLightningNode(ctx context.Context,
277
        pubKey [33]byte) (time.Time, bool, error) {
×
278

×
279
        var (
×
280
                exists     bool
×
281
                lastUpdate time.Time
×
282
        )
×
283
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
284
                dbNode, err := db.GetNodeByPubKey(
×
285
                        ctx, sqlc.GetNodeByPubKeyParams{
×
286
                                Version: int16(ProtocolV1),
×
287
                                PubKey:  pubKey[:],
×
288
                        },
×
289
                )
×
290
                if errors.Is(err, sql.ErrNoRows) {
×
291
                        return nil
×
292
                } else if err != nil {
×
293
                        return fmt.Errorf("unable to fetch node: %w", err)
×
294
                }
×
295

296
                exists = true
×
297

×
298
                if dbNode.LastUpdate.Valid {
×
299
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
300
                }
×
301

302
                return nil
×
303
        }, sqldb.NoOpReset)
304
        if err != nil {
×
305
                return time.Time{}, false,
×
306
                        fmt.Errorf("unable to fetch node: %w", err)
×
307
        }
×
308

309
        return lastUpdate, exists, nil
×
310
}
311

312
// AddrsForNode returns all known addresses for the target node public key
313
// that the graph DB is aware of. The returned boolean indicates if the
314
// given node is unknown to the graph DB or not.
315
//
316
// NOTE: part of the V1Store interface.
317
func (s *SQLStore) AddrsForNode(ctx context.Context,
318
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
319

×
320
        var (
×
321
                addresses []net.Addr
×
322
                known     bool
×
323
        )
×
324
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
325
                // First, check if the node exists and get its DB ID if it
×
326
                // does.
×
327
                dbID, err := db.GetNodeIDByPubKey(
×
328
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
329
                                Version: int16(ProtocolV1),
×
330
                                PubKey:  nodePub.SerializeCompressed(),
×
331
                        },
×
332
                )
×
333
                if errors.Is(err, sql.ErrNoRows) {
×
334
                        return nil
×
335
                }
×
336

337
                known = true
×
338

×
339
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
340
                if err != nil {
×
341
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
342
                                err)
×
343
                }
×
344

345
                return nil
×
346
        }, sqldb.NoOpReset)
347
        if err != nil {
×
348
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
349
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
350
        }
×
351

352
        return known, addresses, nil
×
353
}
354

355
// DeleteLightningNode starts a new database transaction to remove a vertex/node
356
// from the database according to the node's public key.
357
//
358
// NOTE: part of the V1Store interface.
359
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
360
        pubKey route.Vertex) error {
×
361

×
362
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
363
                res, err := db.DeleteNodeByPubKey(
×
364
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
365
                                Version: int16(ProtocolV1),
×
366
                                PubKey:  pubKey[:],
×
367
                        },
×
368
                )
×
369
                if err != nil {
×
370
                        return err
×
371
                }
×
372

373
                rows, err := res.RowsAffected()
×
374
                if err != nil {
×
375
                        return err
×
376
                }
×
377

378
                if rows == 0 {
×
379
                        return ErrGraphNodeNotFound
×
380
                } else if rows > 1 {
×
381
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
382
                }
×
383

384
                return err
×
385
        }, sqldb.NoOpReset)
386
        if err != nil {
×
387
                return fmt.Errorf("unable to delete node: %w", err)
×
388
        }
×
389

390
        return nil
×
391
}
392

393
// FetchNodeFeatures returns the features of the given node. If no features are
394
// known for the node, an empty feature vector is returned.
395
//
396
// NOTE: this is part of the graphdb.NodeTraverser interface.
397
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
398
        *lnwire.FeatureVector, error) {
×
399

×
400
        ctx := context.TODO()
×
401

×
402
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
403
}
×
404

405
// DisabledChannelIDs returns the channel ids of disabled channels.
406
// A channel is disabled when two of the associated ChanelEdgePolicies
407
// have their disabled bit on.
408
//
409
// NOTE: part of the V1Store interface.
410
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
411
        var (
×
412
                ctx     = context.TODO()
×
413
                chanIDs []uint64
×
414
        )
×
415
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
416
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
417
                if err != nil {
×
418
                        return fmt.Errorf("unable to fetch disabled "+
×
419
                                "channels: %w", err)
×
420
                }
×
421

422
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
423

×
424
                return nil
×
425
        }, sqldb.NoOpReset)
426
        if err != nil {
×
427
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
428
                        err)
×
429
        }
×
430

431
        return chanIDs, nil
×
432
}
433

434
// LookupAlias attempts to return the alias as advertised by the target node.
435
//
436
// NOTE: part of the V1Store interface.
437
func (s *SQLStore) LookupAlias(ctx context.Context,
438
        pub *btcec.PublicKey) (string, error) {
×
439

×
440
        var alias string
×
441
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
442
                dbNode, err := db.GetNodeByPubKey(
×
443
                        ctx, sqlc.GetNodeByPubKeyParams{
×
444
                                Version: int16(ProtocolV1),
×
445
                                PubKey:  pub.SerializeCompressed(),
×
446
                        },
×
447
                )
×
448
                if errors.Is(err, sql.ErrNoRows) {
×
449
                        return ErrNodeAliasNotFound
×
450
                } else if err != nil {
×
451
                        return fmt.Errorf("unable to fetch node: %w", err)
×
452
                }
×
453

454
                if !dbNode.Alias.Valid {
×
455
                        return ErrNodeAliasNotFound
×
456
                }
×
457

458
                alias = dbNode.Alias.String
×
459

×
460
                return nil
×
461
        }, sqldb.NoOpReset)
462
        if err != nil {
×
463
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
464
        }
×
465

466
        return alias, nil
×
467
}
468

469
// SourceNode returns the source node of the graph. The source node is treated
470
// as the center node within a star-graph. This method may be used to kick off
471
// a path finding algorithm in order to explore the reachability of another
472
// node based off the source node.
473
//
474
// NOTE: part of the V1Store interface.
475
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
476
        error) {
×
477

×
478
        var node *models.LightningNode
×
479
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
480
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
481
                if err != nil {
×
482
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
483
                                err)
×
484
                }
×
485

486
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
487

×
488
                return err
×
489
        }, sqldb.NoOpReset)
490
        if err != nil {
×
491
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
492
        }
×
493

494
        return node, nil
×
495
}
496

497
// SetSourceNode sets the source node within the graph database. The source
498
// node is to be used as the center of a star-graph within path finding
499
// algorithms.
500
//
501
// NOTE: part of the V1Store interface.
502
func (s *SQLStore) SetSourceNode(ctx context.Context,
503
        node *models.LightningNode) error {
×
504

×
505
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
506
                id, err := upsertNode(ctx, db, node)
×
507
                if err != nil {
×
508
                        return fmt.Errorf("unable to upsert source node: %w",
×
509
                                err)
×
510
                }
×
511

512
                // Make sure that if a source node for this version is already
513
                // set, then the ID is the same as the one we are about to set.
514
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
515
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
516
                        return fmt.Errorf("unable to fetch source node: %w",
×
517
                                err)
×
518
                } else if err == nil {
×
519
                        if dbSourceNodeID != id {
×
520
                                return fmt.Errorf("v1 source node already "+
×
521
                                        "set to a different node: %d vs %d",
×
522
                                        dbSourceNodeID, id)
×
523
                        }
×
524

525
                        return nil
×
526
                }
527

528
                return db.AddSourceNode(ctx, id)
×
529
        }, sqldb.NoOpReset)
530
}
531

532
// NodeUpdatesInHorizon returns all the known lightning node which have an
533
// update timestamp within the passed range. This method can be used by two
534
// nodes to quickly determine if they have the same set of up to date node
535
// announcements.
536
//
537
// NOTE: This is part of the V1Store interface.
538
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
539
        endTime time.Time) ([]models.LightningNode, error) {
×
540

×
541
        ctx := context.TODO()
×
542

×
543
        var nodes []models.LightningNode
×
544
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
545
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
546
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
547
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
548
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
549
                        },
×
550
                )
×
551
                if err != nil {
×
552
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
553
                }
×
554

555
                err = forEachNodeInBatch(
×
556
                        ctx, s.cfg.QueryCfg, db, dbNodes,
×
557
                        func(_ int64, node *models.LightningNode) error {
×
558
                                nodes = append(nodes, *node)
×
559

×
560
                                return nil
×
561
                        },
×
562
                )
563
                if err != nil {
×
564
                        return fmt.Errorf("unable to build nodes: %w", err)
×
565
                }
×
566

567
                return nil
×
568
        }, sqldb.NoOpReset)
569
        if err != nil {
×
570
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
571
        }
×
572

573
        return nodes, nil
×
574
}
575

576
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
577
// undirected edge from the two target nodes are created. The information stored
578
// denotes the static attributes of the channel, such as the channelID, the keys
579
// involved in creation of the channel, and the set of features that the channel
580
// supports. The chanPoint and chanID are used to uniquely identify the edge
581
// globally within the database.
582
//
583
// NOTE: part of the V1Store interface.
584
func (s *SQLStore) AddChannelEdge(ctx context.Context,
585
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
586

×
587
        var alreadyExists bool
×
588
        r := &batch.Request[SQLQueries]{
×
589
                Opts: batch.NewSchedulerOptions(opts...),
×
590
                Reset: func() {
×
591
                        alreadyExists = false
×
592
                },
×
593
                Do: func(tx SQLQueries) error {
×
594
                        _, err := insertChannel(ctx, tx, edge)
×
595

×
596
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
597
                        // succeed, but propagate the error via local state.
×
598
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
599
                                alreadyExists = true
×
600
                                return nil
×
601
                        }
×
602

603
                        return err
×
604
                },
605
                OnCommit: func(err error) error {
×
606
                        switch {
×
607
                        case err != nil:
×
608
                                return err
×
609
                        case alreadyExists:
×
610
                                return ErrEdgeAlreadyExist
×
611
                        default:
×
612
                                s.rejectCache.remove(edge.ChannelID)
×
613
                                s.chanCache.remove(edge.ChannelID)
×
614
                                return nil
×
615
                        }
616
                },
617
        }
618

619
        return s.chanScheduler.Execute(ctx, r)
×
620
}
621

622
// HighestChanID returns the "highest" known channel ID in the channel graph.
623
// This represents the "newest" channel from the PoV of the chain. This method
624
// can be used by peers to quickly determine if their graphs are in sync.
625
//
626
// NOTE: This is part of the V1Store interface.
627
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
628
        var highestChanID uint64
×
629
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
630
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
631
                if errors.Is(err, sql.ErrNoRows) {
×
632
                        return nil
×
633
                } else if err != nil {
×
634
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
635
                                err)
×
636
                }
×
637

638
                highestChanID = byteOrder.Uint64(chanID)
×
639

×
640
                return nil
×
641
        }, sqldb.NoOpReset)
642
        if err != nil {
×
643
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
644
        }
×
645

646
        return highestChanID, nil
×
647
}
648

649
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
650
// within the database for the referenced channel. The `flags` attribute within
651
// the ChannelEdgePolicy determines which of the directed edges are being
652
// updated. If the flag is 1, then the first node's information is being
653
// updated, otherwise it's the second node's information. The node ordering is
654
// determined by the lexicographical ordering of the identity public keys of the
655
// nodes on either side of the channel.
656
//
657
// NOTE: part of the V1Store interface.
658
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
659
        edge *models.ChannelEdgePolicy,
660
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
661

×
662
        var (
×
663
                isUpdate1    bool
×
664
                edgeNotFound bool
×
665
                from, to     route.Vertex
×
666
        )
×
667

×
668
        r := &batch.Request[SQLQueries]{
×
669
                Opts: batch.NewSchedulerOptions(opts...),
×
670
                Reset: func() {
×
671
                        isUpdate1 = false
×
672
                        edgeNotFound = false
×
673
                },
×
674
                Do: func(tx SQLQueries) error {
×
675
                        var err error
×
676
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
677
                                ctx, tx, edge,
×
678
                        )
×
679
                        if err != nil {
×
680
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
681
                        }
×
682

683
                        // Silence ErrEdgeNotFound so that the batch can
684
                        // succeed, but propagate the error via local state.
685
                        if errors.Is(err, ErrEdgeNotFound) {
×
686
                                edgeNotFound = true
×
687
                                return nil
×
688
                        }
×
689

690
                        return err
×
691
                },
692
                OnCommit: func(err error) error {
×
693
                        switch {
×
694
                        case err != nil:
×
695
                                return err
×
696
                        case edgeNotFound:
×
697
                                return ErrEdgeNotFound
×
698
                        default:
×
699
                                s.updateEdgeCache(edge, isUpdate1)
×
700
                                return nil
×
701
                        }
702
                },
703
        }
704

705
        err := s.chanScheduler.Execute(ctx, r)
×
706

×
707
        return from, to, err
×
708
}
709

710
// updateEdgeCache updates our reject and channel caches with the new
711
// edge policy information.
712
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
713
        isUpdate1 bool) {
×
714

×
715
        // If an entry for this channel is found in reject cache, we'll modify
×
716
        // the entry with the updated timestamp for the direction that was just
×
717
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
718
        // during the next query for this edge.
×
719
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
720
                if isUpdate1 {
×
721
                        entry.upd1Time = e.LastUpdate.Unix()
×
722
                } else {
×
723
                        entry.upd2Time = e.LastUpdate.Unix()
×
724
                }
×
725
                s.rejectCache.insert(e.ChannelID, entry)
×
726
        }
727

728
        // If an entry for this channel is found in channel cache, we'll modify
729
        // the entry with the updated policy for the direction that was just
730
        // written. If the edge doesn't exist, we'll defer loading the info and
731
        // policies and lazily read from disk during the next query.
732
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
733
                if isUpdate1 {
×
734
                        channel.Policy1 = e
×
735
                } else {
×
736
                        channel.Policy2 = e
×
737
                }
×
738
                s.chanCache.insert(e.ChannelID, channel)
×
739
        }
740
}
741

742
// ForEachSourceNodeChannel iterates through all channels of the source node,
743
// executing the passed callback on each. The call-back is provided with the
744
// channel's outpoint, whether we have a policy for the channel and the channel
745
// peer's node information.
746
//
747
// NOTE: part of the V1Store interface.
748
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
749
        cb func(chanPoint wire.OutPoint, havePolicy bool,
750
                otherNode *models.LightningNode) error, reset func()) error {
×
751

×
752
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
753
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
754
                if err != nil {
×
755
                        return fmt.Errorf("unable to fetch source node: %w",
×
756
                                err)
×
757
                }
×
758

759
                return forEachNodeChannel(
×
760
                        ctx, db, s.cfg, nodeID,
×
761
                        func(info *models.ChannelEdgeInfo,
×
762
                                outPolicy *models.ChannelEdgePolicy,
×
763
                                _ *models.ChannelEdgePolicy) error {
×
764

×
765
                                // Fetch the other node.
×
766
                                var (
×
767
                                        otherNodePub [33]byte
×
768
                                        node1        = info.NodeKey1Bytes
×
769
                                        node2        = info.NodeKey2Bytes
×
770
                                )
×
771
                                switch {
×
772
                                case bytes.Equal(node1[:], nodePub[:]):
×
773
                                        otherNodePub = node2
×
774
                                case bytes.Equal(node2[:], nodePub[:]):
×
775
                                        otherNodePub = node1
×
776
                                default:
×
777
                                        return fmt.Errorf("node not " +
×
778
                                                "participating in this channel")
×
779
                                }
780

781
                                _, otherNode, err := getNodeByPubKey(
×
782
                                        ctx, db, otherNodePub,
×
783
                                )
×
784
                                if err != nil {
×
785
                                        return fmt.Errorf("unable to fetch "+
×
786
                                                "other node(%x): %w",
×
787
                                                otherNodePub, err)
×
788
                                }
×
789

790
                                return cb(
×
791
                                        info.ChannelPoint, outPolicy != nil,
×
792
                                        otherNode,
×
793
                                )
×
794
                        },
795
                )
796
        }, reset)
797
}
798

799
// ForEachNode iterates through all the stored vertices/nodes in the graph,
800
// executing the passed callback with each node encountered. If the callback
801
// returns an error, then the transaction is aborted and the iteration stops
802
// early.
803
//
804
// NOTE: part of the V1Store interface.
805
func (s *SQLStore) ForEachNode(ctx context.Context,
806
        cb func(node *models.LightningNode) error, reset func()) error {
×
807

×
808
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
809
                return forEachNodePaginated(
×
810
                        ctx, s.cfg.QueryCfg, db,
×
811
                        ProtocolV1, func(_ context.Context, _ int64,
×
812
                                node *models.LightningNode) error {
×
813

×
814
                                return cb(node)
×
815
                        },
×
816
                )
817
        }, reset)
818
}
819

820
// ForEachNodeDirectedChannel iterates through all channels of a given node,
821
// executing the passed callback on the directed edge representing the channel
822
// and its incoming policy. If the callback returns an error, then the iteration
823
// is halted with the error propagated back up to the caller.
824
//
825
// Unknown policies are passed into the callback as nil values.
826
//
827
// NOTE: this is part of the graphdb.NodeTraverser interface.
828
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
829
        cb func(channel *DirectedChannel) error, reset func()) error {
×
830

×
831
        var ctx = context.TODO()
×
832

×
833
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
834
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
835
        }, reset)
×
836
}
837

838
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
839
// graph, executing the passed callback with each node encountered. If the
840
// callback returns an error, then the transaction is aborted and the iteration
841
// stops early.
842
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
843
        cb func(route.Vertex, *lnwire.FeatureVector) error,
844
        reset func()) error {
×
845

×
846
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
847
                return forEachNodeCacheable(
×
848
                        ctx, s.cfg.QueryCfg, db,
×
849
                        func(_ int64, nodePub route.Vertex,
×
850
                                features *lnwire.FeatureVector) error {
×
851

×
852
                                return cb(nodePub, features)
×
853
                        },
×
854
                )
855
        }, reset)
856
        if err != nil {
×
857
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
858
        }
×
859

860
        return nil
×
861
}
862

863
// ForEachNodeChannel iterates through all channels of the given node,
864
// executing the passed callback with an edge info structure and the policies
865
// of each end of the channel. The first edge policy is the outgoing edge *to*
866
// the connecting node, while the second is the incoming edge *from* the
867
// connecting node. If the callback returns an error, then the iteration is
868
// halted with the error propagated back up to the caller.
869
//
870
// Unknown policies are passed into the callback as nil values.
871
//
872
// NOTE: part of the V1Store interface.
873
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
874
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
875
                *models.ChannelEdgePolicy) error, reset func()) error {
×
876

×
877
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
878
                dbNode, err := db.GetNodeByPubKey(
×
879
                        ctx, sqlc.GetNodeByPubKeyParams{
×
880
                                Version: int16(ProtocolV1),
×
881
                                PubKey:  nodePub[:],
×
882
                        },
×
883
                )
×
884
                if errors.Is(err, sql.ErrNoRows) {
×
885
                        return nil
×
886
                } else if err != nil {
×
887
                        return fmt.Errorf("unable to fetch node: %w", err)
×
888
                }
×
889

890
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
891
        }, reset)
892
}
893

894
// ChanUpdatesInHorizon returns all the known channel edges which have at least
895
// one edge that has an update timestamp within the specified horizon.
896
//
897
// NOTE: This is part of the V1Store interface.
898
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
899
        endTime time.Time) ([]ChannelEdge, error) {
×
900

×
901
        s.cacheMu.Lock()
×
902
        defer s.cacheMu.Unlock()
×
903

×
904
        var (
×
905
                ctx = context.TODO()
×
906
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
907
                // an additional map to keep track of the edges already seen to
×
908
                // prevent re-adding it.
×
909
                edgesSeen    = make(map[uint64]struct{})
×
910
                edgesToCache = make(map[uint64]ChannelEdge)
×
911
                edges        []ChannelEdge
×
912
                hits         int
×
913
        )
×
914
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
915
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
916
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
917
                                Version:   int16(ProtocolV1),
×
918
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
919
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
920
                        },
×
921
                )
×
922
                if err != nil {
×
923
                        return err
×
924
                }
×
925

926
                for _, row := range rows {
×
927
                        // If we've already retrieved the info and policies for
×
928
                        // this edge, then we can skip it as we don't need to do
×
929
                        // so again.
×
930
                        chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid)
×
931
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
932
                                continue
×
933
                        }
934

935
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
936
                                hits++
×
937
                                edgesSeen[chanIDInt] = struct{}{}
×
938
                                edges = append(edges, channel)
×
939

×
940
                                continue
×
941
                        }
942

943
                        node1, node2, err := buildNodes(
×
944
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
945
                        )
×
946
                        if err != nil {
×
947
                                return err
×
948
                        }
×
949

950
                        channel, err := getAndBuildEdgeInfo(
×
951
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
952
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
953
                        )
×
954
                        if err != nil {
×
955
                                return fmt.Errorf("unable to build channel "+
×
956
                                        "info: %w", err)
×
957
                        }
×
958

959
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
960
                        if err != nil {
×
961
                                return fmt.Errorf("unable to extract channel "+
×
962
                                        "policies: %w", err)
×
963
                        }
×
964

965
                        p1, p2, err := getAndBuildChanPolicies(
×
966
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
967
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
968
                        )
×
969
                        if err != nil {
×
970
                                return fmt.Errorf("unable to build channel "+
×
971
                                        "policies: %w", err)
×
972
                        }
×
973

974
                        edgesSeen[chanIDInt] = struct{}{}
×
975
                        chanEdge := ChannelEdge{
×
976
                                Info:    channel,
×
977
                                Policy1: p1,
×
978
                                Policy2: p2,
×
979
                                Node1:   node1,
×
980
                                Node2:   node2,
×
981
                        }
×
982
                        edges = append(edges, chanEdge)
×
983
                        edgesToCache[chanIDInt] = chanEdge
×
984
                }
985

986
                return nil
×
987
        }, func() {
×
988
                edgesSeen = make(map[uint64]struct{})
×
989
                edgesToCache = make(map[uint64]ChannelEdge)
×
990
                edges = nil
×
991
        })
×
992
        if err != nil {
×
993
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
994
        }
×
995

996
        // Insert any edges loaded from disk into the cache.
997
        for chanid, channel := range edgesToCache {
×
998
                s.chanCache.insert(chanid, channel)
×
999
        }
×
1000

1001
        if len(edges) > 0 {
×
1002
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1003
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1004
        } else {
×
1005
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1006
                        "horizon (%s, %s)", startTime, endTime)
×
1007
        }
×
1008

1009
        return edges, nil
×
1010
}
1011

1012
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1013
// data to the call-back. If withAddrs is true, then the call-back will also be
1014
// provided with the addresses associated with the node. The address retrieval
1015
// result in an additional round-trip to the database, so it should only be used
1016
// if the addresses are actually needed.
1017
//
1018
// NOTE: part of the V1Store interface.
1019
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1020
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1021
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1022

×
1023
        type nodeCachedBatchData struct {
×
1024
                features      map[int64][]int
×
1025
                addrs         map[int64][]nodeAddress
×
1026
                chanBatchData *batchChannelData
×
1027
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1028
        }
×
1029

×
1030
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1031
                // pageQueryFunc is used to query the next page of nodes.
×
1032
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1033
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1034

×
1035
                        return db.ListNodeIDsAndPubKeys(
×
1036
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1037
                                        Version: int16(ProtocolV1),
×
1038
                                        ID:      lastID,
×
1039
                                        Limit:   limit,
×
1040
                                },
×
1041
                        )
×
1042
                }
×
1043

1044
                // batchDataFunc is then used to batch load the data required
1045
                // for each page of nodes.
1046
                batchDataFunc := func(ctx context.Context,
×
1047
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1048

×
1049
                        // Batch load node features.
×
1050
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1051
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1052
                        )
×
1053
                        if err != nil {
×
1054
                                return nil, fmt.Errorf("unable to batch load "+
×
1055
                                        "node features: %w", err)
×
1056
                        }
×
1057

1058
                        // Maybe fetch the node's addresses if requested.
1059
                        var nodeAddrs map[int64][]nodeAddress
×
1060
                        if withAddrs {
×
1061
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1062
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1063
                                )
×
1064
                                if err != nil {
×
1065
                                        return nil, fmt.Errorf("unable to "+
×
1066
                                                "batch load node "+
×
1067
                                                "addresses: %w", err)
×
1068
                                }
×
1069
                        }
1070

1071
                        // Batch load ALL unique channels for ALL nodes in this
1072
                        // page.
1073
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1074
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1075
                                        Version:  int16(ProtocolV1),
×
1076
                                        Node1Ids: nodeIDs,
×
1077
                                        Node2Ids: nodeIDs,
×
1078
                                },
×
1079
                        )
×
1080
                        if err != nil {
×
1081
                                return nil, fmt.Errorf("unable to batch "+
×
1082
                                        "fetch channels for nodes: %w", err)
×
1083
                        }
×
1084

1085
                        // Deduplicate channels and collect IDs.
1086
                        var (
×
1087
                                allChannelIDs []int64
×
1088
                                allPolicyIDs  []int64
×
1089
                        )
×
1090
                        uniqueChannels := make(
×
1091
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1092
                        )
×
1093

×
1094
                        for _, channel := range allChannels {
×
1095
                                channelID := channel.GraphChannel.ID
×
1096

×
1097
                                // Only process each unique channel once.
×
1098
                                _, exists := uniqueChannels[channelID]
×
1099
                                if exists {
×
1100
                                        continue
×
1101
                                }
1102

1103
                                uniqueChannels[channelID] = channel
×
1104
                                allChannelIDs = append(allChannelIDs, channelID)
×
1105

×
1106
                                if channel.Policy1ID.Valid {
×
1107
                                        allPolicyIDs = append(
×
1108
                                                allPolicyIDs,
×
1109
                                                channel.Policy1ID.Int64,
×
1110
                                        )
×
1111
                                }
×
1112
                                if channel.Policy2ID.Valid {
×
1113
                                        allPolicyIDs = append(
×
1114
                                                allPolicyIDs,
×
1115
                                                channel.Policy2ID.Int64,
×
1116
                                        )
×
1117
                                }
×
1118
                        }
1119

1120
                        // Batch load channel data for all unique channels.
1121
                        channelBatchData, err := batchLoadChannelData(
×
1122
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1123
                                allPolicyIDs,
×
1124
                        )
×
1125
                        if err != nil {
×
1126
                                return nil, fmt.Errorf("unable to batch "+
×
1127
                                        "load channel data: %w", err)
×
1128
                        }
×
1129

1130
                        // Create map of node ID to channels that involve this
1131
                        // node.
1132
                        nodeIDSet := make(map[int64]bool)
×
1133
                        for _, nodeID := range nodeIDs {
×
1134
                                nodeIDSet[nodeID] = true
×
1135
                        }
×
1136

1137
                        nodeChannelMap := make(
×
1138
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1139
                        )
×
1140
                        for _, channel := range uniqueChannels {
×
1141
                                // Add channel to both nodes if they're in our
×
1142
                                // current page.
×
1143
                                node1 := channel.GraphChannel.NodeID1
×
1144
                                if nodeIDSet[node1] {
×
1145
                                        nodeChannelMap[node1] = append(
×
1146
                                                nodeChannelMap[node1], channel,
×
1147
                                        )
×
1148
                                }
×
1149
                                node2 := channel.GraphChannel.NodeID2
×
1150
                                if nodeIDSet[node2] {
×
1151
                                        nodeChannelMap[node2] = append(
×
1152
                                                nodeChannelMap[node2], channel,
×
1153
                                        )
×
1154
                                }
×
1155
                        }
1156

1157
                        return &nodeCachedBatchData{
×
1158
                                features:      nodeFeatures,
×
1159
                                addrs:         nodeAddrs,
×
1160
                                chanBatchData: channelBatchData,
×
1161
                                chanMap:       nodeChannelMap,
×
1162
                        }, nil
×
1163
                }
1164

1165
                // processItem is used to process each node in the current page.
1166
                processItem := func(ctx context.Context,
×
1167
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1168
                        batchData *nodeCachedBatchData) error {
×
1169

×
1170
                        // Build feature vector for this node.
×
1171
                        fv := lnwire.EmptyFeatureVector()
×
1172
                        features, exists := batchData.features[nodeData.ID]
×
1173
                        if exists {
×
1174
                                for _, bit := range features {
×
1175
                                        fv.Set(lnwire.FeatureBit(bit))
×
1176
                                }
×
1177
                        }
1178

1179
                        var nodePub route.Vertex
×
1180
                        copy(nodePub[:], nodeData.PubKey)
×
1181

×
1182
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1183

×
1184
                        toNodeCallback := func() route.Vertex {
×
1185
                                return nodePub
×
1186
                        }
×
1187

1188
                        // Build cached channels map for this node.
1189
                        channels := make(map[uint64]*DirectedChannel)
×
1190
                        for _, channelRow := range nodeChannels {
×
1191
                                directedChan, err := buildDirectedChannel(
×
1192
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1193
                                        channelRow, batchData.chanBatchData, fv,
×
1194
                                        toNodeCallback,
×
1195
                                )
×
1196
                                if err != nil {
×
1197
                                        return err
×
1198
                                }
×
1199

1200
                                channels[directedChan.ChannelID] = directedChan
×
1201
                        }
1202

1203
                        addrs, err := buildNodeAddresses(
×
1204
                                batchData.addrs[nodeData.ID],
×
1205
                        )
×
1206
                        if err != nil {
×
1207
                                return fmt.Errorf("unable to build node "+
×
1208
                                        "addresses: %w", err)
×
1209
                        }
×
1210

1211
                        return cb(ctx, nodePub, addrs, channels)
×
1212
                }
1213

1214
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1215
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1216
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1217
                                return node.ID
×
1218
                        },
×
1219
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1220
                                error) {
×
1221

×
1222
                                return node.ID, nil
×
1223
                        },
×
1224
                        batchDataFunc, processItem,
1225
                )
1226
        }, reset)
1227
}
1228

1229
// ForEachChannelCacheable iterates through all the channel edges stored
1230
// within the graph and invokes the passed callback for each edge. The
1231
// callback takes two edges as since this is a directed graph, both the
1232
// in/out edges are visited. If the callback returns an error, then the
1233
// transaction is aborted and the iteration stops early.
1234
//
1235
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1236
// pointer for that particular channel edge routing policy will be
1237
// passed into the callback.
1238
//
1239
// NOTE: this method is like ForEachChannel but fetches only the data
1240
// required for the graph cache.
1241
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1242
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1243
        reset func()) error {
×
1244

×
1245
        ctx := context.TODO()
×
1246

×
1247
        handleChannel := func(_ context.Context,
×
1248
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1249

×
1250
                node1, node2, err := buildNodeVertices(
×
1251
                        row.Node1Pubkey, row.Node2Pubkey,
×
1252
                )
×
1253
                if err != nil {
×
1254
                        return err
×
1255
                }
×
1256

1257
                edge := buildCacheableChannelInfo(
×
1258
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1259
                )
×
1260

×
1261
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1262
                if err != nil {
×
1263
                        return err
×
1264
                }
×
1265

1266
                pol1, pol2, err := buildCachedChanPolicies(
×
1267
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1268
                )
×
1269
                if err != nil {
×
1270
                        return err
×
1271
                }
×
1272

1273
                return cb(edge, pol1, pol2)
×
1274
        }
1275

1276
        extractCursor := func(
×
1277
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1278

×
1279
                return row.ID
×
1280
        }
×
1281

1282
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1283
                //nolint:ll
×
1284
                queryFunc := func(ctx context.Context, lastID int64,
×
1285
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1286
                        error) {
×
1287

×
1288
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1289
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1290
                                        Version: int16(ProtocolV1),
×
1291
                                        ID:      lastID,
×
1292
                                        Limit:   limit,
×
1293
                                },
×
1294
                        )
×
1295
                }
×
1296

1297
                return sqldb.ExecutePaginatedQuery(
×
1298
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1299
                        extractCursor, handleChannel,
×
1300
                )
×
1301
        }, reset)
1302
}
1303

1304
// ForEachChannel iterates through all the channel edges stored within the
1305
// graph and invokes the passed callback for each edge. The callback takes two
1306
// edges as since this is a directed graph, both the in/out edges are visited.
1307
// If the callback returns an error, then the transaction is aborted and the
1308
// iteration stops early.
1309
//
1310
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1311
// for that particular channel edge routing policy will be passed into the
1312
// callback.
1313
//
1314
// NOTE: part of the V1Store interface.
1315
func (s *SQLStore) ForEachChannel(ctx context.Context,
1316
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1317
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1318

×
1319
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1320
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1321
        }, reset)
×
1322
}
1323

1324
// FilterChannelRange returns the channel ID's of all known channels which were
1325
// mined in a block height within the passed range. The channel IDs are grouped
1326
// by their common block height. This method can be used to quickly share with a
1327
// peer the set of channels we know of within a particular range to catch them
1328
// up after a period of time offline. If withTimestamps is true then the
1329
// timestamp info of the latest received channel update messages of the channel
1330
// will be included in the response.
1331
//
1332
// NOTE: This is part of the V1Store interface.
1333
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1334
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1335

×
1336
        var (
×
1337
                ctx       = context.TODO()
×
1338
                startSCID = &lnwire.ShortChannelID{
×
1339
                        BlockHeight: startHeight,
×
1340
                }
×
1341
                endSCID = lnwire.ShortChannelID{
×
1342
                        BlockHeight: endHeight,
×
1343
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1344
                        TxPosition:  math.MaxUint16,
×
1345
                }
×
1346
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1347
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1348
        )
×
1349

×
1350
        // 1) get all channels where channelID is between start and end chan ID.
×
1351
        // 2) skip if not public (ie, no channel_proof)
×
1352
        // 3) collect that channel.
×
1353
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1354
        //    and add those timestamps to the collected channel.
×
1355
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1356
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1357
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1358
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1359
                                StartScid: chanIDStart,
×
1360
                                EndScid:   chanIDEnd,
×
1361
                        },
×
1362
                )
×
1363
                if err != nil {
×
1364
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1365
                                err)
×
1366
                }
×
1367

1368
                for _, dbChan := range dbChans {
×
1369
                        cid := lnwire.NewShortChanIDFromInt(
×
1370
                                byteOrder.Uint64(dbChan.Scid),
×
1371
                        )
×
1372
                        chanInfo := NewChannelUpdateInfo(
×
1373
                                cid, time.Time{}, time.Time{},
×
1374
                        )
×
1375

×
1376
                        if !withTimestamps {
×
1377
                                channelsPerBlock[cid.BlockHeight] = append(
×
1378
                                        channelsPerBlock[cid.BlockHeight],
×
1379
                                        chanInfo,
×
1380
                                )
×
1381

×
1382
                                continue
×
1383
                        }
1384

1385
                        //nolint:ll
1386
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1387
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1388
                                        Version:   int16(ProtocolV1),
×
1389
                                        ChannelID: dbChan.ID,
×
1390
                                        NodeID:    dbChan.NodeID1,
×
1391
                                },
×
1392
                        )
×
1393
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1394
                                return fmt.Errorf("unable to fetch node1 "+
×
1395
                                        "policy: %w", err)
×
1396
                        } else if err == nil {
×
1397
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1398
                                        node1Policy.LastUpdate.Int64, 0,
×
1399
                                )
×
1400
                        }
×
1401

1402
                        //nolint:ll
1403
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1404
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1405
                                        Version:   int16(ProtocolV1),
×
1406
                                        ChannelID: dbChan.ID,
×
1407
                                        NodeID:    dbChan.NodeID2,
×
1408
                                },
×
1409
                        )
×
1410
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1411
                                return fmt.Errorf("unable to fetch node2 "+
×
1412
                                        "policy: %w", err)
×
1413
                        } else if err == nil {
×
1414
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1415
                                        node2Policy.LastUpdate.Int64, 0,
×
1416
                                )
×
1417
                        }
×
1418

1419
                        channelsPerBlock[cid.BlockHeight] = append(
×
1420
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1421
                        )
×
1422
                }
1423

1424
                return nil
×
1425
        }, func() {
×
1426
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1427
        })
×
1428
        if err != nil {
×
1429
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1430
        }
×
1431

1432
        if len(channelsPerBlock) == 0 {
×
1433
                return nil, nil
×
1434
        }
×
1435

1436
        // Return the channel ranges in ascending block height order.
1437
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1438
        slices.Sort(blocks)
×
1439

×
1440
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1441
                return BlockChannelRange{
×
1442
                        Height:   block,
×
1443
                        Channels: channelsPerBlock[block],
×
1444
                }
×
1445
        }), nil
×
1446
}
1447

1448
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1449
// zombie. This method is used on an ad-hoc basis, when channels need to be
1450
// marked as zombies outside the normal pruning cycle.
1451
//
1452
// NOTE: part of the V1Store interface.
1453
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1454
        pubKey1, pubKey2 [33]byte) error {
×
1455

×
1456
        ctx := context.TODO()
×
1457

×
1458
        s.cacheMu.Lock()
×
1459
        defer s.cacheMu.Unlock()
×
1460

×
1461
        chanIDB := channelIDToBytes(chanID)
×
1462

×
1463
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1464
                return db.UpsertZombieChannel(
×
1465
                        ctx, sqlc.UpsertZombieChannelParams{
×
1466
                                Version:  int16(ProtocolV1),
×
1467
                                Scid:     chanIDB,
×
1468
                                NodeKey1: pubKey1[:],
×
1469
                                NodeKey2: pubKey2[:],
×
1470
                        },
×
1471
                )
×
1472
        }, sqldb.NoOpReset)
×
1473
        if err != nil {
×
1474
                return fmt.Errorf("unable to upsert zombie channel "+
×
1475
                        "(channel_id=%d): %w", chanID, err)
×
1476
        }
×
1477

1478
        s.rejectCache.remove(chanID)
×
1479
        s.chanCache.remove(chanID)
×
1480

×
1481
        return nil
×
1482
}
1483

1484
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1485
//
1486
// NOTE: part of the V1Store interface.
1487
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1488
        s.cacheMu.Lock()
×
1489
        defer s.cacheMu.Unlock()
×
1490

×
1491
        var (
×
1492
                ctx     = context.TODO()
×
1493
                chanIDB = channelIDToBytes(chanID)
×
1494
        )
×
1495

×
1496
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1497
                res, err := db.DeleteZombieChannel(
×
1498
                        ctx, sqlc.DeleteZombieChannelParams{
×
1499
                                Scid:    chanIDB,
×
1500
                                Version: int16(ProtocolV1),
×
1501
                        },
×
1502
                )
×
1503
                if err != nil {
×
1504
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1505
                                err)
×
1506
                }
×
1507

1508
                rows, err := res.RowsAffected()
×
1509
                if err != nil {
×
1510
                        return err
×
1511
                }
×
1512

1513
                if rows == 0 {
×
1514
                        return ErrZombieEdgeNotFound
×
1515
                } else if rows > 1 {
×
1516
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1517
                                "expected 1", rows)
×
1518
                }
×
1519

1520
                return nil
×
1521
        }, sqldb.NoOpReset)
1522
        if err != nil {
×
1523
                return fmt.Errorf("unable to mark edge live "+
×
1524
                        "(channel_id=%d): %w", chanID, err)
×
1525
        }
×
1526

1527
        s.rejectCache.remove(chanID)
×
1528
        s.chanCache.remove(chanID)
×
1529

×
1530
        return err
×
1531
}
1532

1533
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1534
// zombie, then the two node public keys corresponding to this edge are also
1535
// returned.
1536
//
1537
// NOTE: part of the V1Store interface.
1538
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1539
        error) {
×
1540

×
1541
        var (
×
1542
                ctx              = context.TODO()
×
1543
                isZombie         bool
×
1544
                pubKey1, pubKey2 route.Vertex
×
1545
                chanIDB          = channelIDToBytes(chanID)
×
1546
        )
×
1547

×
1548
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1549
                zombie, err := db.GetZombieChannel(
×
1550
                        ctx, sqlc.GetZombieChannelParams{
×
1551
                                Scid:    chanIDB,
×
1552
                                Version: int16(ProtocolV1),
×
1553
                        },
×
1554
                )
×
1555
                if errors.Is(err, sql.ErrNoRows) {
×
1556
                        return nil
×
1557
                }
×
1558
                if err != nil {
×
1559
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1560
                                err)
×
1561
                }
×
1562

1563
                copy(pubKey1[:], zombie.NodeKey1)
×
1564
                copy(pubKey2[:], zombie.NodeKey2)
×
1565
                isZombie = true
×
1566

×
1567
                return nil
×
1568
        }, sqldb.NoOpReset)
1569
        if err != nil {
×
1570
                return false, route.Vertex{}, route.Vertex{},
×
1571
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1572
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1573
        }
×
1574

1575
        return isZombie, pubKey1, pubKey2, nil
×
1576
}
1577

1578
// NumZombies returns the current number of zombie channels in the graph.
1579
//
1580
// NOTE: part of the V1Store interface.
1581
func (s *SQLStore) NumZombies() (uint64, error) {
×
1582
        var (
×
1583
                ctx        = context.TODO()
×
1584
                numZombies uint64
×
1585
        )
×
1586
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1587
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1588
                if err != nil {
×
1589
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1590
                                err)
×
1591
                }
×
1592

1593
                numZombies = uint64(count)
×
1594

×
1595
                return nil
×
1596
        }, sqldb.NoOpReset)
1597
        if err != nil {
×
1598
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1599
        }
×
1600

1601
        return numZombies, nil
×
1602
}
1603

1604
// DeleteChannelEdges removes edges with the given channel IDs from the
1605
// database and marks them as zombies. This ensures that we're unable to re-add
1606
// it to our database once again. If an edge does not exist within the
1607
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1608
// true, then when we mark these edges as zombies, we'll set up the keys such
1609
// that we require the node that failed to send the fresh update to be the one
1610
// that resurrects the channel from its zombie state. The markZombie bool
1611
// denotes whether to mark the channel as a zombie.
1612
//
1613
// NOTE: part of the V1Store interface.
1614
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1615
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1616

×
1617
        s.cacheMu.Lock()
×
1618
        defer s.cacheMu.Unlock()
×
1619

×
1620
        // Keep track of which channels we end up finding so that we can
×
1621
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1622
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1623
        for _, chanID := range chanIDs {
×
1624
                chanLookup[chanID] = struct{}{}
×
1625
        }
×
1626

1627
        var (
×
1628
                ctx     = context.TODO()
×
1629
                deleted []*models.ChannelEdgeInfo
×
1630
        )
×
1631
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1632
                chanIDsToDelete := make([]int64, 0, len(chanIDs))
×
1633
                chanCallBack := func(ctx context.Context,
×
1634
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1635

×
1636
                        // Deleting the entry from the map indicates that we
×
1637
                        // have found the channel.
×
1638
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1639
                        delete(chanLookup, scid)
×
1640

×
1641
                        node1, node2, err := buildNodeVertices(
×
1642
                                row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1643
                        )
×
1644
                        if err != nil {
×
1645
                                return err
×
1646
                        }
×
1647

1648
                        info, err := getAndBuildEdgeInfo(
×
1649
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
1650
                                node1, node2,
×
1651
                        )
×
1652
                        if err != nil {
×
1653
                                return err
×
1654
                        }
×
1655

1656
                        deleted = append(deleted, info)
×
1657
                        chanIDsToDelete = append(
×
1658
                                chanIDsToDelete, row.GraphChannel.ID,
×
1659
                        )
×
1660

×
1661
                        if !markZombie {
×
1662
                                return nil
×
1663
                        }
×
1664

1665
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1666
                                info.NodeKey2Bytes
×
1667
                        if strictZombiePruning {
×
1668
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1669
                                if row.Policy1LastUpdate.Valid {
×
1670
                                        e1Time := time.Unix(
×
1671
                                                row.Policy1LastUpdate.Int64, 0,
×
1672
                                        )
×
1673
                                        e1UpdateTime = &e1Time
×
1674
                                }
×
1675
                                if row.Policy2LastUpdate.Valid {
×
1676
                                        e2Time := time.Unix(
×
1677
                                                row.Policy2LastUpdate.Int64, 0,
×
1678
                                        )
×
1679
                                        e2UpdateTime = &e2Time
×
1680
                                }
×
1681

1682
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1683
                                        info, e1UpdateTime, e2UpdateTime,
×
1684
                                )
×
1685
                        }
1686

1687
                        err = db.UpsertZombieChannel(
×
1688
                                ctx, sqlc.UpsertZombieChannelParams{
×
1689
                                        Version:  int16(ProtocolV1),
×
1690
                                        Scid:     channelIDToBytes(scid),
×
1691
                                        NodeKey1: nodeKey1[:],
×
1692
                                        NodeKey2: nodeKey2[:],
×
1693
                                },
×
1694
                        )
×
1695
                        if err != nil {
×
1696
                                return fmt.Errorf("unable to mark channel as "+
×
1697
                                        "zombie: %w", err)
×
1698
                        }
×
1699

1700
                        return nil
×
1701
                }
1702

1703
                err := s.forEachChanWithPoliciesInSCIDList(
×
1704
                        ctx, db, chanCallBack, chanIDs,
×
1705
                )
×
1706
                if err != nil {
×
1707
                        return err
×
1708
                }
×
1709

1710
                if len(chanLookup) > 0 {
×
1711
                        return ErrEdgeNotFound
×
1712
                }
×
1713

1714
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1715
        }, func() {
×
1716
                deleted = nil
×
1717

×
1718
                // Re-fill the lookup map.
×
1719
                for _, chanID := range chanIDs {
×
1720
                        chanLookup[chanID] = struct{}{}
×
1721
                }
×
1722
        })
1723
        if err != nil {
×
1724
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1725
                        err)
×
1726
        }
×
1727

1728
        for _, chanID := range chanIDs {
×
1729
                s.rejectCache.remove(chanID)
×
1730
                s.chanCache.remove(chanID)
×
1731
        }
×
1732

1733
        return deleted, nil
×
1734
}
1735

1736
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1737
// channel identified by the channel ID. If the channel can't be found, then
1738
// ErrEdgeNotFound is returned. A struct which houses the general information
1739
// for the channel itself is returned as well as two structs that contain the
1740
// routing policies for the channel in either direction.
1741
//
1742
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1743
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1744
// the ChannelEdgeInfo will only include the public keys of each node.
1745
//
1746
// NOTE: part of the V1Store interface.
1747
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1748
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1749
        *models.ChannelEdgePolicy, error) {
×
1750

×
1751
        var (
×
1752
                ctx              = context.TODO()
×
1753
                edge             *models.ChannelEdgeInfo
×
1754
                policy1, policy2 *models.ChannelEdgePolicy
×
1755
                chanIDB          = channelIDToBytes(chanID)
×
1756
        )
×
1757
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1758
                row, err := db.GetChannelBySCIDWithPolicies(
×
1759
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1760
                                Scid:    chanIDB,
×
1761
                                Version: int16(ProtocolV1),
×
1762
                        },
×
1763
                )
×
1764
                if errors.Is(err, sql.ErrNoRows) {
×
1765
                        // First check if this edge is perhaps in the zombie
×
1766
                        // index.
×
1767
                        zombie, err := db.GetZombieChannel(
×
1768
                                ctx, sqlc.GetZombieChannelParams{
×
1769
                                        Scid:    chanIDB,
×
1770
                                        Version: int16(ProtocolV1),
×
1771
                                },
×
1772
                        )
×
1773
                        if errors.Is(err, sql.ErrNoRows) {
×
1774
                                return ErrEdgeNotFound
×
1775
                        } else if err != nil {
×
1776
                                return fmt.Errorf("unable to check if "+
×
1777
                                        "channel is zombie: %w", err)
×
1778
                        }
×
1779

1780
                        // At this point, we know the channel is a zombie, so
1781
                        // we'll return an error indicating this, and we will
1782
                        // populate the edge info with the public keys of each
1783
                        // party as this is the only information we have about
1784
                        // it.
1785
                        edge = &models.ChannelEdgeInfo{}
×
1786
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1787
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1788

×
1789
                        return ErrZombieEdge
×
1790
                } else if err != nil {
×
1791
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1792
                }
×
1793

1794
                node1, node2, err := buildNodeVertices(
×
1795
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1796
                )
×
1797
                if err != nil {
×
1798
                        return err
×
1799
                }
×
1800

1801
                edge, err = getAndBuildEdgeInfo(
×
1802
                        ctx, db, s.cfg.ChainHash, row.GraphChannel, node1,
×
1803
                        node2,
×
1804
                )
×
1805
                if err != nil {
×
1806
                        return fmt.Errorf("unable to build channel info: %w",
×
1807
                                err)
×
1808
                }
×
1809

1810
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1811
                if err != nil {
×
1812
                        return fmt.Errorf("unable to extract channel "+
×
1813
                                "policies: %w", err)
×
1814
                }
×
1815

1816
                policy1, policy2, err = getAndBuildChanPolicies(
×
1817
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1818
                )
×
1819
                if err != nil {
×
1820
                        return fmt.Errorf("unable to build channel "+
×
1821
                                "policies: %w", err)
×
1822
                }
×
1823

1824
                return nil
×
1825
        }, sqldb.NoOpReset)
1826
        if err != nil {
×
1827
                // If we are returning the ErrZombieEdge, then we also need to
×
1828
                // return the edge info as the method comment indicates that
×
1829
                // this will be populated when the edge is a zombie.
×
1830
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1831
                        err)
×
1832
        }
×
1833

1834
        return edge, policy1, policy2, nil
×
1835
}
1836

1837
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1838
// the channel identified by the funding outpoint. If the channel can't be
1839
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1840
// information for the channel itself is returned as well as two structs that
1841
// contain the routing policies for the channel in either direction.
1842
//
1843
// NOTE: part of the V1Store interface.
1844
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1845
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1846
        *models.ChannelEdgePolicy, error) {
×
1847

×
1848
        var (
×
1849
                ctx              = context.TODO()
×
1850
                edge             *models.ChannelEdgeInfo
×
1851
                policy1, policy2 *models.ChannelEdgePolicy
×
1852
        )
×
1853
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1854
                row, err := db.GetChannelByOutpointWithPolicies(
×
1855
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1856
                                Outpoint: op.String(),
×
1857
                                Version:  int16(ProtocolV1),
×
1858
                        },
×
1859
                )
×
1860
                if errors.Is(err, sql.ErrNoRows) {
×
1861
                        return ErrEdgeNotFound
×
1862
                } else if err != nil {
×
1863
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1864
                }
×
1865

1866
                node1, node2, err := buildNodeVertices(
×
1867
                        row.Node1Pubkey, row.Node2Pubkey,
×
1868
                )
×
1869
                if err != nil {
×
1870
                        return err
×
1871
                }
×
1872

1873
                edge, err = getAndBuildEdgeInfo(
×
1874
                        ctx, db, s.cfg.ChainHash, row.GraphChannel, node1,
×
1875
                        node2,
×
1876
                )
×
1877
                if err != nil {
×
1878
                        return fmt.Errorf("unable to build channel info: %w",
×
1879
                                err)
×
1880
                }
×
1881

1882
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1883
                if err != nil {
×
1884
                        return fmt.Errorf("unable to extract channel "+
×
1885
                                "policies: %w", err)
×
1886
                }
×
1887

1888
                policy1, policy2, err = getAndBuildChanPolicies(
×
1889
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1890
                )
×
1891
                if err != nil {
×
1892
                        return fmt.Errorf("unable to build channel "+
×
1893
                                "policies: %w", err)
×
1894
                }
×
1895

1896
                return nil
×
1897
        }, sqldb.NoOpReset)
1898
        if err != nil {
×
1899
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1900
                        err)
×
1901
        }
×
1902

1903
        return edge, policy1, policy2, nil
×
1904
}
1905

1906
// HasChannelEdge returns true if the database knows of a channel edge with the
1907
// passed channel ID, and false otherwise. If an edge with that ID is found
1908
// within the graph, then two time stamps representing the last time the edge
1909
// was updated for both directed edges are returned along with the boolean. If
1910
// it is not found, then the zombie index is checked and its result is returned
1911
// as the second boolean.
1912
//
1913
// NOTE: part of the V1Store interface.
1914
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1915
        bool, error) {
×
1916

×
1917
        ctx := context.TODO()
×
1918

×
1919
        var (
×
1920
                exists          bool
×
1921
                isZombie        bool
×
1922
                node1LastUpdate time.Time
×
1923
                node2LastUpdate time.Time
×
1924
        )
×
1925

×
1926
        // We'll query the cache with the shared lock held to allow multiple
×
1927
        // readers to access values in the cache concurrently if they exist.
×
1928
        s.cacheMu.RLock()
×
1929
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1930
                s.cacheMu.RUnlock()
×
1931
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1932
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1933
                exists, isZombie = entry.flags.unpack()
×
1934

×
1935
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1936
        }
×
1937
        s.cacheMu.RUnlock()
×
1938

×
1939
        s.cacheMu.Lock()
×
1940
        defer s.cacheMu.Unlock()
×
1941

×
1942
        // The item was not found with the shared lock, so we'll acquire the
×
1943
        // exclusive lock and check the cache again in case another method added
×
1944
        // the entry to the cache while no lock was held.
×
1945
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1946
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1947
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1948
                exists, isZombie = entry.flags.unpack()
×
1949

×
1950
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1951
        }
×
1952

1953
        chanIDB := channelIDToBytes(chanID)
×
1954
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1955
                channel, err := db.GetChannelBySCID(
×
1956
                        ctx, sqlc.GetChannelBySCIDParams{
×
1957
                                Scid:    chanIDB,
×
1958
                                Version: int16(ProtocolV1),
×
1959
                        },
×
1960
                )
×
1961
                if errors.Is(err, sql.ErrNoRows) {
×
1962
                        // Check if it is a zombie channel.
×
1963
                        isZombie, err = db.IsZombieChannel(
×
1964
                                ctx, sqlc.IsZombieChannelParams{
×
1965
                                        Scid:    chanIDB,
×
1966
                                        Version: int16(ProtocolV1),
×
1967
                                },
×
1968
                        )
×
1969
                        if err != nil {
×
1970
                                return fmt.Errorf("could not check if channel "+
×
1971
                                        "is zombie: %w", err)
×
1972
                        }
×
1973

1974
                        return nil
×
1975
                } else if err != nil {
×
1976
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1977
                }
×
1978

1979
                exists = true
×
1980

×
1981
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
1982
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1983
                                Version:   int16(ProtocolV1),
×
1984
                                ChannelID: channel.ID,
×
1985
                                NodeID:    channel.NodeID1,
×
1986
                        },
×
1987
                )
×
1988
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1989
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
1990
                                err)
×
1991
                } else if err == nil {
×
1992
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
1993
                }
×
1994

1995
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
1996
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1997
                                Version:   int16(ProtocolV1),
×
1998
                                ChannelID: channel.ID,
×
1999
                                NodeID:    channel.NodeID2,
×
2000
                        },
×
2001
                )
×
2002
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2003
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2004
                                err)
×
2005
                } else if err == nil {
×
2006
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2007
                }
×
2008

2009
                return nil
×
2010
        }, sqldb.NoOpReset)
2011
        if err != nil {
×
2012
                return time.Time{}, time.Time{}, false, false,
×
2013
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2014
        }
×
2015

2016
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2017
                upd1Time: node1LastUpdate.Unix(),
×
2018
                upd2Time: node2LastUpdate.Unix(),
×
2019
                flags:    packRejectFlags(exists, isZombie),
×
2020
        })
×
2021

×
2022
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2023
}
2024

2025
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2026
// passed channel point (outpoint). If the passed channel doesn't exist within
2027
// the database, then ErrEdgeNotFound is returned.
2028
//
2029
// NOTE: part of the V1Store interface.
2030
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2031
        var (
×
2032
                ctx       = context.TODO()
×
2033
                channelID uint64
×
2034
        )
×
2035
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2036
                chanID, err := db.GetSCIDByOutpoint(
×
2037
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2038
                                Outpoint: chanPoint.String(),
×
2039
                                Version:  int16(ProtocolV1),
×
2040
                        },
×
2041
                )
×
2042
                if errors.Is(err, sql.ErrNoRows) {
×
2043
                        return ErrEdgeNotFound
×
2044
                } else if err != nil {
×
2045
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2046
                                err)
×
2047
                }
×
2048

2049
                channelID = byteOrder.Uint64(chanID)
×
2050

×
2051
                return nil
×
2052
        }, sqldb.NoOpReset)
2053
        if err != nil {
×
2054
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2055
        }
×
2056

2057
        return channelID, nil
×
2058
}
2059

2060
// IsPublicNode is a helper method that determines whether the node with the
2061
// given public key is seen as a public node in the graph from the graph's
2062
// source node's point of view.
2063
//
2064
// NOTE: part of the V1Store interface.
2065
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2066
        ctx := context.TODO()
×
2067

×
2068
        var isPublic bool
×
2069
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2070
                var err error
×
2071
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2072

×
2073
                return err
×
2074
        }, sqldb.NoOpReset)
×
2075
        if err != nil {
×
2076
                return false, fmt.Errorf("unable to check if node is "+
×
2077
                        "public: %w", err)
×
2078
        }
×
2079

2080
        return isPublic, nil
×
2081
}
2082

2083
// FetchChanInfos returns the set of channel edges that correspond to the passed
2084
// channel ID's. If an edge is the query is unknown to the database, it will
2085
// skipped and the result will contain only those edges that exist at the time
2086
// of the query. This can be used to respond to peer queries that are seeking to
2087
// fill in gaps in their view of the channel graph.
2088
//
2089
// NOTE: part of the V1Store interface.
2090
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2091
        var (
×
2092
                ctx   = context.TODO()
×
2093
                edges = make(map[uint64]ChannelEdge)
×
2094
        )
×
2095
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2096
                chanCallBack := func(ctx context.Context,
×
2097
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2098

×
2099
                        node1, node2, err := buildNodes(
×
2100
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
2101
                        )
×
2102
                        if err != nil {
×
2103
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2104
                                        err)
×
2105
                        }
×
2106

2107
                        edge, err := getAndBuildEdgeInfo(
×
2108
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
2109
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2110
                        )
×
2111
                        if err != nil {
×
2112
                                return fmt.Errorf("unable to build "+
×
2113
                                        "channel info: %w", err)
×
2114
                        }
×
2115

2116
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2117
                        if err != nil {
×
2118
                                return fmt.Errorf("unable to extract channel "+
×
2119
                                        "policies: %w", err)
×
2120
                        }
×
2121

2122
                        p1, p2, err := getAndBuildChanPolicies(
×
2123
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2124
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2125
                        )
×
2126
                        if err != nil {
×
2127
                                return fmt.Errorf("unable to build channel "+
×
2128
                                        "policies: %w", err)
×
2129
                        }
×
2130

2131
                        edges[edge.ChannelID] = ChannelEdge{
×
2132
                                Info:    edge,
×
2133
                                Policy1: p1,
×
2134
                                Policy2: p2,
×
2135
                                Node1:   node1,
×
2136
                                Node2:   node2,
×
2137
                        }
×
2138

×
2139
                        return nil
×
2140
                }
2141

2142
                return s.forEachChanWithPoliciesInSCIDList(
×
2143
                        ctx, db, chanCallBack, chanIDs,
×
2144
                )
×
2145
        }, func() {
×
2146
                clear(edges)
×
2147
        })
×
2148
        if err != nil {
×
2149
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2150
        }
×
2151

2152
        res := make([]ChannelEdge, 0, len(edges))
×
2153
        for _, chanID := range chanIDs {
×
2154
                edge, ok := edges[chanID]
×
2155
                if !ok {
×
2156
                        continue
×
2157
                }
2158

2159
                res = append(res, edge)
×
2160
        }
2161

2162
        return res, nil
×
2163
}
2164

2165
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2166
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2167
// channels in a paginated manner.
2168
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2169
        db SQLQueries, cb func(ctx context.Context,
2170
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2171
        chanIDs []uint64) error {
×
2172

×
2173
        queryWrapper := func(ctx context.Context,
×
2174
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2175
                error) {
×
2176

×
2177
                return db.GetChannelsBySCIDWithPolicies(
×
2178
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2179
                                Version: int16(ProtocolV1),
×
2180
                                Scids:   scids,
×
2181
                        },
×
2182
                )
×
2183
        }
×
2184

2185
        return sqldb.ExecuteBatchQuery(
×
2186
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2187
                cb,
×
2188
        )
×
2189
}
2190

2191
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2192
// ID's that we don't know and are not known zombies of the passed set. In other
2193
// words, we perform a set difference of our set of chan ID's and the ones
2194
// passed in. This method can be used by callers to determine the set of
2195
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2196
// known zombies is also returned.
2197
//
2198
// NOTE: part of the V1Store interface.
2199
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2200
        []ChannelUpdateInfo, error) {
×
2201

×
2202
        var (
×
2203
                ctx          = context.TODO()
×
2204
                newChanIDs   []uint64
×
2205
                knownZombies []ChannelUpdateInfo
×
2206
                infoLookup   = make(
×
2207
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2208
                )
×
2209
        )
×
2210

×
2211
        // We first build a lookup map of the channel ID's to the
×
2212
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2213
        // already know about.
×
2214
        for _, chanInfo := range chansInfo {
×
2215
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2216
        }
×
2217

2218
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2219
                // The call-back function deletes known channels from
×
2220
                // infoLookup, so that we can later check which channels are
×
2221
                // zombies by only looking at the remaining channels in the set.
×
2222
                cb := func(ctx context.Context,
×
2223
                        channel sqlc.GraphChannel) error {
×
2224

×
2225
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2226

×
2227
                        return nil
×
2228
                }
×
2229

2230
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2231
                if err != nil {
×
2232
                        return fmt.Errorf("unable to iterate through "+
×
2233
                                "channels: %w", err)
×
2234
                }
×
2235

2236
                // We want to ensure that we deal with the channels in the
2237
                // same order that they were passed in, so we iterate over the
2238
                // original chansInfo slice and then check if that channel is
2239
                // still in the infoLookup map.
2240
                for _, chanInfo := range chansInfo {
×
2241
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2242
                        if _, ok := infoLookup[channelID]; !ok {
×
2243
                                continue
×
2244
                        }
2245

2246
                        isZombie, err := db.IsZombieChannel(
×
2247
                                ctx, sqlc.IsZombieChannelParams{
×
2248
                                        Scid:    channelIDToBytes(channelID),
×
2249
                                        Version: int16(ProtocolV1),
×
2250
                                },
×
2251
                        )
×
2252
                        if err != nil {
×
2253
                                return fmt.Errorf("unable to fetch zombie "+
×
2254
                                        "channel: %w", err)
×
2255
                        }
×
2256

2257
                        if isZombie {
×
2258
                                knownZombies = append(knownZombies, chanInfo)
×
2259

×
2260
                                continue
×
2261
                        }
2262

2263
                        newChanIDs = append(newChanIDs, channelID)
×
2264
                }
2265

2266
                return nil
×
2267
        }, func() {
×
2268
                newChanIDs = nil
×
2269
                knownZombies = nil
×
2270
                // Rebuild the infoLookup map in case of a rollback.
×
2271
                for _, chanInfo := range chansInfo {
×
2272
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2273
                        infoLookup[scid] = chanInfo
×
2274
                }
×
2275
        })
2276
        if err != nil {
×
2277
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2278
        }
×
2279

2280
        return newChanIDs, knownZombies, nil
×
2281
}
2282

2283
// forEachChanInSCIDList is a helper method that executes a paged query
2284
// against the database to fetch all channels that match the passed
2285
// ChannelUpdateInfo slice. The callback function is called for each channel
2286
// that is found.
2287
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2288
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2289
        chansInfo []ChannelUpdateInfo) error {
×
2290

×
2291
        queryWrapper := func(ctx context.Context,
×
2292
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2293

×
2294
                return db.GetChannelsBySCIDs(
×
2295
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2296
                                Version: int16(ProtocolV1),
×
2297
                                Scids:   scids,
×
2298
                        },
×
2299
                )
×
2300
        }
×
2301

2302
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2303
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2304

×
2305
                return channelIDToBytes(channelID)
×
2306
        }
×
2307

2308
        return sqldb.ExecuteBatchQuery(
×
2309
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2310
                cb,
×
2311
        )
×
2312
}
2313

2314
// PruneGraphNodes is a garbage collection method which attempts to prune out
2315
// any nodes from the channel graph that are currently unconnected. This ensure
2316
// that we only maintain a graph of reachable nodes. In the event that a pruned
2317
// node gains more channels, it will be re-added back to the graph.
2318
//
2319
// NOTE: this prunes nodes across protocol versions. It will never prune the
2320
// source nodes.
2321
//
2322
// NOTE: part of the V1Store interface.
2323
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2324
        var ctx = context.TODO()
×
2325

×
2326
        var prunedNodes []route.Vertex
×
2327
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2328
                var err error
×
2329
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2330

×
2331
                return err
×
2332
        }, func() {
×
2333
                prunedNodes = nil
×
2334
        })
×
2335
        if err != nil {
×
2336
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2337
        }
×
2338

2339
        return prunedNodes, nil
×
2340
}
2341

2342
// PruneGraph prunes newly closed channels from the channel graph in response
2343
// to a new block being solved on the network. Any transactions which spend the
2344
// funding output of any known channels within he graph will be deleted.
2345
// Additionally, the "prune tip", or the last block which has been used to
2346
// prune the graph is stored so callers can ensure the graph is fully in sync
2347
// with the current UTXO state. A slice of channels that have been closed by
2348
// the target block along with any pruned nodes are returned if the function
2349
// succeeds without error.
2350
//
2351
// NOTE: part of the V1Store interface.
2352
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2353
        blockHash *chainhash.Hash, blockHeight uint32) (
2354
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2355

×
2356
        ctx := context.TODO()
×
2357

×
2358
        s.cacheMu.Lock()
×
2359
        defer s.cacheMu.Unlock()
×
2360

×
2361
        var (
×
2362
                closedChans []*models.ChannelEdgeInfo
×
2363
                prunedNodes []route.Vertex
×
2364
        )
×
2365
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2366
                var chansToDelete []int64
×
2367

×
2368
                // Define the callback function for processing each channel.
×
2369
                channelCallback := func(ctx context.Context,
×
2370
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2371

×
2372
                        node1, node2, err := buildNodeVertices(
×
2373
                                row.Node1Pubkey, row.Node2Pubkey,
×
2374
                        )
×
2375
                        if err != nil {
×
2376
                                return err
×
2377
                        }
×
2378

2379
                        info, err := getAndBuildEdgeInfo(
×
2380
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
2381
                                node1, node2,
×
2382
                        )
×
2383
                        if err != nil {
×
2384
                                return err
×
2385
                        }
×
2386

2387
                        closedChans = append(closedChans, info)
×
2388
                        chansToDelete = append(
×
2389
                                chansToDelete, row.GraphChannel.ID,
×
2390
                        )
×
2391

×
2392
                        return nil
×
2393
                }
2394

2395
                err := s.forEachChanInOutpoints(
×
2396
                        ctx, db, spentOutputs, channelCallback,
×
2397
                )
×
2398
                if err != nil {
×
2399
                        return fmt.Errorf("unable to fetch channels by "+
×
2400
                                "outpoints: %w", err)
×
2401
                }
×
2402

2403
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2404
                if err != nil {
×
2405
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2406
                }
×
2407

2408
                err = db.UpsertPruneLogEntry(
×
2409
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2410
                                BlockHash:   blockHash[:],
×
2411
                                BlockHeight: int64(blockHeight),
×
2412
                        },
×
2413
                )
×
2414
                if err != nil {
×
2415
                        return fmt.Errorf("unable to insert prune log "+
×
2416
                                "entry: %w", err)
×
2417
                }
×
2418

2419
                // Now that we've pruned some channels, we'll also prune any
2420
                // nodes that no longer have any channels.
2421
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2422
                if err != nil {
×
2423
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2424
                                err)
×
2425
                }
×
2426

2427
                return nil
×
2428
        }, func() {
×
2429
                prunedNodes = nil
×
2430
                closedChans = nil
×
2431
        })
×
2432
        if err != nil {
×
2433
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2434
        }
×
2435

2436
        for _, channel := range closedChans {
×
2437
                s.rejectCache.remove(channel.ChannelID)
×
2438
                s.chanCache.remove(channel.ChannelID)
×
2439
        }
×
2440

2441
        return closedChans, prunedNodes, nil
×
2442
}
2443

2444
// forEachChanInOutpoints is a helper function that executes a paginated
2445
// query to fetch channels by their outpoints and applies the given call-back
2446
// to each.
2447
//
2448
// NOTE: this fetches channels for all protocol versions.
2449
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2450
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2451
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2452

×
2453
        // Create a wrapper that uses the transaction's db instance to execute
×
2454
        // the query.
×
2455
        queryWrapper := func(ctx context.Context,
×
2456
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2457
                error) {
×
2458

×
2459
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2460
        }
×
2461

2462
        // Define the conversion function from Outpoint to string.
2463
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2464
                return outpoint.String()
×
2465
        }
×
2466

2467
        return sqldb.ExecuteBatchQuery(
×
2468
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2469
                queryWrapper, cb,
×
2470
        )
×
2471
}
2472

2473
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2474
        dbIDs []int64) error {
×
2475

×
2476
        // Create a wrapper that uses the transaction's db instance to execute
×
2477
        // the query.
×
2478
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2479
                return nil, db.DeleteChannels(ctx, ids)
×
2480
        }
×
2481

2482
        idConverter := func(id int64) int64 {
×
2483
                return id
×
2484
        }
×
2485

2486
        return sqldb.ExecuteBatchQuery(
×
2487
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2488
                queryWrapper, func(ctx context.Context, _ any) error {
×
2489
                        return nil
×
2490
                },
×
2491
        )
2492
}
2493

2494
// ChannelView returns the verifiable edge information for each active channel
2495
// within the known channel graph. The set of UTXOs (along with their scripts)
2496
// returned are the ones that need to be watched on chain to detect channel
2497
// closes on the resident blockchain.
2498
//
2499
// NOTE: part of the V1Store interface.
2500
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2501
        var (
×
2502
                ctx        = context.TODO()
×
2503
                edgePoints []EdgePoint
×
2504
        )
×
2505

×
2506
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2507
                handleChannel := func(_ context.Context,
×
2508
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2509

×
2510
                        pkScript, err := genMultiSigP2WSH(
×
2511
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2512
                        )
×
2513
                        if err != nil {
×
2514
                                return err
×
2515
                        }
×
2516

2517
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2518
                        if err != nil {
×
2519
                                return err
×
2520
                        }
×
2521

2522
                        edgePoints = append(edgePoints, EdgePoint{
×
2523
                                FundingPkScript: pkScript,
×
2524
                                OutPoint:        *op,
×
2525
                        })
×
2526

×
2527
                        return nil
×
2528
                }
2529

2530
                queryFunc := func(ctx context.Context, lastID int64,
×
2531
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2532

×
2533
                        return db.ListChannelsPaginated(
×
2534
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2535
                                        Version: int16(ProtocolV1),
×
2536
                                        ID:      lastID,
×
2537
                                        Limit:   limit,
×
2538
                                },
×
2539
                        )
×
2540
                }
×
2541

2542
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2543
                        return row.ID
×
2544
                }
×
2545

2546
                return sqldb.ExecutePaginatedQuery(
×
2547
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
2548
                        extractCursor, handleChannel,
×
2549
                )
×
2550
        }, func() {
×
2551
                edgePoints = nil
×
2552
        })
×
2553
        if err != nil {
×
2554
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2555
        }
×
2556

2557
        return edgePoints, nil
×
2558
}
2559

2560
// PruneTip returns the block height and hash of the latest block that has been
2561
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2562
// to tell if the graph is currently in sync with the current best known UTXO
2563
// state.
2564
//
2565
// NOTE: part of the V1Store interface.
2566
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2567
        var (
×
2568
                ctx       = context.TODO()
×
2569
                tipHash   chainhash.Hash
×
2570
                tipHeight uint32
×
2571
        )
×
2572
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2573
                pruneTip, err := db.GetPruneTip(ctx)
×
2574
                if errors.Is(err, sql.ErrNoRows) {
×
2575
                        return ErrGraphNeverPruned
×
2576
                } else if err != nil {
×
2577
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2578
                }
×
2579

2580
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2581
                tipHeight = uint32(pruneTip.BlockHeight)
×
2582

×
2583
                return nil
×
2584
        }, sqldb.NoOpReset)
2585
        if err != nil {
×
2586
                return nil, 0, err
×
2587
        }
×
2588

2589
        return &tipHash, tipHeight, nil
×
2590
}
2591

2592
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2593
//
2594
// NOTE: this prunes nodes across protocol versions. It will never prune the
2595
// source nodes.
2596
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2597
        db SQLQueries) ([]route.Vertex, error) {
×
2598

×
2599
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2600
        if err != nil {
×
2601
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2602
                        "nodes: %w", err)
×
2603
        }
×
2604

2605
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2606
        for i, nodeKey := range nodeKeys {
×
2607
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2608
                if err != nil {
×
2609
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2610
                                "from bytes: %w", err)
×
2611
                }
×
2612

2613
                prunedNodes[i] = pub
×
2614
        }
2615

2616
        return prunedNodes, nil
×
2617
}
2618

2619
// DisconnectBlockAtHeight is used to indicate that the block specified
2620
// by the passed height has been disconnected from the main chain. This
2621
// will "rewind" the graph back to the height below, deleting channels
2622
// that are no longer confirmed from the graph. The prune log will be
2623
// set to the last prune height valid for the remaining chain.
2624
// Channels that were removed from the graph resulting from the
2625
// disconnected block are returned.
2626
//
2627
// NOTE: part of the V1Store interface.
2628
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2629
        []*models.ChannelEdgeInfo, error) {
×
2630

×
2631
        ctx := context.TODO()
×
2632

×
2633
        var (
×
2634
                // Every channel having a ShortChannelID starting at 'height'
×
2635
                // will no longer be confirmed.
×
2636
                startShortChanID = lnwire.ShortChannelID{
×
2637
                        BlockHeight: height,
×
2638
                }
×
2639

×
2640
                // Delete everything after this height from the db up until the
×
2641
                // SCID alias range.
×
2642
                endShortChanID = aliasmgr.StartingAlias
×
2643

×
2644
                removedChans []*models.ChannelEdgeInfo
×
2645

×
2646
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2647
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2648
        )
×
2649

×
2650
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2651
                rows, err := db.GetChannelsBySCIDRange(
×
2652
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2653
                                StartScid: chanIDStart,
×
2654
                                EndScid:   chanIDEnd,
×
2655
                        },
×
2656
                )
×
2657
                if err != nil {
×
2658
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2659
                }
×
2660

2661
                chanIDsToDelete := make([]int64, len(rows))
×
2662
                for i, row := range rows {
×
2663
                        node1, node2, err := buildNodeVertices(
×
2664
                                row.Node1PubKey, row.Node2PubKey,
×
2665
                        )
×
2666
                        if err != nil {
×
2667
                                return err
×
2668
                        }
×
2669

2670
                        channel, err := getAndBuildEdgeInfo(
×
2671
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
2672
                                node1, node2,
×
2673
                        )
×
2674
                        if err != nil {
×
2675
                                return err
×
2676
                        }
×
2677

2678
                        chanIDsToDelete[i] = row.GraphChannel.ID
×
2679
                        removedChans = append(removedChans, channel)
×
2680
                }
2681

2682
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2683
                if err != nil {
×
2684
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2685
                }
×
2686

2687
                return db.DeletePruneLogEntriesInRange(
×
2688
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2689
                                StartHeight: int64(height),
×
2690
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2691
                        },
×
2692
                )
×
2693
        }, func() {
×
2694
                removedChans = nil
×
2695
        })
×
2696
        if err != nil {
×
2697
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2698
                        "height: %w", err)
×
2699
        }
×
2700

2701
        for _, channel := range removedChans {
×
2702
                s.rejectCache.remove(channel.ChannelID)
×
2703
                s.chanCache.remove(channel.ChannelID)
×
2704
        }
×
2705

2706
        return removedChans, nil
×
2707
}
2708

2709
// AddEdgeProof sets the proof of an existing edge in the graph database.
2710
//
2711
// NOTE: part of the V1Store interface.
2712
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2713
        proof *models.ChannelAuthProof) error {
×
2714

×
2715
        var (
×
2716
                ctx       = context.TODO()
×
2717
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2718
        )
×
2719

×
2720
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2721
                res, err := db.AddV1ChannelProof(
×
2722
                        ctx, sqlc.AddV1ChannelProofParams{
×
2723
                                Scid:              scidBytes,
×
2724
                                Node1Signature:    proof.NodeSig1Bytes,
×
2725
                                Node2Signature:    proof.NodeSig2Bytes,
×
2726
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2727
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2728
                        },
×
2729
                )
×
2730
                if err != nil {
×
2731
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2732
                }
×
2733

2734
                n, err := res.RowsAffected()
×
2735
                if err != nil {
×
2736
                        return err
×
2737
                }
×
2738

2739
                if n == 0 {
×
2740
                        return fmt.Errorf("no rows affected when adding edge "+
×
2741
                                "proof for SCID %v", scid)
×
2742
                } else if n > 1 {
×
2743
                        return fmt.Errorf("multiple rows affected when adding "+
×
2744
                                "edge proof for SCID %v: %d rows affected",
×
2745
                                scid, n)
×
2746
                }
×
2747

2748
                return nil
×
2749
        }, sqldb.NoOpReset)
2750
        if err != nil {
×
2751
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2752
        }
×
2753

2754
        return nil
×
2755
}
2756

2757
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2758
// that we can ignore channel announcements that we know to be closed without
2759
// having to validate them and fetch a block.
2760
//
2761
// NOTE: part of the V1Store interface.
2762
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2763
        var (
×
2764
                ctx     = context.TODO()
×
2765
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2766
        )
×
2767

×
2768
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2769
                return db.InsertClosedChannel(ctx, chanIDB)
×
2770
        }, sqldb.NoOpReset)
×
2771
}
2772

2773
// IsClosedScid checks whether a channel identified by the passed in scid is
2774
// closed. This helps avoid having to perform expensive validation checks.
2775
//
2776
// NOTE: part of the V1Store interface.
2777
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2778
        var (
×
2779
                ctx      = context.TODO()
×
2780
                isClosed bool
×
2781
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2782
        )
×
2783
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2784
                var err error
×
2785
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2786
                if err != nil {
×
2787
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2788
                                err)
×
2789
                }
×
2790

2791
                return nil
×
2792
        }, sqldb.NoOpReset)
2793
        if err != nil {
×
2794
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2795
                        err)
×
2796
        }
×
2797

2798
        return isClosed, nil
×
2799
}
2800

2801
// GraphSession will provide the call-back with access to a NodeTraverser
2802
// instance which can be used to perform queries against the channel graph.
2803
//
2804
// NOTE: part of the V1Store interface.
2805
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
2806
        reset func()) error {
×
2807

×
2808
        var ctx = context.TODO()
×
2809

×
2810
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2811
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2812
        }, reset)
×
2813
}
2814

2815
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2816
// read only transaction for a consistent view of the graph.
2817
type sqlNodeTraverser struct {
2818
        db    SQLQueries
2819
        chain chainhash.Hash
2820
}
2821

2822
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2823
// NodeTraverser interface.
2824
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2825

2826
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2827
func newSQLNodeTraverser(db SQLQueries,
2828
        chain chainhash.Hash) *sqlNodeTraverser {
×
2829

×
2830
        return &sqlNodeTraverser{
×
2831
                db:    db,
×
2832
                chain: chain,
×
2833
        }
×
2834
}
×
2835

2836
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2837
// node.
2838
//
2839
// NOTE: Part of the NodeTraverser interface.
2840
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2841
        cb func(channel *DirectedChannel) error, _ func()) error {
×
2842

×
2843
        ctx := context.TODO()
×
2844

×
2845
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2846
}
×
2847

2848
// FetchNodeFeatures returns the features of the given node. If the node is
2849
// unknown, assume no additional features are supported.
2850
//
2851
// NOTE: Part of the NodeTraverser interface.
2852
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2853
        *lnwire.FeatureVector, error) {
×
2854

×
2855
        ctx := context.TODO()
×
2856

×
2857
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2858
}
×
2859

2860
// forEachNodeDirectedChannel iterates through all channels of a given
2861
// node, executing the passed callback on the directed edge representing the
2862
// channel and its incoming policy. If the node is not found, no error is
2863
// returned.
2864
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2865
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2866

×
2867
        toNodeCallback := func() route.Vertex {
×
2868
                return nodePub
×
2869
        }
×
2870

2871
        dbID, err := db.GetNodeIDByPubKey(
×
2872
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2873
                        Version: int16(ProtocolV1),
×
2874
                        PubKey:  nodePub[:],
×
2875
                },
×
2876
        )
×
2877
        if errors.Is(err, sql.ErrNoRows) {
×
2878
                return nil
×
2879
        } else if err != nil {
×
2880
                return fmt.Errorf("unable to fetch node: %w", err)
×
2881
        }
×
2882

2883
        rows, err := db.ListChannelsByNodeID(
×
2884
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2885
                        Version: int16(ProtocolV1),
×
2886
                        NodeID1: dbID,
×
2887
                },
×
2888
        )
×
2889
        if err != nil {
×
2890
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2891
        }
×
2892

2893
        // Exit early if there are no channels for this node so we don't
2894
        // do the unnecessary feature fetching.
2895
        if len(rows) == 0 {
×
2896
                return nil
×
2897
        }
×
2898

2899
        features, err := getNodeFeatures(ctx, db, dbID)
×
2900
        if err != nil {
×
2901
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2902
        }
×
2903

2904
        for _, row := range rows {
×
2905
                node1, node2, err := buildNodeVertices(
×
2906
                        row.Node1Pubkey, row.Node2Pubkey,
×
2907
                )
×
2908
                if err != nil {
×
2909
                        return fmt.Errorf("unable to build node vertices: %w",
×
2910
                                err)
×
2911
                }
×
2912

2913
                edge := buildCacheableChannelInfo(
×
2914
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
2915
                        node1, node2,
×
2916
                )
×
2917

×
2918
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2919
                if err != nil {
×
2920
                        return err
×
2921
                }
×
2922

2923
                p1, p2, err := buildCachedChanPolicies(
×
2924
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
2925
                )
×
2926
                if err != nil {
×
2927
                        return err
×
2928
                }
×
2929

2930
                // Determine the outgoing and incoming policy for this
2931
                // channel and node combo.
2932
                outPolicy, inPolicy := p1, p2
×
2933
                if p1 != nil && node2 == nodePub {
×
2934
                        outPolicy, inPolicy = p2, p1
×
2935
                } else if p2 != nil && node1 != nodePub {
×
2936
                        outPolicy, inPolicy = p2, p1
×
2937
                }
×
2938

2939
                var cachedInPolicy *models.CachedEdgePolicy
×
2940
                if inPolicy != nil {
×
2941
                        cachedInPolicy = inPolicy
×
2942
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2943
                        cachedInPolicy.ToNodeFeatures = features
×
2944
                }
×
2945

2946
                directedChannel := &DirectedChannel{
×
2947
                        ChannelID:    edge.ChannelID,
×
2948
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2949
                        OtherNode:    edge.NodeKey2Bytes,
×
2950
                        Capacity:     edge.Capacity,
×
2951
                        OutPolicySet: outPolicy != nil,
×
2952
                        InPolicy:     cachedInPolicy,
×
2953
                }
×
2954
                if outPolicy != nil {
×
2955
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2956
                                directedChannel.InboundFee = fee
×
2957
                        })
×
2958
                }
2959

2960
                if nodePub == edge.NodeKey2Bytes {
×
2961
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2962
                }
×
2963

2964
                if err := cb(directedChannel); err != nil {
×
2965
                        return err
×
2966
                }
×
2967
        }
2968

2969
        return nil
×
2970
}
2971

2972
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2973
// and executes the provided callback for each node. It does so via pagination
2974
// along with batch loading of the node feature bits.
2975
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
2976
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
2977
                features *lnwire.FeatureVector) error) error {
×
2978

×
2979
        handleNode := func(_ context.Context,
×
2980
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
2981
                featureBits map[int64][]int) error {
×
2982

×
2983
                fv := lnwire.EmptyFeatureVector()
×
2984
                if features, exists := featureBits[dbNode.ID]; exists {
×
2985
                        for _, bit := range features {
×
2986
                                fv.Set(lnwire.FeatureBit(bit))
×
2987
                        }
×
2988
                }
2989

2990
                var pub route.Vertex
×
2991
                copy(pub[:], dbNode.PubKey)
×
2992

×
2993
                return processNode(dbNode.ID, pub, fv)
×
2994
        }
2995

2996
        queryFunc := func(ctx context.Context, lastID int64,
×
2997
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
2998

×
2999
                return db.ListNodeIDsAndPubKeys(
×
3000
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3001
                                Version: int16(ProtocolV1),
×
3002
                                ID:      lastID,
×
3003
                                Limit:   limit,
×
3004
                        },
×
3005
                )
×
3006
        }
×
3007

3008
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
3009
                return row.ID
×
3010
        }
×
3011

3012
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
3013
                return node.ID, nil
×
3014
        }
×
3015

3016
        batchQueryFunc := func(ctx context.Context,
×
3017
                nodeIDs []int64) (map[int64][]int, error) {
×
3018

×
3019
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
3020
        }
×
3021

3022
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
3023
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
3024
                batchQueryFunc, handleNode,
×
3025
        )
×
3026
}
3027

3028
// forEachNodeChannel iterates through all channels of a node, executing
3029
// the passed callback on each. The call-back is provided with the channel's
3030
// edge information, the outgoing policy and the incoming policy for the
3031
// channel and node combo.
3032
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3033
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
3034
                *models.ChannelEdgePolicy,
3035
                *models.ChannelEdgePolicy) error) error {
×
3036

×
3037
        // Get all the V1 channels for this node.
×
3038
        rows, err := db.ListChannelsByNodeID(
×
3039
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3040
                        Version: int16(ProtocolV1),
×
3041
                        NodeID1: id,
×
3042
                },
×
3043
        )
×
3044
        if err != nil {
×
3045
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3046
        }
×
3047

3048
        // Collect all the channel and policy IDs.
3049
        var (
×
3050
                chanIDs   = make([]int64, 0, len(rows))
×
3051
                policyIDs = make([]int64, 0, 2*len(rows))
×
3052
        )
×
3053
        for _, row := range rows {
×
3054
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3055

×
3056
                if row.Policy1ID.Valid {
×
3057
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3058
                }
×
3059
                if row.Policy2ID.Valid {
×
3060
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3061
                }
×
3062
        }
3063

3064
        batchData, err := batchLoadChannelData(
×
3065
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3066
        )
×
3067
        if err != nil {
×
3068
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3069
        }
×
3070

3071
        // Call the call-back for each channel and its known policies.
3072
        for _, row := range rows {
×
3073
                node1, node2, err := buildNodeVertices(
×
3074
                        row.Node1Pubkey, row.Node2Pubkey,
×
3075
                )
×
3076
                if err != nil {
×
3077
                        return fmt.Errorf("unable to build node vertices: %w",
×
3078
                                err)
×
3079
                }
×
3080

3081
                edge, err := buildEdgeInfoWithBatchData(
×
3082
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3083
                        batchData,
×
3084
                )
×
3085
                if err != nil {
×
3086
                        return fmt.Errorf("unable to build channel info: %w",
×
3087
                                err)
×
3088
                }
×
3089

3090
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3091
                if err != nil {
×
3092
                        return fmt.Errorf("unable to extract channel "+
×
3093
                                "policies: %w", err)
×
3094
                }
×
3095

3096
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3097
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3098
                )
×
3099
                if err != nil {
×
3100
                        return fmt.Errorf("unable to build channel "+
×
3101
                                "policies: %w", err)
×
3102
                }
×
3103

3104
                // Determine the outgoing and incoming policy for this
3105
                // channel and node combo.
3106
                p1ToNode := row.GraphChannel.NodeID2
×
3107
                p2ToNode := row.GraphChannel.NodeID1
×
3108
                outPolicy, inPolicy := p1, p2
×
3109
                if (p1 != nil && p1ToNode == id) ||
×
3110
                        (p2 != nil && p2ToNode != id) {
×
3111

×
3112
                        outPolicy, inPolicy = p2, p1
×
3113
                }
×
3114

3115
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3116
                        return err
×
3117
                }
×
3118
        }
3119

3120
        return nil
×
3121
}
3122

3123
// updateChanEdgePolicy upserts the channel policy info we have stored for
3124
// a channel we already know of.
3125
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3126
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3127
        error) {
×
3128

×
3129
        var (
×
3130
                node1Pub, node2Pub route.Vertex
×
3131
                isNode1            bool
×
3132
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3133
        )
×
3134

×
3135
        // Check that this edge policy refers to a channel that we already
×
3136
        // know of. We do this explicitly so that we can return the appropriate
×
3137
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3138
        // abort the transaction which would abort the entire batch.
×
3139
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3140
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3141
                        Scid:    chanIDB,
×
3142
                        Version: int16(ProtocolV1),
×
3143
                },
×
3144
        )
×
3145
        if errors.Is(err, sql.ErrNoRows) {
×
3146
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3147
        } else if err != nil {
×
3148
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3149
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3150
        }
×
3151

3152
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3153
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3154

×
3155
        // Figure out which node this edge is from.
×
3156
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3157
        nodeID := dbChan.NodeID1
×
3158
        if !isNode1 {
×
3159
                nodeID = dbChan.NodeID2
×
3160
        }
×
3161

3162
        var (
×
3163
                inboundBase sql.NullInt64
×
3164
                inboundRate sql.NullInt64
×
3165
        )
×
3166
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3167
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3168
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3169
        })
×
3170

3171
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3172
                Version:     int16(ProtocolV1),
×
3173
                ChannelID:   dbChan.ID,
×
3174
                NodeID:      nodeID,
×
3175
                Timelock:    int32(edge.TimeLockDelta),
×
3176
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3177
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3178
                MinHtlcMsat: int64(edge.MinHTLC),
×
3179
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3180
                Disabled: sql.NullBool{
×
3181
                        Valid: true,
×
3182
                        Bool:  edge.IsDisabled(),
×
3183
                },
×
3184
                MaxHtlcMsat: sql.NullInt64{
×
3185
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3186
                        Int64: int64(edge.MaxHTLC),
×
3187
                },
×
3188
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3189
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3190
                InboundBaseFeeMsat:      inboundBase,
×
3191
                InboundFeeRateMilliMsat: inboundRate,
×
3192
                Signature:               edge.SigBytes,
×
3193
        })
×
3194
        if err != nil {
×
3195
                return node1Pub, node2Pub, isNode1,
×
3196
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3197
        }
×
3198

3199
        // Convert the flat extra opaque data into a map of TLV types to
3200
        // values.
3201
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3202
        if err != nil {
×
3203
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3204
                        "marshal extra opaque data: %w", err)
×
3205
        }
×
3206

3207
        // Update the channel policy's extra signed fields.
3208
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3209
        if err != nil {
×
3210
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3211
                        "policy extra TLVs: %w", err)
×
3212
        }
×
3213

3214
        return node1Pub, node2Pub, isNode1, nil
×
3215
}
3216

3217
// getNodeByPubKey attempts to look up a target node by its public key.
3218
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3219
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3220

×
3221
        dbNode, err := db.GetNodeByPubKey(
×
3222
                ctx, sqlc.GetNodeByPubKeyParams{
×
3223
                        Version: int16(ProtocolV1),
×
3224
                        PubKey:  pubKey[:],
×
3225
                },
×
3226
        )
×
3227
        if errors.Is(err, sql.ErrNoRows) {
×
3228
                return 0, nil, ErrGraphNodeNotFound
×
3229
        } else if err != nil {
×
3230
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3231
        }
×
3232

3233
        node, err := buildNode(ctx, db, &dbNode)
×
3234
        if err != nil {
×
3235
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3236
        }
×
3237

3238
        return dbNode.ID, node, nil
×
3239
}
3240

3241
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3242
// provided parameters.
3243
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3244
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3245

×
3246
        return &models.CachedEdgeInfo{
×
3247
                ChannelID:     byteOrder.Uint64(scid),
×
3248
                NodeKey1Bytes: node1Pub,
×
3249
                NodeKey2Bytes: node2Pub,
×
3250
                Capacity:      btcutil.Amount(capacity),
×
3251
        }
×
3252
}
×
3253

3254
// buildNode constructs a LightningNode instance from the given database node
3255
// record. The node's features, addresses and extra signed fields are also
3256
// fetched from the database and set on the node.
3257
func buildNode(ctx context.Context, db SQLQueries,
3258
        dbNode *sqlc.GraphNode) (*models.LightningNode, error) {
×
3259

×
3260
        // NOTE: buildNode is only used to load the data for a single node, and
×
3261
        // so no paged queries will be performed. This means that it's ok to
×
3262
        // used pass in default config values here.
×
3263
        cfg := sqldb.DefaultQueryConfig()
×
3264

×
3265
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3266
        if err != nil {
×
3267
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3268
                        err)
×
3269
        }
×
3270

3271
        return buildNodeWithBatchData(dbNode, data)
×
3272
}
3273

3274
// buildNodeWithBatchData builds a models.LightningNode instance
3275
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3276
// features/addresses/extra fields, then the corresponding fields are expected
3277
// to be present in the batchNodeData.
3278
func buildNodeWithBatchData(dbNode *sqlc.GraphNode,
3279
        batchData *batchNodeData) (*models.LightningNode, error) {
×
3280

×
3281
        if dbNode.Version != int16(ProtocolV1) {
×
3282
                return nil, fmt.Errorf("unsupported node version: %d",
×
3283
                        dbNode.Version)
×
3284
        }
×
3285

3286
        var pub [33]byte
×
3287
        copy(pub[:], dbNode.PubKey)
×
3288

×
3289
        node := &models.LightningNode{
×
3290
                PubKeyBytes: pub,
×
3291
                Features:    lnwire.EmptyFeatureVector(),
×
3292
                LastUpdate:  time.Unix(0, 0),
×
3293
        }
×
3294

×
3295
        if len(dbNode.Signature) == 0 {
×
3296
                return node, nil
×
3297
        }
×
3298

3299
        node.HaveNodeAnnouncement = true
×
3300
        node.AuthSigBytes = dbNode.Signature
×
3301
        node.Alias = dbNode.Alias.String
×
3302
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3303

×
3304
        var err error
×
3305
        if dbNode.Color.Valid {
×
3306
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3307
                if err != nil {
×
3308
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3309
                                err)
×
3310
                }
×
3311
        }
3312

3313
        // Use preloaded features.
3314
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3315
                fv := lnwire.EmptyFeatureVector()
×
3316
                for _, bit := range features {
×
3317
                        fv.Set(lnwire.FeatureBit(bit))
×
3318
                }
×
3319
                node.Features = fv
×
3320
        }
3321

3322
        // Use preloaded addresses.
3323
        addresses, exists := batchData.addresses[dbNode.ID]
×
3324
        if exists && len(addresses) > 0 {
×
3325
                node.Addresses, err = buildNodeAddresses(addresses)
×
3326
                if err != nil {
×
3327
                        return nil, fmt.Errorf("unable to build addresses "+
×
3328
                                "for node(%d): %w", dbNode.ID, err)
×
3329
                }
×
3330
        }
3331

3332
        // Use preloaded extra fields.
3333
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3334
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3335
                if err != nil {
×
3336
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3337
                                "signed fields: %w", err)
×
3338
                }
×
3339
                if len(recs) != 0 {
×
3340
                        node.ExtraOpaqueData = recs
×
3341
                }
×
3342
        }
3343

3344
        return node, nil
×
3345
}
3346

3347
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3348
// with the preloaded data, and executes the provided callback for each node.
3349
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3350
        db SQLQueries, nodes []sqlc.GraphNode,
3351
        cb func(dbID int64, node *models.LightningNode) error) error {
×
3352

×
3353
        // Extract node IDs for batch loading.
×
3354
        nodeIDs := make([]int64, len(nodes))
×
3355
        for i, node := range nodes {
×
3356
                nodeIDs[i] = node.ID
×
3357
        }
×
3358

3359
        // Batch load all related data for this page.
3360
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3361
        if err != nil {
×
3362
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3363
        }
×
3364

3365
        for _, dbNode := range nodes {
×
3366
                node, err := buildNodeWithBatchData(&dbNode, batchData)
×
3367
                if err != nil {
×
3368
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3369
                                dbNode.ID, err)
×
3370
                }
×
3371

3372
                if err := cb(dbNode.ID, node); err != nil {
×
3373
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3374
                                dbNode.ID, err)
×
3375
                }
×
3376
        }
3377

3378
        return nil
×
3379
}
3380

3381
// getNodeFeatures fetches the feature bits and constructs the feature vector
3382
// for a node with the given DB ID.
3383
func getNodeFeatures(ctx context.Context, db SQLQueries,
3384
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3385

×
3386
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3387
        if err != nil {
×
3388
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3389
                        nodeID, err)
×
3390
        }
×
3391

3392
        features := lnwire.EmptyFeatureVector()
×
3393
        for _, feature := range rows {
×
3394
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3395
        }
×
3396

3397
        return features, nil
×
3398
}
3399

3400
// upsertNode upserts the node record into the database. If the node already
3401
// exists, then the node's information is updated. If the node doesn't exist,
3402
// then a new node is created. The node's features, addresses and extra TLV
3403
// types are also updated. The node's DB ID is returned.
3404
func upsertNode(ctx context.Context, db SQLQueries,
3405
        node *models.LightningNode) (int64, error) {
×
3406

×
3407
        params := sqlc.UpsertNodeParams{
×
3408
                Version: int16(ProtocolV1),
×
3409
                PubKey:  node.PubKeyBytes[:],
×
3410
        }
×
3411

×
3412
        if node.HaveNodeAnnouncement {
×
3413
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3414
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3415
                params.Alias = sqldb.SQLStr(node.Alias)
×
3416
                params.Signature = node.AuthSigBytes
×
3417
        }
×
3418

3419
        nodeID, err := db.UpsertNode(ctx, params)
×
3420
        if err != nil {
×
3421
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3422
                        err)
×
3423
        }
×
3424

3425
        // We can exit here if we don't have the announcement yet.
3426
        if !node.HaveNodeAnnouncement {
×
3427
                return nodeID, nil
×
3428
        }
×
3429

3430
        // Update the node's features.
3431
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3432
        if err != nil {
×
3433
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3434
        }
×
3435

3436
        // Update the node's addresses.
3437
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3438
        if err != nil {
×
3439
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3440
        }
×
3441

3442
        // Convert the flat extra opaque data into a map of TLV types to
3443
        // values.
3444
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3445
        if err != nil {
×
3446
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3447
                        err)
×
3448
        }
×
3449

3450
        // Update the node's extra signed fields.
3451
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3452
        if err != nil {
×
3453
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3454
        }
×
3455

3456
        return nodeID, nil
×
3457
}
3458

3459
// upsertNodeFeatures updates the node's features node_features table. This
3460
// includes deleting any feature bits no longer present and inserting any new
3461
// feature bits. If the feature bit does not yet exist in the features table,
3462
// then an entry is created in that table first.
3463
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3464
        features *lnwire.FeatureVector) error {
×
3465

×
3466
        // Get any existing features for the node.
×
3467
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3468
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3469
                return err
×
3470
        }
×
3471

3472
        // Copy the nodes latest set of feature bits.
3473
        newFeatures := make(map[int32]struct{})
×
3474
        if features != nil {
×
3475
                for feature := range features.Features() {
×
3476
                        newFeatures[int32(feature)] = struct{}{}
×
3477
                }
×
3478
        }
3479

3480
        // For any current feature that already exists in the DB, remove it from
3481
        // the in-memory map. For any existing feature that does not exist in
3482
        // the in-memory map, delete it from the database.
3483
        for _, feature := range existingFeatures {
×
3484
                // The feature is still present, so there are no updates to be
×
3485
                // made.
×
3486
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3487
                        delete(newFeatures, feature.FeatureBit)
×
3488
                        continue
×
3489
                }
3490

3491
                // The feature is no longer present, so we remove it from the
3492
                // database.
3493
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3494
                        NodeID:     nodeID,
×
3495
                        FeatureBit: feature.FeatureBit,
×
3496
                })
×
3497
                if err != nil {
×
3498
                        return fmt.Errorf("unable to delete node(%d) "+
×
3499
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3500
                                err)
×
3501
                }
×
3502
        }
3503

3504
        // Any remaining entries in newFeatures are new features that need to be
3505
        // added to the database for the first time.
3506
        for feature := range newFeatures {
×
3507
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3508
                        NodeID:     nodeID,
×
3509
                        FeatureBit: feature,
×
3510
                })
×
3511
                if err != nil {
×
3512
                        return fmt.Errorf("unable to insert node(%d) "+
×
3513
                                "feature(%v): %w", nodeID, feature, err)
×
3514
                }
×
3515
        }
3516

3517
        return nil
×
3518
}
3519

3520
// fetchNodeFeatures fetches the features for a node with the given public key.
3521
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3522
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3523

×
3524
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3525
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3526
                        PubKey:  nodePub[:],
×
3527
                        Version: int16(ProtocolV1),
×
3528
                },
×
3529
        )
×
3530
        if err != nil {
×
3531
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3532
                        nodePub, err)
×
3533
        }
×
3534

3535
        features := lnwire.EmptyFeatureVector()
×
3536
        for _, bit := range rows {
×
3537
                features.Set(lnwire.FeatureBit(bit))
×
3538
        }
×
3539

3540
        return features, nil
×
3541
}
3542

3543
// dbAddressType is an enum type that represents the different address types
3544
// that we store in the node_addresses table. The address type determines how
3545
// the address is to be serialised/deserialize.
3546
type dbAddressType uint8
3547

3548
const (
3549
        addressTypeIPv4   dbAddressType = 1
3550
        addressTypeIPv6   dbAddressType = 2
3551
        addressTypeTorV2  dbAddressType = 3
3552
        addressTypeTorV3  dbAddressType = 4
3553
        addressTypeOpaque dbAddressType = math.MaxInt8
3554
)
3555

3556
// upsertNodeAddresses updates the node's addresses in the database. This
3557
// includes deleting any existing addresses and inserting the new set of
3558
// addresses. The deletion is necessary since the ordering of the addresses may
3559
// change, and we need to ensure that the database reflects the latest set of
3560
// addresses so that at the time of reconstructing the node announcement, the
3561
// order is preserved and the signature over the message remains valid.
3562
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3563
        addresses []net.Addr) error {
×
3564

×
3565
        // Delete any existing addresses for the node. This is required since
×
3566
        // even if the new set of addresses is the same, the ordering may have
×
3567
        // changed for a given address type.
×
3568
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3569
        if err != nil {
×
3570
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3571
                        nodeID, err)
×
3572
        }
×
3573

3574
        // Copy the nodes latest set of addresses.
3575
        newAddresses := map[dbAddressType][]string{
×
3576
                addressTypeIPv4:   {},
×
3577
                addressTypeIPv6:   {},
×
3578
                addressTypeTorV2:  {},
×
3579
                addressTypeTorV3:  {},
×
3580
                addressTypeOpaque: {},
×
3581
        }
×
3582
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3583
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3584
        }
×
3585

3586
        for _, address := range addresses {
×
3587
                switch addr := address.(type) {
×
3588
                case *net.TCPAddr:
×
3589
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3590
                                addAddr(addressTypeIPv4, addr)
×
3591
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3592
                                addAddr(addressTypeIPv6, addr)
×
3593
                        } else {
×
3594
                                return fmt.Errorf("unhandled IP address: %v",
×
3595
                                        addr)
×
3596
                        }
×
3597

3598
                case *tor.OnionAddr:
×
3599
                        switch len(addr.OnionService) {
×
3600
                        case tor.V2Len:
×
3601
                                addAddr(addressTypeTorV2, addr)
×
3602
                        case tor.V3Len:
×
3603
                                addAddr(addressTypeTorV3, addr)
×
3604
                        default:
×
3605
                                return fmt.Errorf("invalid length for a tor " +
×
3606
                                        "address")
×
3607
                        }
3608

3609
                case *lnwire.OpaqueAddrs:
×
3610
                        addAddr(addressTypeOpaque, addr)
×
3611

3612
                default:
×
3613
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3614
                }
3615
        }
3616

3617
        // Any remaining entries in newAddresses are new addresses that need to
3618
        // be added to the database for the first time.
3619
        for addrType, addrList := range newAddresses {
×
3620
                for position, addr := range addrList {
×
3621
                        err := db.InsertNodeAddress(
×
3622
                                ctx, sqlc.InsertNodeAddressParams{
×
3623
                                        NodeID:   nodeID,
×
3624
                                        Type:     int16(addrType),
×
3625
                                        Address:  addr,
×
3626
                                        Position: int32(position),
×
3627
                                },
×
3628
                        )
×
3629
                        if err != nil {
×
3630
                                return fmt.Errorf("unable to insert "+
×
3631
                                        "node(%d) address(%v): %w", nodeID,
×
3632
                                        addr, err)
×
3633
                        }
×
3634
                }
3635
        }
3636

3637
        return nil
×
3638
}
3639

3640
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3641
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
3642
        error) {
×
3643

×
3644
        // GetNodeAddresses ensures that the addresses for a given type are
×
3645
        // returned in the same order as they were inserted.
×
3646
        rows, err := db.GetNodeAddresses(ctx, id)
×
3647
        if err != nil {
×
3648
                return nil, err
×
3649
        }
×
3650

3651
        addresses := make([]net.Addr, 0, len(rows))
×
3652
        for _, row := range rows {
×
3653
                address := row.Address
×
3654

×
3655
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
3656
                if err != nil {
×
3657
                        return nil, fmt.Errorf("unable to parse address "+
×
3658
                                "for node(%d): %v: %w", id, address, err)
×
3659
                }
×
3660

3661
                addresses = append(addresses, addr)
×
3662
        }
3663

3664
        // If we have no addresses, then we'll return nil instead of an
3665
        // empty slice.
3666
        if len(addresses) == 0 {
×
3667
                addresses = nil
×
3668
        }
×
3669

3670
        return addresses, nil
×
3671
}
3672

3673
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3674
// database. This includes updating any existing types, inserting any new types,
3675
// and deleting any types that are no longer present.
3676
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3677
        nodeID int64, extraFields map[uint64][]byte) error {
×
3678

×
3679
        // Get any existing extra signed fields for the node.
×
3680
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3681
        if err != nil {
×
3682
                return err
×
3683
        }
×
3684

3685
        // Make a lookup map of the existing field types so that we can use it
3686
        // to keep track of any fields we should delete.
3687
        m := make(map[uint64]bool)
×
3688
        for _, field := range existingFields {
×
3689
                m[uint64(field.Type)] = true
×
3690
        }
×
3691

3692
        // For all the new fields, we'll upsert them and remove them from the
3693
        // map of existing fields.
3694
        for tlvType, value := range extraFields {
×
3695
                err = db.UpsertNodeExtraType(
×
3696
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3697
                                NodeID: nodeID,
×
3698
                                Type:   int64(tlvType),
×
3699
                                Value:  value,
×
3700
                        },
×
3701
                )
×
3702
                if err != nil {
×
3703
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3704
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3705
                }
×
3706

3707
                // Remove the field from the map of existing fields if it was
3708
                // present.
3709
                delete(m, tlvType)
×
3710
        }
3711

3712
        // For all the fields that are left in the map of existing fields, we'll
3713
        // delete them as they are no longer present in the new set of fields.
3714
        for tlvType := range m {
×
3715
                err = db.DeleteExtraNodeType(
×
3716
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3717
                                NodeID: nodeID,
×
3718
                                Type:   int64(tlvType),
×
3719
                        },
×
3720
                )
×
3721
                if err != nil {
×
3722
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3723
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3724
                }
×
3725
        }
3726

3727
        return nil
×
3728
}
3729

3730
// srcNodeInfo holds the information about the source node of the graph.
3731
type srcNodeInfo struct {
3732
        // id is the DB level ID of the source node entry in the "nodes" table.
3733
        id int64
3734

3735
        // pub is the public key of the source node.
3736
        pub route.Vertex
3737
}
3738

3739
// sourceNode returns the DB node ID and pub key of the source node for the
3740
// specified protocol version.
3741
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3742
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3743

×
3744
        s.srcNodeMu.Lock()
×
3745
        defer s.srcNodeMu.Unlock()
×
3746

×
3747
        // If we already have the source node ID and pub key cached, then
×
3748
        // return them.
×
3749
        if info, ok := s.srcNodes[version]; ok {
×
3750
                return info.id, info.pub, nil
×
3751
        }
×
3752

3753
        var pubKey route.Vertex
×
3754

×
3755
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3756
        if err != nil {
×
3757
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3758
                        err)
×
3759
        }
×
3760

3761
        if len(nodes) == 0 {
×
3762
                return 0, pubKey, ErrSourceNodeNotSet
×
3763
        } else if len(nodes) > 1 {
×
3764
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3765
                        "protocol %s found", version)
×
3766
        }
×
3767

3768
        copy(pubKey[:], nodes[0].PubKey)
×
3769

×
3770
        s.srcNodes[version] = &srcNodeInfo{
×
3771
                id:  nodes[0].NodeID,
×
3772
                pub: pubKey,
×
3773
        }
×
3774

×
3775
        return nodes[0].NodeID, pubKey, nil
×
3776
}
3777

3778
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3779
// This then produces a map from TLV type to value. If the input is not a
3780
// valid TLV stream, then an error is returned.
3781
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3782
        r := bytes.NewReader(data)
×
3783

×
3784
        tlvStream, err := tlv.NewStream()
×
3785
        if err != nil {
×
3786
                return nil, err
×
3787
        }
×
3788

3789
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3790
        // pass it into the P2P decoding variant.
3791
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3792
        if err != nil {
×
3793
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3794
        }
×
3795
        if len(parsedTypes) == 0 {
×
3796
                return nil, nil
×
3797
        }
×
3798

3799
        records := make(map[uint64][]byte)
×
3800
        for k, v := range parsedTypes {
×
3801
                records[uint64(k)] = v
×
3802
        }
×
3803

3804
        return records, nil
×
3805
}
3806

3807
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3808
// channel.
3809
type dbChanInfo struct {
3810
        channelID int64
3811
        node1ID   int64
3812
        node2ID   int64
3813
}
3814

3815
// insertChannel inserts a new channel record into the database.
3816
func insertChannel(ctx context.Context, db SQLQueries,
3817
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3818

×
3819
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3820

×
3821
        // Make sure that the channel doesn't already exist. We do this
×
3822
        // explicitly instead of relying on catching a unique constraint error
×
3823
        // because relying on SQL to throw that error would abort the entire
×
3824
        // batch of transactions.
×
3825
        _, err := db.GetChannelBySCID(
×
3826
                ctx, sqlc.GetChannelBySCIDParams{
×
3827
                        Scid:    chanIDB,
×
3828
                        Version: int16(ProtocolV1),
×
3829
                },
×
3830
        )
×
3831
        if err == nil {
×
3832
                return nil, ErrEdgeAlreadyExist
×
3833
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3834
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3835
        }
×
3836

3837
        // Make sure that at least a "shell" entry for each node is present in
3838
        // the nodes table.
3839
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3840
        if err != nil {
×
3841
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3842
        }
×
3843

3844
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3845
        if err != nil {
×
3846
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3847
        }
×
3848

3849
        var capacity sql.NullInt64
×
3850
        if edge.Capacity != 0 {
×
3851
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3852
        }
×
3853

3854
        createParams := sqlc.CreateChannelParams{
×
3855
                Version:     int16(ProtocolV1),
×
3856
                Scid:        chanIDB,
×
3857
                NodeID1:     node1DBID,
×
3858
                NodeID2:     node2DBID,
×
3859
                Outpoint:    edge.ChannelPoint.String(),
×
3860
                Capacity:    capacity,
×
3861
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3862
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3863
        }
×
3864

×
3865
        if edge.AuthProof != nil {
×
3866
                proof := edge.AuthProof
×
3867

×
3868
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3869
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3870
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3871
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3872
        }
×
3873

3874
        // Insert the new channel record.
3875
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3876
        if err != nil {
×
3877
                return nil, err
×
3878
        }
×
3879

3880
        // Insert any channel features.
3881
        for feature := range edge.Features.Features() {
×
3882
                err = db.InsertChannelFeature(
×
3883
                        ctx, sqlc.InsertChannelFeatureParams{
×
3884
                                ChannelID:  dbChanID,
×
3885
                                FeatureBit: int32(feature),
×
3886
                        },
×
3887
                )
×
3888
                if err != nil {
×
3889
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
3890
                                "feature(%v): %w", dbChanID, feature, err)
×
3891
                }
×
3892
        }
3893

3894
        // Finally, insert any extra TLV fields in the channel announcement.
3895
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3896
        if err != nil {
×
3897
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
3898
                        "data: %w", err)
×
3899
        }
×
3900

3901
        for tlvType, value := range extra {
×
3902
                err := db.CreateChannelExtraType(
×
3903
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3904
                                ChannelID: dbChanID,
×
3905
                                Type:      int64(tlvType),
×
3906
                                Value:     value,
×
3907
                        },
×
3908
                )
×
3909
                if err != nil {
×
3910
                        return nil, fmt.Errorf("unable to upsert "+
×
3911
                                "channel(%d) extra signed field(%v): %w",
×
3912
                                edge.ChannelID, tlvType, err)
×
3913
                }
×
3914
        }
3915

3916
        return &dbChanInfo{
×
3917
                channelID: dbChanID,
×
3918
                node1ID:   node1DBID,
×
3919
                node2ID:   node2DBID,
×
3920
        }, nil
×
3921
}
3922

3923
// maybeCreateShellNode checks if a shell node entry exists for the
3924
// given public key. If it does not exist, then a new shell node entry is
3925
// created. The ID of the node is returned. A shell node only has a protocol
3926
// version and public key persisted.
3927
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3928
        pubKey route.Vertex) (int64, error) {
×
3929

×
3930
        dbNode, err := db.GetNodeByPubKey(
×
3931
                ctx, sqlc.GetNodeByPubKeyParams{
×
3932
                        PubKey:  pubKey[:],
×
3933
                        Version: int16(ProtocolV1),
×
3934
                },
×
3935
        )
×
3936
        // The node exists. Return the ID.
×
3937
        if err == nil {
×
3938
                return dbNode.ID, nil
×
3939
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3940
                return 0, err
×
3941
        }
×
3942

3943
        // Otherwise, the node does not exist, so we create a shell entry for
3944
        // it.
3945
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3946
                Version: int16(ProtocolV1),
×
3947
                PubKey:  pubKey[:],
×
3948
        })
×
3949
        if err != nil {
×
3950
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3951
        }
×
3952

3953
        return id, nil
×
3954
}
3955

3956
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3957
// the database. This includes deleting any existing types and then inserting
3958
// the new types.
3959
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3960
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3961

×
3962
        // Delete all existing extra signed fields for the channel policy.
×
3963
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3964
        if err != nil {
×
3965
                return fmt.Errorf("unable to delete "+
×
3966
                        "existing policy extra signed fields for policy %d: %w",
×
3967
                        chanPolicyID, err)
×
3968
        }
×
3969

3970
        // Insert all new extra signed fields for the channel policy.
3971
        for tlvType, value := range extraFields {
×
3972
                err = db.InsertChanPolicyExtraType(
×
3973
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3974
                                ChannelPolicyID: chanPolicyID,
×
3975
                                Type:            int64(tlvType),
×
3976
                                Value:           value,
×
3977
                        },
×
3978
                )
×
3979
                if err != nil {
×
3980
                        return fmt.Errorf("unable to insert "+
×
3981
                                "channel_policy(%d) extra signed field(%v): %w",
×
3982
                                chanPolicyID, tlvType, err)
×
3983
                }
×
3984
        }
3985

3986
        return nil
×
3987
}
3988

3989
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3990
// provided dbChanRow and also fetches any other required information
3991
// to construct the edge info.
3992
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
3993
        chain chainhash.Hash, dbChan sqlc.GraphChannel, node1,
3994
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3995

×
3996
        // NOTE: getAndBuildEdgeInfo is only used to load the data for a single
×
3997
        // edge, and so no paged queries will be performed. This means that
×
3998
        // it's ok to used pass in default config values here.
×
3999
        cfg := sqldb.DefaultQueryConfig()
×
4000

×
4001
        data, err := batchLoadChannelData(ctx, cfg, db, []int64{dbChan.ID}, nil)
×
4002
        if err != nil {
×
4003
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
4004
                        err)
×
4005
        }
×
4006

4007
        return buildEdgeInfoWithBatchData(chain, dbChan, node1, node2, data)
×
4008
}
4009

4010
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
4011
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
4012
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
4013
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
4014

×
4015
        if dbChan.Version != int16(ProtocolV1) {
×
4016
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4017
                        dbChan.Version)
×
4018
        }
×
4019

4020
        // Use pre-loaded features and extras types.
4021
        fv := lnwire.EmptyFeatureVector()
×
4022
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
4023
                for _, bit := range features {
×
4024
                        fv.Set(lnwire.FeatureBit(bit))
×
4025
                }
×
4026
        }
4027

4028
        var extras map[uint64][]byte
×
4029
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
4030
        if exists {
×
4031
                extras = channelExtras
×
4032
        } else {
×
4033
                extras = make(map[uint64][]byte)
×
4034
        }
×
4035

4036
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4037
        if err != nil {
×
4038
                return nil, err
×
4039
        }
×
4040

4041
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4042
        if err != nil {
×
4043
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4044
                        "fields: %w", err)
×
4045
        }
×
4046
        if recs == nil {
×
4047
                recs = make([]byte, 0)
×
4048
        }
×
4049

4050
        var btcKey1, btcKey2 route.Vertex
×
4051
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4052
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4053

×
4054
        channel := &models.ChannelEdgeInfo{
×
4055
                ChainHash:        chain,
×
4056
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4057
                NodeKey1Bytes:    node1,
×
4058
                NodeKey2Bytes:    node2,
×
4059
                BitcoinKey1Bytes: btcKey1,
×
4060
                BitcoinKey2Bytes: btcKey2,
×
4061
                ChannelPoint:     *op,
×
4062
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4063
                Features:         fv,
×
4064
                ExtraOpaqueData:  recs,
×
4065
        }
×
4066

×
4067
        // We always set all the signatures at the same time, so we can
×
4068
        // safely check if one signature is present to determine if we have the
×
4069
        // rest of the signatures for the auth proof.
×
4070
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4071
                channel.AuthProof = &models.ChannelAuthProof{
×
4072
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4073
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4074
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4075
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4076
                }
×
4077
        }
×
4078

4079
        return channel, nil
×
4080
}
4081

4082
// buildNodeVertices is a helper that converts raw node public keys
4083
// into route.Vertex instances.
4084
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4085
        route.Vertex, error) {
×
4086

×
4087
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4088
        if err != nil {
×
4089
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4090
                        "create vertex from node1 pubkey: %w", err)
×
4091
        }
×
4092

4093
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4094
        if err != nil {
×
4095
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4096
                        "create vertex from node2 pubkey: %w", err)
×
4097
        }
×
4098

4099
        return node1Vertex, node2Vertex, nil
×
4100
}
4101

4102
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4103
// retrieves all the extra info required to build the complete
4104
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4105
// the provided sqlc.GraphChannelPolicy records are nil.
4106
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4107
        dbPol1, dbPol2 *sqlc.GraphChannelPolicy, channelID uint64, node1,
4108
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4109
        *models.ChannelEdgePolicy, error) {
×
4110

×
4111
        if dbPol1 == nil && dbPol2 == nil {
×
4112
                return nil, nil, nil
×
4113
        }
×
4114

4115
        var policyIDs = make([]int64, 0, 2)
×
4116
        if dbPol1 != nil {
×
4117
                policyIDs = append(policyIDs, dbPol1.ID)
×
4118
        }
×
4119
        if dbPol2 != nil {
×
4120
                policyIDs = append(policyIDs, dbPol2.ID)
×
4121
        }
×
4122

4123
        // NOTE: getAndBuildChanPolicies is only used to load the data for
4124
        // a maximum of two policies, and so no paged queries will be
4125
        // performed (unless the page size is one). So it's ok to use
4126
        // the default config values here.
4127
        cfg := sqldb.DefaultQueryConfig()
×
4128

×
4129
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4130
        if err != nil {
×
4131
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4132
                        "data: %w", err)
×
4133
        }
×
4134

4135
        pol1, err := buildChanPolicyWithBatchData(
×
4136
                dbPol1, channelID, node2, batchData,
×
4137
        )
×
4138
        if err != nil {
×
4139
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4140
        }
×
4141

4142
        pol2, err := buildChanPolicyWithBatchData(
×
4143
                dbPol2, channelID, node1, batchData,
×
4144
        )
×
4145
        if err != nil {
×
4146
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4147
        }
×
4148

4149
        return pol1, pol2, nil
×
4150
}
4151

4152
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4153
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4154
// then nil is returned for it.
4155
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4156
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4157
        *models.CachedEdgePolicy, error) {
×
4158

×
4159
        var p1, p2 *models.CachedEdgePolicy
×
4160
        if dbPol1 != nil {
×
4161
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
4162
                if err != nil {
×
4163
                        return nil, nil, err
×
4164
                }
×
4165

4166
                p1 = models.NewCachedPolicy(policy1)
×
4167
        }
4168
        if dbPol2 != nil {
×
4169
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4170
                if err != nil {
×
4171
                        return nil, nil, err
×
4172
                }
×
4173

4174
                p2 = models.NewCachedPolicy(policy2)
×
4175
        }
4176

4177
        return p1, p2, nil
×
4178
}
4179

4180
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4181
// provided sqlc.GraphChannelPolicy and other required information.
4182
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4183
        extras map[uint64][]byte,
4184
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4185

×
4186
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4187
        if err != nil {
×
4188
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4189
                        "fields: %w", err)
×
4190
        }
×
4191

4192
        var inboundFee fn.Option[lnwire.Fee]
×
4193
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4194
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4195

×
4196
                inboundFee = fn.Some(lnwire.Fee{
×
4197
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4198
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4199
                })
×
4200
        }
×
4201

4202
        return &models.ChannelEdgePolicy{
×
4203
                SigBytes:  dbPolicy.Signature,
×
4204
                ChannelID: channelID,
×
4205
                LastUpdate: time.Unix(
×
4206
                        dbPolicy.LastUpdate.Int64, 0,
×
4207
                ),
×
4208
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4209
                        dbPolicy.MessageFlags,
×
4210
                ),
×
4211
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4212
                        dbPolicy.ChannelFlags,
×
4213
                ),
×
4214
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4215
                MinHTLC: lnwire.MilliSatoshi(
×
4216
                        dbPolicy.MinHtlcMsat,
×
4217
                ),
×
4218
                MaxHTLC: lnwire.MilliSatoshi(
×
4219
                        dbPolicy.MaxHtlcMsat.Int64,
×
4220
                ),
×
4221
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4222
                        dbPolicy.BaseFeeMsat,
×
4223
                ),
×
4224
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4225
                ToNode:                    toNode,
×
4226
                InboundFee:                inboundFee,
×
4227
                ExtraOpaqueData:           recs,
×
4228
        }, nil
×
4229
}
4230

4231
// buildNodes builds the models.LightningNode instances for the
4232
// given row which is expected to be a sqlc type that contains node information.
4233
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4234
        dbNode2 sqlc.GraphNode) (*models.LightningNode, *models.LightningNode,
4235
        error) {
×
4236

×
4237
        node1, err := buildNode(ctx, db, &dbNode1)
×
4238
        if err != nil {
×
4239
                return nil, nil, err
×
4240
        }
×
4241

4242
        node2, err := buildNode(ctx, db, &dbNode2)
×
4243
        if err != nil {
×
4244
                return nil, nil, err
×
4245
        }
×
4246

4247
        return node1, node2, nil
×
4248
}
4249

4250
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4251
// row which is expected to be a sqlc type that contains channel policy
4252
// information. It returns two policies, which may be nil if the policy
4253
// information is not present in the row.
4254
//
4255
//nolint:ll,dupl,funlen
4256
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4257
        *sqlc.GraphChannelPolicy, error) {
×
4258

×
4259
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4260
        switch r := row.(type) {
×
4261
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4262
                if r.Policy1Timelock.Valid {
×
4263
                        policy1 = &sqlc.GraphChannelPolicy{
×
4264
                                Timelock:                r.Policy1Timelock.Int32,
×
4265
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4266
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4267
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4268
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4269
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4270
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4271
                                Disabled:                r.Policy1Disabled,
×
4272
                                MessageFlags:            r.Policy1MessageFlags,
×
4273
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4274
                        }
×
4275
                }
×
4276
                if r.Policy2Timelock.Valid {
×
4277
                        policy2 = &sqlc.GraphChannelPolicy{
×
4278
                                Timelock:                r.Policy2Timelock.Int32,
×
4279
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4280
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4281
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4282
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4283
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4284
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4285
                                Disabled:                r.Policy2Disabled,
×
4286
                                MessageFlags:            r.Policy2MessageFlags,
×
4287
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4288
                        }
×
4289
                }
×
4290

4291
                return policy1, policy2, nil
×
4292

4293
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4294
                if r.Policy1ID.Valid {
×
4295
                        policy1 = &sqlc.GraphChannelPolicy{
×
4296
                                ID:                      r.Policy1ID.Int64,
×
4297
                                Version:                 r.Policy1Version.Int16,
×
4298
                                ChannelID:               r.GraphChannel.ID,
×
4299
                                NodeID:                  r.Policy1NodeID.Int64,
×
4300
                                Timelock:                r.Policy1Timelock.Int32,
×
4301
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4302
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4303
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4304
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4305
                                LastUpdate:              r.Policy1LastUpdate,
×
4306
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4307
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4308
                                Disabled:                r.Policy1Disabled,
×
4309
                                MessageFlags:            r.Policy1MessageFlags,
×
4310
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4311
                                Signature:               r.Policy1Signature,
×
4312
                        }
×
4313
                }
×
4314
                if r.Policy2ID.Valid {
×
4315
                        policy2 = &sqlc.GraphChannelPolicy{
×
4316
                                ID:                      r.Policy2ID.Int64,
×
4317
                                Version:                 r.Policy2Version.Int16,
×
4318
                                ChannelID:               r.GraphChannel.ID,
×
4319
                                NodeID:                  r.Policy2NodeID.Int64,
×
4320
                                Timelock:                r.Policy2Timelock.Int32,
×
4321
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4322
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4323
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4324
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4325
                                LastUpdate:              r.Policy2LastUpdate,
×
4326
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4327
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4328
                                Disabled:                r.Policy2Disabled,
×
4329
                                MessageFlags:            r.Policy2MessageFlags,
×
4330
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4331
                                Signature:               r.Policy2Signature,
×
4332
                        }
×
4333
                }
×
4334

4335
                return policy1, policy2, nil
×
4336

4337
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4338
                if r.Policy1ID.Valid {
×
4339
                        policy1 = &sqlc.GraphChannelPolicy{
×
4340
                                ID:                      r.Policy1ID.Int64,
×
4341
                                Version:                 r.Policy1Version.Int16,
×
4342
                                ChannelID:               r.GraphChannel.ID,
×
4343
                                NodeID:                  r.Policy1NodeID.Int64,
×
4344
                                Timelock:                r.Policy1Timelock.Int32,
×
4345
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4346
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4347
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4348
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4349
                                LastUpdate:              r.Policy1LastUpdate,
×
4350
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4351
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4352
                                Disabled:                r.Policy1Disabled,
×
4353
                                MessageFlags:            r.Policy1MessageFlags,
×
4354
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4355
                                Signature:               r.Policy1Signature,
×
4356
                        }
×
4357
                }
×
4358
                if r.Policy2ID.Valid {
×
4359
                        policy2 = &sqlc.GraphChannelPolicy{
×
4360
                                ID:                      r.Policy2ID.Int64,
×
4361
                                Version:                 r.Policy2Version.Int16,
×
4362
                                ChannelID:               r.GraphChannel.ID,
×
4363
                                NodeID:                  r.Policy2NodeID.Int64,
×
4364
                                Timelock:                r.Policy2Timelock.Int32,
×
4365
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4366
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4367
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4368
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4369
                                LastUpdate:              r.Policy2LastUpdate,
×
4370
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4371
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4372
                                Disabled:                r.Policy2Disabled,
×
4373
                                MessageFlags:            r.Policy2MessageFlags,
×
4374
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4375
                                Signature:               r.Policy2Signature,
×
4376
                        }
×
4377
                }
×
4378

4379
                return policy1, policy2, nil
×
4380

4381
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4382
                if r.Policy1ID.Valid {
×
4383
                        policy1 = &sqlc.GraphChannelPolicy{
×
4384
                                ID:                      r.Policy1ID.Int64,
×
4385
                                Version:                 r.Policy1Version.Int16,
×
4386
                                ChannelID:               r.GraphChannel.ID,
×
4387
                                NodeID:                  r.Policy1NodeID.Int64,
×
4388
                                Timelock:                r.Policy1Timelock.Int32,
×
4389
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4390
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4391
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4392
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4393
                                LastUpdate:              r.Policy1LastUpdate,
×
4394
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4395
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4396
                                Disabled:                r.Policy1Disabled,
×
4397
                                MessageFlags:            r.Policy1MessageFlags,
×
4398
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4399
                                Signature:               r.Policy1Signature,
×
4400
                        }
×
4401
                }
×
4402
                if r.Policy2ID.Valid {
×
4403
                        policy2 = &sqlc.GraphChannelPolicy{
×
4404
                                ID:                      r.Policy2ID.Int64,
×
4405
                                Version:                 r.Policy2Version.Int16,
×
4406
                                ChannelID:               r.GraphChannel.ID,
×
4407
                                NodeID:                  r.Policy2NodeID.Int64,
×
4408
                                Timelock:                r.Policy2Timelock.Int32,
×
4409
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4410
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4411
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4412
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4413
                                LastUpdate:              r.Policy2LastUpdate,
×
4414
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4415
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4416
                                Disabled:                r.Policy2Disabled,
×
4417
                                MessageFlags:            r.Policy2MessageFlags,
×
4418
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4419
                                Signature:               r.Policy2Signature,
×
4420
                        }
×
4421
                }
×
4422

4423
                return policy1, policy2, nil
×
4424

4425
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4426
                if r.Policy1ID.Valid {
×
4427
                        policy1 = &sqlc.GraphChannelPolicy{
×
4428
                                ID:                      r.Policy1ID.Int64,
×
4429
                                Version:                 r.Policy1Version.Int16,
×
4430
                                ChannelID:               r.GraphChannel.ID,
×
4431
                                NodeID:                  r.Policy1NodeID.Int64,
×
4432
                                Timelock:                r.Policy1Timelock.Int32,
×
4433
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4434
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4435
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4436
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4437
                                LastUpdate:              r.Policy1LastUpdate,
×
4438
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4439
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4440
                                Disabled:                r.Policy1Disabled,
×
4441
                                MessageFlags:            r.Policy1MessageFlags,
×
4442
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4443
                                Signature:               r.Policy1Signature,
×
4444
                        }
×
4445
                }
×
4446
                if r.Policy2ID.Valid {
×
4447
                        policy2 = &sqlc.GraphChannelPolicy{
×
4448
                                ID:                      r.Policy2ID.Int64,
×
4449
                                Version:                 r.Policy2Version.Int16,
×
4450
                                ChannelID:               r.GraphChannel.ID,
×
4451
                                NodeID:                  r.Policy2NodeID.Int64,
×
4452
                                Timelock:                r.Policy2Timelock.Int32,
×
4453
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4454
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4455
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4456
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4457
                                LastUpdate:              r.Policy2LastUpdate,
×
4458
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4459
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4460
                                Disabled:                r.Policy2Disabled,
×
4461
                                MessageFlags:            r.Policy2MessageFlags,
×
4462
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4463
                                Signature:               r.Policy2Signature,
×
4464
                        }
×
4465
                }
×
4466

4467
                return policy1, policy2, nil
×
4468

4469
        case sqlc.ListChannelsForNodeIDsRow:
×
4470
                if r.Policy1ID.Valid {
×
4471
                        policy1 = &sqlc.GraphChannelPolicy{
×
4472
                                ID:                      r.Policy1ID.Int64,
×
4473
                                Version:                 r.Policy1Version.Int16,
×
4474
                                ChannelID:               r.GraphChannel.ID,
×
4475
                                NodeID:                  r.Policy1NodeID.Int64,
×
4476
                                Timelock:                r.Policy1Timelock.Int32,
×
4477
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4478
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4479
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4480
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4481
                                LastUpdate:              r.Policy1LastUpdate,
×
4482
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4483
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4484
                                Disabled:                r.Policy1Disabled,
×
4485
                                MessageFlags:            r.Policy1MessageFlags,
×
4486
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4487
                                Signature:               r.Policy1Signature,
×
4488
                        }
×
4489
                }
×
4490
                if r.Policy2ID.Valid {
×
4491
                        policy2 = &sqlc.GraphChannelPolicy{
×
4492
                                ID:                      r.Policy2ID.Int64,
×
4493
                                Version:                 r.Policy2Version.Int16,
×
4494
                                ChannelID:               r.GraphChannel.ID,
×
4495
                                NodeID:                  r.Policy2NodeID.Int64,
×
4496
                                Timelock:                r.Policy2Timelock.Int32,
×
4497
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4498
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4499
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4500
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4501
                                LastUpdate:              r.Policy2LastUpdate,
×
4502
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4503
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4504
                                Disabled:                r.Policy2Disabled,
×
4505
                                MessageFlags:            r.Policy2MessageFlags,
×
4506
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4507
                                Signature:               r.Policy2Signature,
×
4508
                        }
×
4509
                }
×
4510

4511
                return policy1, policy2, nil
×
4512

4513
        case sqlc.ListChannelsByNodeIDRow:
×
4514
                if r.Policy1ID.Valid {
×
4515
                        policy1 = &sqlc.GraphChannelPolicy{
×
4516
                                ID:                      r.Policy1ID.Int64,
×
4517
                                Version:                 r.Policy1Version.Int16,
×
4518
                                ChannelID:               r.GraphChannel.ID,
×
4519
                                NodeID:                  r.Policy1NodeID.Int64,
×
4520
                                Timelock:                r.Policy1Timelock.Int32,
×
4521
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4522
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4523
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4524
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4525
                                LastUpdate:              r.Policy1LastUpdate,
×
4526
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4527
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4528
                                Disabled:                r.Policy1Disabled,
×
4529
                                MessageFlags:            r.Policy1MessageFlags,
×
4530
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4531
                                Signature:               r.Policy1Signature,
×
4532
                        }
×
4533
                }
×
4534
                if r.Policy2ID.Valid {
×
4535
                        policy2 = &sqlc.GraphChannelPolicy{
×
4536
                                ID:                      r.Policy2ID.Int64,
×
4537
                                Version:                 r.Policy2Version.Int16,
×
4538
                                ChannelID:               r.GraphChannel.ID,
×
4539
                                NodeID:                  r.Policy2NodeID.Int64,
×
4540
                                Timelock:                r.Policy2Timelock.Int32,
×
4541
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4542
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4543
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4544
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4545
                                LastUpdate:              r.Policy2LastUpdate,
×
4546
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4547
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4548
                                Disabled:                r.Policy2Disabled,
×
4549
                                MessageFlags:            r.Policy2MessageFlags,
×
4550
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4551
                                Signature:               r.Policy2Signature,
×
4552
                        }
×
4553
                }
×
4554

4555
                return policy1, policy2, nil
×
4556

4557
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4558
                if r.Policy1ID.Valid {
×
4559
                        policy1 = &sqlc.GraphChannelPolicy{
×
4560
                                ID:                      r.Policy1ID.Int64,
×
4561
                                Version:                 r.Policy1Version.Int16,
×
4562
                                ChannelID:               r.GraphChannel.ID,
×
4563
                                NodeID:                  r.Policy1NodeID.Int64,
×
4564
                                Timelock:                r.Policy1Timelock.Int32,
×
4565
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4566
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4567
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4568
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4569
                                LastUpdate:              r.Policy1LastUpdate,
×
4570
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4571
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4572
                                Disabled:                r.Policy1Disabled,
×
4573
                                MessageFlags:            r.Policy1MessageFlags,
×
4574
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4575
                                Signature:               r.Policy1Signature,
×
4576
                        }
×
4577
                }
×
4578
                if r.Policy2ID.Valid {
×
4579
                        policy2 = &sqlc.GraphChannelPolicy{
×
4580
                                ID:                      r.Policy2ID.Int64,
×
4581
                                Version:                 r.Policy2Version.Int16,
×
4582
                                ChannelID:               r.GraphChannel.ID,
×
4583
                                NodeID:                  r.Policy2NodeID.Int64,
×
4584
                                Timelock:                r.Policy2Timelock.Int32,
×
4585
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4586
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4587
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4588
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4589
                                LastUpdate:              r.Policy2LastUpdate,
×
4590
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4591
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4592
                                Disabled:                r.Policy2Disabled,
×
4593
                                MessageFlags:            r.Policy2MessageFlags,
×
4594
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4595
                                Signature:               r.Policy2Signature,
×
4596
                        }
×
4597
                }
×
4598

4599
                return policy1, policy2, nil
×
4600
        default:
×
4601
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4602
                        "extractChannelPolicies: %T", r)
×
4603
        }
4604
}
4605

4606
// channelIDToBytes converts a channel ID (SCID) to a byte array
4607
// representation.
4608
func channelIDToBytes(channelID uint64) []byte {
×
4609
        var chanIDB [8]byte
×
4610
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4611

×
4612
        return chanIDB[:]
×
4613
}
×
4614

4615
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4616
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4617
        if len(addresses) == 0 {
×
4618
                return nil, nil
×
4619
        }
×
4620

4621
        result := make([]net.Addr, 0, len(addresses))
×
4622
        for _, addr := range addresses {
×
4623
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
4624
                if err != nil {
×
4625
                        return nil, fmt.Errorf("unable to parse address %s "+
×
4626
                                "of type %d: %w", addr.address, addr.addrType,
×
4627
                                err)
×
4628
                }
×
4629
                if netAddr != nil {
×
4630
                        result = append(result, netAddr)
×
4631
                }
×
4632
        }
4633

4634
        // If we have no valid addresses, return nil instead of empty slice.
4635
        if len(result) == 0 {
×
4636
                return nil, nil
×
4637
        }
×
4638

4639
        return result, nil
×
4640
}
4641

4642
// parseAddress parses the given address string based on the address type
4643
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4644
// and opaque addresses.
4645
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4646
        switch addrType {
×
4647
        case addressTypeIPv4:
×
4648
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4649
                if err != nil {
×
4650
                        return nil, err
×
4651
                }
×
4652

4653
                tcp.IP = tcp.IP.To4()
×
4654

×
4655
                return tcp, nil
×
4656

4657
        case addressTypeIPv6:
×
4658
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
4659
                if err != nil {
×
4660
                        return nil, err
×
4661
                }
×
4662

4663
                return tcp, nil
×
4664

4665
        case addressTypeTorV3, addressTypeTorV2:
×
4666
                service, portStr, err := net.SplitHostPort(address)
×
4667
                if err != nil {
×
4668
                        return nil, fmt.Errorf("unable to split tor "+
×
4669
                                "address: %v", address)
×
4670
                }
×
4671

4672
                port, err := strconv.Atoi(portStr)
×
4673
                if err != nil {
×
4674
                        return nil, err
×
4675
                }
×
4676

4677
                return &tor.OnionAddr{
×
4678
                        OnionService: service,
×
4679
                        Port:         port,
×
4680
                }, nil
×
4681

4682
        case addressTypeOpaque:
×
4683
                opaque, err := hex.DecodeString(address)
×
4684
                if err != nil {
×
4685
                        return nil, fmt.Errorf("unable to decode opaque "+
×
4686
                                "address: %v", address)
×
4687
                }
×
4688

4689
                return &lnwire.OpaqueAddrs{
×
4690
                        Payload: opaque,
×
4691
                }, nil
×
4692

4693
        default:
×
4694
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
4695
        }
4696
}
4697

4698
// batchNodeData holds all the related data for a batch of nodes.
4699
type batchNodeData struct {
4700
        // features is a map from a DB node ID to the feature bits for that
4701
        // node.
4702
        features map[int64][]int
4703

4704
        // addresses is a map from a DB node ID to the node's addresses.
4705
        addresses map[int64][]nodeAddress
4706

4707
        // extraFields is a map from a DB node ID to the extra signed fields
4708
        // for that node.
4709
        extraFields map[int64]map[uint64][]byte
4710
}
4711

4712
// nodeAddress holds the address type, position and address string for a
4713
// node. This is used to batch the fetching of node addresses.
4714
type nodeAddress struct {
4715
        addrType dbAddressType
4716
        position int32
4717
        address  string
4718
}
4719

4720
// batchLoadNodeData loads all related data for a batch of node IDs using the
4721
// provided SQLQueries interface. It returns a batchNodeData instance containing
4722
// the node features, addresses and extra signed fields.
4723
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
4724
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
4725

×
4726
        // Batch load the node features.
×
4727
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
4728
        if err != nil {
×
4729
                return nil, fmt.Errorf("unable to batch load node "+
×
4730
                        "features: %w", err)
×
4731
        }
×
4732

4733
        // Batch load the node addresses.
4734
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
4735
        if err != nil {
×
4736
                return nil, fmt.Errorf("unable to batch load node "+
×
4737
                        "addresses: %w", err)
×
4738
        }
×
4739

4740
        // Batch load the node extra signed fields.
4741
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
4742
        if err != nil {
×
4743
                return nil, fmt.Errorf("unable to batch load node extra "+
×
4744
                        "signed fields: %w", err)
×
4745
        }
×
4746

4747
        return &batchNodeData{
×
4748
                features:    features,
×
4749
                addresses:   addrs,
×
4750
                extraFields: extraTypes,
×
4751
        }, nil
×
4752
}
4753

4754
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
4755
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
4756
func batchLoadNodeFeaturesHelper(ctx context.Context,
4757
        cfg *sqldb.QueryConfig, db SQLQueries,
4758
        nodeIDs []int64) (map[int64][]int, error) {
×
4759

×
4760
        features := make(map[int64][]int)
×
4761

×
4762
        return features, sqldb.ExecuteBatchQuery(
×
4763
                ctx, cfg, nodeIDs,
×
4764
                func(id int64) int64 {
×
4765
                        return id
×
4766
                },
×
4767
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
4768
                        error) {
×
4769

×
4770
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
4771
                },
×
4772
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
4773
                        features[feature.NodeID] = append(
×
4774
                                features[feature.NodeID],
×
4775
                                int(feature.FeatureBit),
×
4776
                        )
×
4777

×
4778
                        return nil
×
4779
                },
×
4780
        )
4781
}
4782

4783
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
4784
// wrapper around the GetNodeAddressesBatch query. It returns a map from
4785
// node ID to a slice of nodeAddress structs.
4786
func batchLoadNodeAddressesHelper(ctx context.Context,
4787
        cfg *sqldb.QueryConfig, db SQLQueries,
4788
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
4789

×
4790
        addrs := make(map[int64][]nodeAddress)
×
4791

×
4792
        return addrs, sqldb.ExecuteBatchQuery(
×
4793
                ctx, cfg, nodeIDs,
×
4794
                func(id int64) int64 {
×
4795
                        return id
×
4796
                },
×
4797
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
4798
                        error) {
×
4799

×
4800
                        return db.GetNodeAddressesBatch(ctx, ids)
×
4801
                },
×
4802
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
4803
                        addrs[addr.NodeID] = append(
×
4804
                                addrs[addr.NodeID], nodeAddress{
×
4805
                                        addrType: dbAddressType(addr.Type),
×
4806
                                        position: addr.Position,
×
4807
                                        address:  addr.Address,
×
4808
                                },
×
4809
                        )
×
4810

×
4811
                        return nil
×
4812
                },
×
4813
        )
4814
}
4815

4816
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
4817
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
4818
// query.
4819
func batchLoadNodeExtraTypesHelper(ctx context.Context,
4820
        cfg *sqldb.QueryConfig, db SQLQueries,
4821
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4822

×
4823
        extraFields := make(map[int64]map[uint64][]byte)
×
4824

×
4825
        callback := func(ctx context.Context,
×
4826
                field sqlc.GraphNodeExtraType) error {
×
4827

×
4828
                if extraFields[field.NodeID] == nil {
×
4829
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
4830
                }
×
4831
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
4832

×
4833
                return nil
×
4834
        }
4835

4836
        return extraFields, sqldb.ExecuteBatchQuery(
×
4837
                ctx, cfg, nodeIDs,
×
4838
                func(id int64) int64 {
×
4839
                        return id
×
4840
                },
×
4841
                func(ctx context.Context, ids []int64) (
4842
                        []sqlc.GraphNodeExtraType, error) {
×
4843

×
4844
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
4845
                },
×
4846
                callback,
4847
        )
4848
}
4849

4850
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
4851
// from the provided sqlc.GraphChannelPolicy records and the
4852
// provided batchChannelData.
4853
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4854
        channelID uint64, node1, node2 route.Vertex,
4855
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
4856
        *models.ChannelEdgePolicy, error) {
×
4857

×
4858
        pol1, err := buildChanPolicyWithBatchData(
×
4859
                dbPol1, channelID, node2, batchData,
×
4860
        )
×
4861
        if err != nil {
×
4862
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4863
        }
×
4864

4865
        pol2, err := buildChanPolicyWithBatchData(
×
4866
                dbPol2, channelID, node1, batchData,
×
4867
        )
×
4868
        if err != nil {
×
4869
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4870
        }
×
4871

4872
        return pol1, pol2, nil
×
4873
}
4874

4875
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
4876
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
4877
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
4878
        channelID uint64, toNode route.Vertex,
4879
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
4880

×
4881
        if dbPol == nil {
×
4882
                return nil, nil
×
4883
        }
×
4884

4885
        var dbPol1Extras map[uint64][]byte
×
4886
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
4887
                dbPol1Extras = extras
×
4888
        } else {
×
4889
                dbPol1Extras = make(map[uint64][]byte)
×
4890
        }
×
4891

4892
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
4893
}
4894

4895
// batchChannelData holds all the related data for a batch of channels.
4896
type batchChannelData struct {
4897
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
4898
        chanfeatures map[int64][]int
4899

4900
        // chanExtras is a map from DB channel ID to a map of TLV type to
4901
        // extra signed field bytes.
4902
        chanExtraTypes map[int64]map[uint64][]byte
4903

4904
        // policyExtras is a map from DB channel policy ID to a map of TLV type
4905
        // to extra signed field bytes.
4906
        policyExtras map[int64]map[uint64][]byte
4907
}
4908

4909
// batchLoadChannelData loads all related data for batches of channels and
4910
// policies.
4911
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
4912
        db SQLQueries, channelIDs []int64,
4913
        policyIDs []int64) (*batchChannelData, error) {
×
4914

×
4915
        batchData := &batchChannelData{
×
4916
                chanfeatures:   make(map[int64][]int),
×
4917
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
4918
                policyExtras:   make(map[int64]map[uint64][]byte),
×
4919
        }
×
4920

×
4921
        // Batch load channel features and extras
×
4922
        var err error
×
4923
        if len(channelIDs) > 0 {
×
4924
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
4925
                        ctx, cfg, db, channelIDs,
×
4926
                )
×
4927
                if err != nil {
×
4928
                        return nil, fmt.Errorf("unable to batch load "+
×
4929
                                "channel features: %w", err)
×
4930
                }
×
4931

4932
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
4933
                        ctx, cfg, db, channelIDs,
×
4934
                )
×
4935
                if err != nil {
×
4936
                        return nil, fmt.Errorf("unable to batch load "+
×
4937
                                "channel extras: %w", err)
×
4938
                }
×
4939
        }
4940

4941
        if len(policyIDs) > 0 {
×
4942
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
4943
                        ctx, cfg, db, policyIDs,
×
4944
                )
×
4945
                if err != nil {
×
4946
                        return nil, fmt.Errorf("unable to batch load "+
×
4947
                                "policy extras: %w", err)
×
4948
                }
×
4949
                batchData.policyExtras = policyExtras
×
4950
        }
4951

4952
        return batchData, nil
×
4953
}
4954

4955
// batchLoadChannelFeaturesHelper loads channel features for a batch of
4956
// channel IDs using ExecuteBatchQuery wrapper around the
4957
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
4958
// slice of feature bits.
4959
func batchLoadChannelFeaturesHelper(ctx context.Context,
4960
        cfg *sqldb.QueryConfig, db SQLQueries,
4961
        channelIDs []int64) (map[int64][]int, error) {
×
4962

×
4963
        features := make(map[int64][]int)
×
4964

×
4965
        return features, sqldb.ExecuteBatchQuery(
×
4966
                ctx, cfg, channelIDs,
×
4967
                func(id int64) int64 {
×
4968
                        return id
×
4969
                },
×
4970
                func(ctx context.Context,
4971
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
4972

×
4973
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
4974
                },
×
4975
                func(ctx context.Context,
4976
                        feature sqlc.GraphChannelFeature) error {
×
4977

×
4978
                        features[feature.ChannelID] = append(
×
4979
                                features[feature.ChannelID],
×
4980
                                int(feature.FeatureBit),
×
4981
                        )
×
4982

×
4983
                        return nil
×
4984
                },
×
4985
        )
4986
}
4987

4988
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
4989
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
4990
// query. It returns a map from DB channel ID to a map of TLV type to extra
4991
// signed field bytes.
4992
func batchLoadChannelExtrasHelper(ctx context.Context,
4993
        cfg *sqldb.QueryConfig, db SQLQueries,
4994
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4995

×
4996
        extras := make(map[int64]map[uint64][]byte)
×
4997

×
4998
        cb := func(ctx context.Context,
×
4999
                extra sqlc.GraphChannelExtraType) error {
×
5000

×
5001
                if extras[extra.ChannelID] == nil {
×
5002
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5003
                }
×
5004
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
5005

×
5006
                return nil
×
5007
        }
5008

5009
        return extras, sqldb.ExecuteBatchQuery(
×
5010
                ctx, cfg, channelIDs,
×
5011
                func(id int64) int64 {
×
5012
                        return id
×
5013
                },
×
5014
                func(ctx context.Context,
5015
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5016

×
5017
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5018
                }, cb,
×
5019
        )
5020
}
5021

5022
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5023
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5024
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5025
// a map of TLV type to extra signed field bytes.
5026
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5027
        cfg *sqldb.QueryConfig, db SQLQueries,
5028
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5029

×
5030
        extras := make(map[int64]map[uint64][]byte)
×
5031

×
5032
        return extras, sqldb.ExecuteBatchQuery(
×
5033
                ctx, cfg, policyIDs,
×
5034
                func(id int64) int64 {
×
5035
                        return id
×
5036
                },
×
5037
                func(ctx context.Context, ids []int64) (
5038
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5039

×
5040
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5041
                },
×
5042
                func(ctx context.Context,
5043
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5044

×
5045
                        if extras[row.PolicyID] == nil {
×
5046
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5047
                        }
×
5048
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5049

×
5050
                        return nil
×
5051
                },
5052
        )
5053
}
5054

5055
// forEachNodePaginated executes a paginated query to process each node in the
5056
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5057
// and applies the provided processNode function to each node.
5058
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5059
        db SQLQueries, protocol ProtocolVersion,
5060
        processNode func(context.Context, int64,
5061
                *models.LightningNode) error) error {
×
5062

×
5063
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5064
                limit int32) ([]sqlc.GraphNode, error) {
×
5065

×
5066
                return db.ListNodesPaginated(
×
5067
                        ctx, sqlc.ListNodesPaginatedParams{
×
5068
                                Version: int16(protocol),
×
5069
                                ID:      lastID,
×
5070
                                Limit:   limit,
×
5071
                        },
×
5072
                )
×
5073
        }
×
5074

5075
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5076
                return node.ID
×
5077
        }
×
5078

5079
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5080
                return node.ID, nil
×
5081
        }
×
5082

5083
        batchQueryFunc := func(ctx context.Context,
×
5084
                nodeIDs []int64) (*batchNodeData, error) {
×
5085

×
5086
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5087
        }
×
5088

5089
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5090
                batchData *batchNodeData) error {
×
5091

×
5092
                node, err := buildNodeWithBatchData(&dbNode, batchData)
×
5093
                if err != nil {
×
5094
                        return fmt.Errorf("unable to build "+
×
5095
                                "node(id=%d): %w", dbNode.ID, err)
×
5096
                }
×
5097

5098
                return processNode(ctx, dbNode.ID, node)
×
5099
        }
5100

5101
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5102
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5103
                collectFunc, batchQueryFunc, processItem,
×
5104
        )
×
5105
}
5106

5107
// forEachChannelWithPolicies executes a paginated query to process each channel
5108
// with policies in the graph.
5109
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5110
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5111
                *models.ChannelEdgePolicy,
5112
                *models.ChannelEdgePolicy) error) error {
×
5113

×
5114
        type channelBatchIDs struct {
×
5115
                channelID int64
×
5116
                policyIDs []int64
×
5117
        }
×
5118

×
5119
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5120
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5121
                error) {
×
5122

×
5123
                return db.ListChannelsWithPoliciesPaginated(
×
5124
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
5125
                                Version: int16(ProtocolV1),
×
5126
                                ID:      lastID,
×
5127
                                Limit:   limit,
×
5128
                        },
×
5129
                )
×
5130
        }
×
5131

5132
        extractPageCursor := func(
×
5133
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5134

×
5135
                return row.GraphChannel.ID
×
5136
        }
×
5137

5138
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
5139
                channelBatchIDs, error) {
×
5140

×
5141
                ids := channelBatchIDs{
×
5142
                        channelID: row.GraphChannel.ID,
×
5143
                }
×
5144

×
5145
                // Extract policy IDs from the row.
×
5146
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5147
                if err != nil {
×
5148
                        return ids, err
×
5149
                }
×
5150

5151
                if dbPol1 != nil {
×
5152
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5153
                }
×
5154
                if dbPol2 != nil {
×
5155
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5156
                }
×
5157

5158
                return ids, nil
×
5159
        }
5160

5161
        batchDataFunc := func(ctx context.Context,
×
5162
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5163

×
5164
                // Separate channel IDs from policy IDs.
×
5165
                var (
×
5166
                        channelIDs = make([]int64, len(allIDs))
×
5167
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5168
                )
×
5169

×
5170
                for i, ids := range allIDs {
×
5171
                        channelIDs[i] = ids.channelID
×
5172
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5173
                }
×
5174

5175
                return batchLoadChannelData(
×
5176
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5177
                )
×
5178
        }
5179

5180
        processItem := func(ctx context.Context,
×
5181
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5182
                batchData *batchChannelData) error {
×
5183

×
5184
                node1, node2, err := buildNodeVertices(
×
5185
                        row.Node1Pubkey, row.Node2Pubkey,
×
5186
                )
×
5187
                if err != nil {
×
5188
                        return err
×
5189
                }
×
5190

5191
                edge, err := buildEdgeInfoWithBatchData(
×
5192
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5193
                        batchData,
×
5194
                )
×
5195
                if err != nil {
×
5196
                        return fmt.Errorf("unable to build channel info: %w",
×
5197
                                err)
×
5198
                }
×
5199

5200
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5201
                if err != nil {
×
5202
                        return err
×
5203
                }
×
5204

5205
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5206
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
5207
                )
×
5208
                if err != nil {
×
5209
                        return err
×
5210
                }
×
5211

5212
                return processChannel(edge, p1, p2)
×
5213
        }
5214

5215
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5216
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5217
                collectFunc, batchDataFunc, processItem,
×
5218
        )
×
5219
}
5220

5221
// buildDirectedChannel builds a DirectedChannel instance from the provided
5222
// data.
5223
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5224
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5225
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5226
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
5227

×
5228
        node1, node2, err := buildNodeVertices(
×
5229
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5230
        )
×
5231
        if err != nil {
×
5232
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5233
        }
×
5234

5235
        edge, err := buildEdgeInfoWithBatchData(
×
5236
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5237
        )
×
5238
        if err != nil {
×
5239
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5240
        }
×
5241

5242
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
5243
        if err != nil {
×
5244
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5245
                        err)
×
5246
        }
×
5247

5248
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5249
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5250
                channelBatchData,
×
5251
        )
×
5252
        if err != nil {
×
5253
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5254
                        err)
×
5255
        }
×
5256

5257
        // Determine outgoing and incoming policy for this specific node.
5258
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5259
        p2ToNode := channelRow.GraphChannel.NodeID1
×
5260
        outPolicy, inPolicy := p1, p2
×
5261
        if (p1 != nil && p1ToNode == nodeID) ||
×
5262
                (p2 != nil && p2ToNode != nodeID) {
×
5263

×
5264
                outPolicy, inPolicy = p2, p1
×
5265
        }
×
5266

5267
        // Build cached policy.
5268
        var cachedInPolicy *models.CachedEdgePolicy
×
5269
        if inPolicy != nil {
×
5270
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5271
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5272
                cachedInPolicy.ToNodeFeatures = features
×
5273
        }
×
5274

5275
        // Extract inbound fee.
5276
        var inboundFee lnwire.Fee
×
5277
        if outPolicy != nil {
×
5278
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5279
                        inboundFee = fee
×
5280
                })
×
5281
        }
5282

5283
        // Build directed channel.
5284
        directedChannel := &DirectedChannel{
×
5285
                ChannelID:    edge.ChannelID,
×
5286
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
5287
                OtherNode:    edge.NodeKey2Bytes,
×
5288
                Capacity:     edge.Capacity,
×
5289
                OutPolicySet: outPolicy != nil,
×
5290
                InPolicy:     cachedInPolicy,
×
5291
                InboundFee:   inboundFee,
×
5292
        }
×
5293

×
5294
        if nodePub == edge.NodeKey2Bytes {
×
5295
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5296
        }
×
5297

5298
        return directedChannel, nil
×
5299
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc