• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16650483894

31 Jul 2025 01:33PM UTC coverage: 67.025% (-0.02%) from 67.047%
16650483894

Pull #10118

github

web-flow
Merge 31978b99f into 37523b6cb
Pull Request #10118: [4] sqldb+graph/db: add and use new pagination helper

6 of 239 new or added lines in 2 files covered. (2.51%)

109 existing lines in 20 files now uncovered.

135527 of 202203 relevant lines covered (67.03%)

21691.01 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// ProtocolVersion is an enum that defines the gossip protocol version of a
35
// message.
36
type ProtocolVersion uint8
37

38
const (
39
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
40
        ProtocolV1 ProtocolVersion = 1
41
)
42

43
// String returns a string representation of the protocol version.
44
func (v ProtocolVersion) String() string {
×
45
        return fmt.Sprintf("V%d", v)
×
46
}
×
47

48
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
49
// execute queries against the SQL graph tables.
50
//
51
//nolint:ll,interfacebloat
52
type SQLQueries interface {
53
        /*
54
                Node queries.
55
        */
56
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
57
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
58
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
59
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
60
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
61
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
62
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
63
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
64
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
65
        DeleteNode(ctx context.Context, id int64) error
66

67
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
68
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
69
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
70
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
71

72
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
73
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
74
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
75
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
76

77
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
78
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
79
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
80
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
81
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
82

83
        /*
84
                Source node queries.
85
        */
86
        AddSourceNode(ctx context.Context, nodeID int64) error
87
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
88

89
        /*
90
                Channel queries.
91
        */
92
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
93
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
94
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
95
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
96
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
97
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
98
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
99
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
100
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
101
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
102
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
103
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
104
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
105
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
106
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
107
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
108
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
109
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
110
        DeleteChannels(ctx context.Context, ids []int64) error
111

112
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
113
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
114
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
115
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
116

117
        /*
118
                Channel Policy table queries.
119
        */
120
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
121
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
122
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
123

124
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
125
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
126
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
127

128
        /*
129
                Zombie index queries.
130
        */
131
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
132
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
133
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
134
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
135
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
136

137
        /*
138
                Prune log table queries.
139
        */
140
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
141
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
142
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
143
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
144

145
        /*
146
                Closed SCID table queries.
147
        */
148
        InsertClosedChannel(ctx context.Context, scid []byte) error
149
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
150
}
151

152
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
153
// database operations.
154
type BatchedSQLQueries interface {
155
        SQLQueries
156
        sqldb.BatchedTx[SQLQueries]
157
}
158

159
// SQLStore is an implementation of the V1Store interface that uses a SQL
160
// database as the backend.
161
type SQLStore struct {
162
        cfg *SQLStoreConfig
163
        db  BatchedSQLQueries
164

165
        // cacheMu guards all caches (rejectCache and chanCache). If
166
        // this mutex will be acquired at the same time as the DB mutex then
167
        // the cacheMu MUST be acquired first to prevent deadlock.
168
        cacheMu     sync.RWMutex
169
        rejectCache *rejectCache
170
        chanCache   *channelCache
171

172
        chanScheduler batch.Scheduler[SQLQueries]
173
        nodeScheduler batch.Scheduler[SQLQueries]
174

175
        srcNodes  map[ProtocolVersion]*srcNodeInfo
176
        srcNodeMu sync.Mutex
177
}
178

179
// A compile-time assertion to ensure that SQLStore implements the V1Store
180
// interface.
181
var _ V1Store = (*SQLStore)(nil)
182

183
// SQLStoreConfig holds the configuration for the SQLStore.
184
type SQLStoreConfig struct {
185
        // ChainHash is the genesis hash for the chain that all the gossip
186
        // messages in this store are aimed at.
187
        ChainHash chainhash.Hash
188

189
        // QueryConfig holds configuration values for SQL queries.
190
        QueryCfg *sqldb.QueryConfig
191
}
192

193
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
194
// storage backend.
195
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
196
        options ...StoreOptionModifier) (*SQLStore, error) {
×
197

×
198
        opts := DefaultOptions()
×
199
        for _, o := range options {
×
200
                o(opts)
×
201
        }
×
202

203
        if opts.NoMigration {
×
204
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
205
                        "supported for SQL stores")
×
206
        }
×
207

208
        s := &SQLStore{
×
209
                cfg:         cfg,
×
210
                db:          db,
×
211
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
212
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
213
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
214
        }
×
215

×
216
        s.chanScheduler = batch.NewTimeScheduler(
×
217
                db, &s.cacheMu, opts.BatchCommitInterval,
×
218
        )
×
219
        s.nodeScheduler = batch.NewTimeScheduler(
×
220
                db, nil, opts.BatchCommitInterval,
×
221
        )
×
222

×
223
        return s, nil
×
224
}
225

226
// AddLightningNode adds a vertex/node to the graph database. If the node is not
227
// in the database from before, this will add a new, unconnected one to the
228
// graph. If it is present from before, this will update that node's
229
// information.
230
//
231
// NOTE: part of the V1Store interface.
232
func (s *SQLStore) AddLightningNode(ctx context.Context,
233
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
234

×
235
        r := &batch.Request[SQLQueries]{
×
236
                Opts: batch.NewSchedulerOptions(opts...),
×
237
                Do: func(queries SQLQueries) error {
×
238
                        _, err := upsertNode(ctx, queries, node)
×
239
                        return err
×
240
                },
×
241
        }
242

243
        return s.nodeScheduler.Execute(ctx, r)
×
244
}
245

246
// FetchLightningNode attempts to look up a target node by its identity public
247
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
248
// returned.
249
//
250
// NOTE: part of the V1Store interface.
251
func (s *SQLStore) FetchLightningNode(ctx context.Context,
252
        pubKey route.Vertex) (*models.LightningNode, error) {
×
253

×
254
        var node *models.LightningNode
×
255
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
256
                var err error
×
257
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
258

×
259
                return err
×
260
        }, sqldb.NoOpReset)
×
261
        if err != nil {
×
262
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
263
        }
×
264

265
        return node, nil
×
266
}
267

268
// HasLightningNode determines if the graph has a vertex identified by the
269
// target node identity public key. If the node exists in the database, a
270
// timestamp of when the data for the node was lasted updated is returned along
271
// with a true boolean. Otherwise, an empty time.Time is returned with a false
272
// boolean.
273
//
274
// NOTE: part of the V1Store interface.
275
func (s *SQLStore) HasLightningNode(ctx context.Context,
276
        pubKey [33]byte) (time.Time, bool, error) {
×
277

×
278
        var (
×
279
                exists     bool
×
280
                lastUpdate time.Time
×
281
        )
×
282
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
283
                dbNode, err := db.GetNodeByPubKey(
×
284
                        ctx, sqlc.GetNodeByPubKeyParams{
×
285
                                Version: int16(ProtocolV1),
×
286
                                PubKey:  pubKey[:],
×
287
                        },
×
288
                )
×
289
                if errors.Is(err, sql.ErrNoRows) {
×
290
                        return nil
×
291
                } else if err != nil {
×
292
                        return fmt.Errorf("unable to fetch node: %w", err)
×
293
                }
×
294

295
                exists = true
×
296

×
297
                if dbNode.LastUpdate.Valid {
×
298
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
299
                }
×
300

301
                return nil
×
302
        }, sqldb.NoOpReset)
303
        if err != nil {
×
304
                return time.Time{}, false,
×
305
                        fmt.Errorf("unable to fetch node: %w", err)
×
306
        }
×
307

308
        return lastUpdate, exists, nil
×
309
}
310

311
// AddrsForNode returns all known addresses for the target node public key
312
// that the graph DB is aware of. The returned boolean indicates if the
313
// given node is unknown to the graph DB or not.
314
//
315
// NOTE: part of the V1Store interface.
316
func (s *SQLStore) AddrsForNode(ctx context.Context,
317
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
318

×
319
        var (
×
320
                addresses []net.Addr
×
321
                known     bool
×
322
        )
×
323
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
324
                // First, check if the node exists and get its DB ID if it
×
325
                // does.
×
326
                dbID, err := db.GetNodeIDByPubKey(
×
327
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
328
                                Version: int16(ProtocolV1),
×
329
                                PubKey:  nodePub.SerializeCompressed(),
×
330
                        },
×
331
                )
×
332
                if errors.Is(err, sql.ErrNoRows) {
×
333
                        return nil
×
334
                }
×
335

336
                known = true
×
337

×
338
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
339
                if err != nil {
×
340
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
341
                                err)
×
342
                }
×
343

344
                return nil
×
345
        }, sqldb.NoOpReset)
346
        if err != nil {
×
347
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
348
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
349
        }
×
350

351
        return known, addresses, nil
×
352
}
353

354
// DeleteLightningNode starts a new database transaction to remove a vertex/node
355
// from the database according to the node's public key.
356
//
357
// NOTE: part of the V1Store interface.
358
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
359
        pubKey route.Vertex) error {
×
360

×
361
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
362
                res, err := db.DeleteNodeByPubKey(
×
363
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
364
                                Version: int16(ProtocolV1),
×
365
                                PubKey:  pubKey[:],
×
366
                        },
×
367
                )
×
368
                if err != nil {
×
369
                        return err
×
370
                }
×
371

372
                rows, err := res.RowsAffected()
×
373
                if err != nil {
×
374
                        return err
×
375
                }
×
376

377
                if rows == 0 {
×
378
                        return ErrGraphNodeNotFound
×
379
                } else if rows > 1 {
×
380
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
381
                }
×
382

383
                return err
×
384
        }, sqldb.NoOpReset)
385
        if err != nil {
×
386
                return fmt.Errorf("unable to delete node: %w", err)
×
387
        }
×
388

389
        return nil
×
390
}
391

392
// FetchNodeFeatures returns the features of the given node. If no features are
393
// known for the node, an empty feature vector is returned.
394
//
395
// NOTE: this is part of the graphdb.NodeTraverser interface.
396
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
397
        *lnwire.FeatureVector, error) {
×
398

×
399
        ctx := context.TODO()
×
400

×
401
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
402
}
×
403

404
// DisabledChannelIDs returns the channel ids of disabled channels.
405
// A channel is disabled when two of the associated ChanelEdgePolicies
406
// have their disabled bit on.
407
//
408
// NOTE: part of the V1Store interface.
409
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
410
        var (
×
411
                ctx     = context.TODO()
×
412
                chanIDs []uint64
×
413
        )
×
414
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
415
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
416
                if err != nil {
×
417
                        return fmt.Errorf("unable to fetch disabled "+
×
418
                                "channels: %w", err)
×
419
                }
×
420

421
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
422

×
423
                return nil
×
424
        }, sqldb.NoOpReset)
425
        if err != nil {
×
426
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
427
                        err)
×
428
        }
×
429

430
        return chanIDs, nil
×
431
}
432

433
// LookupAlias attempts to return the alias as advertised by the target node.
434
//
435
// NOTE: part of the V1Store interface.
436
func (s *SQLStore) LookupAlias(ctx context.Context,
437
        pub *btcec.PublicKey) (string, error) {
×
438

×
439
        var alias string
×
440
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
441
                dbNode, err := db.GetNodeByPubKey(
×
442
                        ctx, sqlc.GetNodeByPubKeyParams{
×
443
                                Version: int16(ProtocolV1),
×
444
                                PubKey:  pub.SerializeCompressed(),
×
445
                        },
×
446
                )
×
447
                if errors.Is(err, sql.ErrNoRows) {
×
448
                        return ErrNodeAliasNotFound
×
449
                } else if err != nil {
×
450
                        return fmt.Errorf("unable to fetch node: %w", err)
×
451
                }
×
452

453
                if !dbNode.Alias.Valid {
×
454
                        return ErrNodeAliasNotFound
×
455
                }
×
456

457
                alias = dbNode.Alias.String
×
458

×
459
                return nil
×
460
        }, sqldb.NoOpReset)
461
        if err != nil {
×
462
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
463
        }
×
464

465
        return alias, nil
×
466
}
467

468
// SourceNode returns the source node of the graph. The source node is treated
469
// as the center node within a star-graph. This method may be used to kick off
470
// a path finding algorithm in order to explore the reachability of another
471
// node based off the source node.
472
//
473
// NOTE: part of the V1Store interface.
474
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
475
        error) {
×
476

×
477
        var node *models.LightningNode
×
478
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
479
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
480
                if err != nil {
×
481
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
482
                                err)
×
483
                }
×
484

485
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
486

×
487
                return err
×
488
        }, sqldb.NoOpReset)
489
        if err != nil {
×
490
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
491
        }
×
492

493
        return node, nil
×
494
}
495

496
// SetSourceNode sets the source node within the graph database. The source
497
// node is to be used as the center of a star-graph within path finding
498
// algorithms.
499
//
500
// NOTE: part of the V1Store interface.
501
func (s *SQLStore) SetSourceNode(ctx context.Context,
502
        node *models.LightningNode) error {
×
503

×
504
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
505
                id, err := upsertNode(ctx, db, node)
×
506
                if err != nil {
×
507
                        return fmt.Errorf("unable to upsert source node: %w",
×
508
                                err)
×
509
                }
×
510

511
                // Make sure that if a source node for this version is already
512
                // set, then the ID is the same as the one we are about to set.
513
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
514
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
515
                        return fmt.Errorf("unable to fetch source node: %w",
×
516
                                err)
×
517
                } else if err == nil {
×
518
                        if dbSourceNodeID != id {
×
519
                                return fmt.Errorf("v1 source node already "+
×
520
                                        "set to a different node: %d vs %d",
×
521
                                        dbSourceNodeID, id)
×
522
                        }
×
523

524
                        return nil
×
525
                }
526

527
                return db.AddSourceNode(ctx, id)
×
528
        }, sqldb.NoOpReset)
529
}
530

531
// NodeUpdatesInHorizon returns all the known lightning node which have an
532
// update timestamp within the passed range. This method can be used by two
533
// nodes to quickly determine if they have the same set of up to date node
534
// announcements.
535
//
536
// NOTE: This is part of the V1Store interface.
537
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
538
        endTime time.Time) ([]models.LightningNode, error) {
×
539

×
540
        ctx := context.TODO()
×
541

×
542
        var nodes []models.LightningNode
×
543
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
544
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
545
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
546
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
547
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
548
                        },
×
549
                )
×
550
                if err != nil {
×
551
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
552
                }
×
553

554
                err = forEachNodeInBatch(
×
NEW
555
                        ctx, s.cfg.QueryCfg, db, dbNodes,
×
556
                        func(_ int64, node *models.LightningNode) error {
×
557
                                nodes = append(nodes, *node)
×
558

×
559
                                return nil
×
560
                        },
×
561
                )
562
                if err != nil {
×
563
                        return fmt.Errorf("unable to build nodes: %w", err)
×
564
                }
×
565

566
                return nil
×
567
        }, sqldb.NoOpReset)
568
        if err != nil {
×
569
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
570
        }
×
571

572
        return nodes, nil
×
573
}
574

575
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
576
// undirected edge from the two target nodes are created. The information stored
577
// denotes the static attributes of the channel, such as the channelID, the keys
578
// involved in creation of the channel, and the set of features that the channel
579
// supports. The chanPoint and chanID are used to uniquely identify the edge
580
// globally within the database.
581
//
582
// NOTE: part of the V1Store interface.
583
func (s *SQLStore) AddChannelEdge(ctx context.Context,
584
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
585

×
586
        var alreadyExists bool
×
587
        r := &batch.Request[SQLQueries]{
×
588
                Opts: batch.NewSchedulerOptions(opts...),
×
589
                Reset: func() {
×
590
                        alreadyExists = false
×
591
                },
×
592
                Do: func(tx SQLQueries) error {
×
593
                        _, err := insertChannel(ctx, tx, edge)
×
594

×
595
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
596
                        // succeed, but propagate the error via local state.
×
597
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
598
                                alreadyExists = true
×
599
                                return nil
×
600
                        }
×
601

602
                        return err
×
603
                },
604
                OnCommit: func(err error) error {
×
605
                        switch {
×
606
                        case err != nil:
×
607
                                return err
×
608
                        case alreadyExists:
×
609
                                return ErrEdgeAlreadyExist
×
610
                        default:
×
611
                                s.rejectCache.remove(edge.ChannelID)
×
612
                                s.chanCache.remove(edge.ChannelID)
×
613
                                return nil
×
614
                        }
615
                },
616
        }
617

618
        return s.chanScheduler.Execute(ctx, r)
×
619
}
620

621
// HighestChanID returns the "highest" known channel ID in the channel graph.
622
// This represents the "newest" channel from the PoV of the chain. This method
623
// can be used by peers to quickly determine if their graphs are in sync.
624
//
625
// NOTE: This is part of the V1Store interface.
626
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
627
        var highestChanID uint64
×
628
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
629
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
630
                if errors.Is(err, sql.ErrNoRows) {
×
631
                        return nil
×
632
                } else if err != nil {
×
633
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
634
                                err)
×
635
                }
×
636

637
                highestChanID = byteOrder.Uint64(chanID)
×
638

×
639
                return nil
×
640
        }, sqldb.NoOpReset)
641
        if err != nil {
×
642
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
643
        }
×
644

645
        return highestChanID, nil
×
646
}
647

648
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
649
// within the database for the referenced channel. The `flags` attribute within
650
// the ChannelEdgePolicy determines which of the directed edges are being
651
// updated. If the flag is 1, then the first node's information is being
652
// updated, otherwise it's the second node's information. The node ordering is
653
// determined by the lexicographical ordering of the identity public keys of the
654
// nodes on either side of the channel.
655
//
656
// NOTE: part of the V1Store interface.
657
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
658
        edge *models.ChannelEdgePolicy,
659
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
660

×
661
        var (
×
662
                isUpdate1    bool
×
663
                edgeNotFound bool
×
664
                from, to     route.Vertex
×
665
        )
×
666

×
667
        r := &batch.Request[SQLQueries]{
×
668
                Opts: batch.NewSchedulerOptions(opts...),
×
669
                Reset: func() {
×
670
                        isUpdate1 = false
×
671
                        edgeNotFound = false
×
672
                },
×
673
                Do: func(tx SQLQueries) error {
×
674
                        var err error
×
675
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
676
                                ctx, tx, edge,
×
677
                        )
×
678
                        if err != nil {
×
679
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
680
                        }
×
681

682
                        // Silence ErrEdgeNotFound so that the batch can
683
                        // succeed, but propagate the error via local state.
684
                        if errors.Is(err, ErrEdgeNotFound) {
×
685
                                edgeNotFound = true
×
686
                                return nil
×
687
                        }
×
688

689
                        return err
×
690
                },
691
                OnCommit: func(err error) error {
×
692
                        switch {
×
693
                        case err != nil:
×
694
                                return err
×
695
                        case edgeNotFound:
×
696
                                return ErrEdgeNotFound
×
697
                        default:
×
698
                                s.updateEdgeCache(edge, isUpdate1)
×
699
                                return nil
×
700
                        }
701
                },
702
        }
703

704
        err := s.chanScheduler.Execute(ctx, r)
×
705

×
706
        return from, to, err
×
707
}
708

709
// updateEdgeCache updates our reject and channel caches with the new
710
// edge policy information.
711
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
712
        isUpdate1 bool) {
×
713

×
714
        // If an entry for this channel is found in reject cache, we'll modify
×
715
        // the entry with the updated timestamp for the direction that was just
×
716
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
717
        // during the next query for this edge.
×
718
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
719
                if isUpdate1 {
×
720
                        entry.upd1Time = e.LastUpdate.Unix()
×
721
                } else {
×
722
                        entry.upd2Time = e.LastUpdate.Unix()
×
723
                }
×
724
                s.rejectCache.insert(e.ChannelID, entry)
×
725
        }
726

727
        // If an entry for this channel is found in channel cache, we'll modify
728
        // the entry with the updated policy for the direction that was just
729
        // written. If the edge doesn't exist, we'll defer loading the info and
730
        // policies and lazily read from disk during the next query.
731
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
732
                if isUpdate1 {
×
733
                        channel.Policy1 = e
×
734
                } else {
×
735
                        channel.Policy2 = e
×
736
                }
×
737
                s.chanCache.insert(e.ChannelID, channel)
×
738
        }
739
}
740

741
// ForEachSourceNodeChannel iterates through all channels of the source node,
742
// executing the passed callback on each. The call-back is provided with the
743
// channel's outpoint, whether we have a policy for the channel and the channel
744
// peer's node information.
745
//
746
// NOTE: part of the V1Store interface.
747
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
748
        cb func(chanPoint wire.OutPoint, havePolicy bool,
749
                otherNode *models.LightningNode) error, reset func()) error {
×
750

×
751
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
752
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
753
                if err != nil {
×
754
                        return fmt.Errorf("unable to fetch source node: %w",
×
755
                                err)
×
756
                }
×
757

758
                return forEachNodeChannel(
×
759
                        ctx, db, s.cfg.ChainHash, nodeID,
×
760
                        func(info *models.ChannelEdgeInfo,
×
761
                                outPolicy *models.ChannelEdgePolicy,
×
762
                                _ *models.ChannelEdgePolicy) error {
×
763

×
764
                                // Fetch the other node.
×
765
                                var (
×
766
                                        otherNodePub [33]byte
×
767
                                        node1        = info.NodeKey1Bytes
×
768
                                        node2        = info.NodeKey2Bytes
×
769
                                )
×
770
                                switch {
×
771
                                case bytes.Equal(node1[:], nodePub[:]):
×
772
                                        otherNodePub = node2
×
773
                                case bytes.Equal(node2[:], nodePub[:]):
×
774
                                        otherNodePub = node1
×
775
                                default:
×
776
                                        return fmt.Errorf("node not " +
×
777
                                                "participating in this channel")
×
778
                                }
779

780
                                _, otherNode, err := getNodeByPubKey(
×
781
                                        ctx, db, otherNodePub,
×
782
                                )
×
783
                                if err != nil {
×
784
                                        return fmt.Errorf("unable to fetch "+
×
785
                                                "other node(%x): %w",
×
786
                                                otherNodePub, err)
×
787
                                }
×
788

789
                                return cb(
×
790
                                        info.ChannelPoint, outPolicy != nil,
×
791
                                        otherNode,
×
792
                                )
×
793
                        },
794
                )
795
        }, reset)
796
}
797

798
// ForEachNode iterates through all the stored vertices/nodes in the graph,
799
// executing the passed callback with each node encountered. If the callback
800
// returns an error, then the transaction is aborted and the iteration stops
801
// early. Any operations performed on the NodeTx passed to the call-back are
802
// executed under the same read transaction and so, methods on the NodeTx object
803
// _MUST_ only be called from within the call-back.
804
//
805
// NOTE: part of the V1Store interface.
806
func (s *SQLStore) ForEachNode(ctx context.Context,
807
        cb func(tx NodeRTx) error, reset func()) error {
×
808

×
809
        var lastID int64
×
810

×
811
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
812
                nodeCB := func(dbID int64, node *models.LightningNode) error {
×
813
                        err := cb(newSQLGraphNodeTx(
×
814
                                db, s.cfg.ChainHash, dbID, node,
×
815
                        ))
×
816
                        if err != nil {
×
817
                                return fmt.Errorf("callback failed for "+
×
818
                                        "node(id=%d): %w", dbID, err)
×
819
                        }
×
820
                        lastID = dbID
×
821

×
822
                        return nil
×
823
                }
824

825
                for {
×
826
                        nodes, err := db.ListNodesPaginated(
×
827
                                ctx, sqlc.ListNodesPaginatedParams{
×
828
                                        Version: int16(ProtocolV1),
×
829
                                        ID:      lastID,
×
NEW
830
                                        Limit:   s.cfg.QueryCfg.MaxPageSize,
×
831
                                },
×
832
                        )
×
833
                        if err != nil {
×
834
                                return fmt.Errorf("unable to fetch nodes: %w",
×
835
                                        err)
×
836
                        }
×
837

838
                        if len(nodes) == 0 {
×
839
                                break
×
840
                        }
841

842
                        err = forEachNodeInBatch(
×
NEW
843
                                ctx, s.cfg.QueryCfg, db, nodes, nodeCB,
×
844
                        )
×
845
                        if err != nil {
×
846
                                return fmt.Errorf("unable to iterate over "+
×
847
                                        "nodes: %w", err)
×
848
                        }
×
849
                }
850

851
                return nil
×
852
        }, reset)
853
}
854

855
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
856
// SQLStore and a SQL transaction.
857
type sqlGraphNodeTx struct {
858
        db    SQLQueries
859
        id    int64
860
        node  *models.LightningNode
861
        chain chainhash.Hash
862
}
863

864
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
865
// interface.
866
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
867

868
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
869
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
870

×
871
        return &sqlGraphNodeTx{
×
872
                db:    db,
×
873
                chain: chain,
×
874
                id:    id,
×
875
                node:  node,
×
876
        }
×
877
}
×
878

879
// Node returns the raw information of the node.
880
//
881
// NOTE: This is a part of the NodeRTx interface.
882
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
883
        return s.node
×
884
}
×
885

886
// ForEachChannel can be used to iterate over the node's channels under the same
887
// transaction used to fetch the node.
888
//
889
// NOTE: This is a part of the NodeRTx interface.
890
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
891
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
892

×
893
        ctx := context.TODO()
×
894

×
895
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
896
}
×
897

898
// FetchNode fetches the node with the given pub key under the same transaction
899
// used to fetch the current node. The returned node is also a NodeRTx and any
900
// operations on that NodeRTx will also be done under the same transaction.
901
//
902
// NOTE: This is a part of the NodeRTx interface.
903
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
904
        ctx := context.TODO()
×
905

×
906
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
907
        if err != nil {
×
908
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
909
                        nodePub, err)
×
910
        }
×
911

912
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
913
}
914

915
// ForEachNodeDirectedChannel iterates through all channels of a given node,
916
// executing the passed callback on the directed edge representing the channel
917
// and its incoming policy. If the callback returns an error, then the iteration
918
// is halted with the error propagated back up to the caller.
919
//
920
// Unknown policies are passed into the callback as nil values.
921
//
922
// NOTE: this is part of the graphdb.NodeTraverser interface.
923
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
924
        cb func(channel *DirectedChannel) error, reset func()) error {
×
925

×
926
        var ctx = context.TODO()
×
927

×
928
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
929
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
930
        }, reset)
×
931
}
932

933
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
934
// graph, executing the passed callback with each node encountered. If the
935
// callback returns an error, then the transaction is aborted and the iteration
936
// stops early.
937
//
938
// NOTE: This is a part of the V1Store interface.
939
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
940
        cb func(route.Vertex, *lnwire.FeatureVector) error,
941
        reset func()) error {
×
942

×
943
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
944
                return forEachNodeCacheable(
×
NEW
945
                        ctx, s.cfg.QueryCfg, db,
×
NEW
946
                        func(nodeID int64, nodePub route.Vertex) error {
×
NEW
947
                                features, err := getNodeFeatures(
×
NEW
948
                                        ctx, db, nodeID,
×
NEW
949
                                )
×
NEW
950
                                if err != nil {
×
NEW
951
                                        return fmt.Errorf("unable to fetch "+
×
NEW
952
                                                "node features: %w", err)
×
NEW
953
                                }
×
954

NEW
955
                                return cb(nodePub, features)
×
956
                        },
957
                )
958
        }, reset)
959
        if err != nil {
×
960
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
961
        }
×
962

963
        return nil
×
964
}
965

966
// ForEachNodeChannel iterates through all channels of the given node,
967
// executing the passed callback with an edge info structure and the policies
968
// of each end of the channel. The first edge policy is the outgoing edge *to*
969
// the connecting node, while the second is the incoming edge *from* the
970
// connecting node. If the callback returns an error, then the iteration is
971
// halted with the error propagated back up to the caller.
972
//
973
// Unknown policies are passed into the callback as nil values.
974
//
975
// NOTE: part of the V1Store interface.
976
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
977
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
978
                *models.ChannelEdgePolicy) error, reset func()) error {
×
979

×
980
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
981
                dbNode, err := db.GetNodeByPubKey(
×
982
                        ctx, sqlc.GetNodeByPubKeyParams{
×
983
                                Version: int16(ProtocolV1),
×
984
                                PubKey:  nodePub[:],
×
985
                        },
×
986
                )
×
987
                if errors.Is(err, sql.ErrNoRows) {
×
988
                        return nil
×
989
                } else if err != nil {
×
990
                        return fmt.Errorf("unable to fetch node: %w", err)
×
991
                }
×
992

993
                return forEachNodeChannel(
×
994
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
995
                )
×
996
        }, reset)
997
}
998

999
// ChanUpdatesInHorizon returns all the known channel edges which have at least
1000
// one edge that has an update timestamp within the specified horizon.
1001
//
1002
// NOTE: This is part of the V1Store interface.
1003
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
1004
        endTime time.Time) ([]ChannelEdge, error) {
×
1005

×
1006
        s.cacheMu.Lock()
×
1007
        defer s.cacheMu.Unlock()
×
1008

×
1009
        var (
×
1010
                ctx = context.TODO()
×
1011
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
1012
                // an additional map to keep track of the edges already seen to
×
1013
                // prevent re-adding it.
×
1014
                edgesSeen    = make(map[uint64]struct{})
×
1015
                edgesToCache = make(map[uint64]ChannelEdge)
×
1016
                edges        []ChannelEdge
×
1017
                hits         int
×
1018
        )
×
1019
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1020
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1021
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1022
                                Version:   int16(ProtocolV1),
×
1023
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1024
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1025
                        },
×
1026
                )
×
1027
                if err != nil {
×
1028
                        return err
×
1029
                }
×
1030

1031
                for _, row := range rows {
×
1032
                        // If we've already retrieved the info and policies for
×
1033
                        // this edge, then we can skip it as we don't need to do
×
1034
                        // so again.
×
1035
                        chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid)
×
1036
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1037
                                continue
×
1038
                        }
1039

1040
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1041
                                hits++
×
1042
                                edgesSeen[chanIDInt] = struct{}{}
×
1043
                                edges = append(edges, channel)
×
1044

×
1045
                                continue
×
1046
                        }
1047

1048
                        node1, node2, err := buildNodes(
×
1049
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
1050
                        )
×
1051
                        if err != nil {
×
1052
                                return err
×
1053
                        }
×
1054

1055
                        channel, err := getAndBuildEdgeInfo(
×
1056
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
1057
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1058
                        )
×
1059
                        if err != nil {
×
1060
                                return fmt.Errorf("unable to build channel "+
×
1061
                                        "info: %w", err)
×
1062
                        }
×
1063

1064
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1065
                        if err != nil {
×
1066
                                return fmt.Errorf("unable to extract channel "+
×
1067
                                        "policies: %w", err)
×
1068
                        }
×
1069

1070
                        p1, p2, err := getAndBuildChanPolicies(
×
1071
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1072
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1073
                        )
×
1074
                        if err != nil {
×
1075
                                return fmt.Errorf("unable to build channel "+
×
1076
                                        "policies: %w", err)
×
1077
                        }
×
1078

1079
                        edgesSeen[chanIDInt] = struct{}{}
×
1080
                        chanEdge := ChannelEdge{
×
1081
                                Info:    channel,
×
1082
                                Policy1: p1,
×
1083
                                Policy2: p2,
×
1084
                                Node1:   node1,
×
1085
                                Node2:   node2,
×
1086
                        }
×
1087
                        edges = append(edges, chanEdge)
×
1088
                        edgesToCache[chanIDInt] = chanEdge
×
1089
                }
1090

1091
                return nil
×
1092
        }, func() {
×
1093
                edgesSeen = make(map[uint64]struct{})
×
1094
                edgesToCache = make(map[uint64]ChannelEdge)
×
1095
                edges = nil
×
1096
        })
×
1097
        if err != nil {
×
1098
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1099
        }
×
1100

1101
        // Insert any edges loaded from disk into the cache.
1102
        for chanid, channel := range edgesToCache {
×
1103
                s.chanCache.insert(chanid, channel)
×
1104
        }
×
1105

1106
        if len(edges) > 0 {
×
1107
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1108
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1109
        } else {
×
1110
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1111
                        "horizon (%s, %s)", startTime, endTime)
×
1112
        }
×
1113

1114
        return edges, nil
×
1115
}
1116

1117
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1118
// data to the call-back.
1119
//
1120
// NOTE: The callback contents MUST not be modified.
1121
//
1122
// NOTE: part of the V1Store interface.
1123
func (s *SQLStore) ForEachNodeCached(ctx context.Context,
1124
        cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error,
1125
        reset func()) error {
×
1126

×
NEW
1127
        handleNode := func(db SQLQueries, nodeID int64,
×
NEW
1128
                nodePub route.Vertex) error {
×
NEW
1129

×
NEW
1130
                features, err := getNodeFeatures(ctx, db, nodeID)
×
NEW
1131
                if err != nil {
×
NEW
1132
                        return fmt.Errorf("unable to fetch node(id=%d) "+
×
NEW
1133
                                "features: %w", nodeID, err)
×
NEW
1134
                }
×
1135

NEW
1136
                toNodeCallback := func() route.Vertex {
×
NEW
1137
                        return nodePub
×
NEW
1138
                }
×
1139

NEW
1140
                rows, err := db.ListChannelsByNodeID(
×
NEW
1141
                        ctx, sqlc.ListChannelsByNodeIDParams{
×
NEW
1142
                                Version: int16(ProtocolV1),
×
NEW
1143
                                NodeID1: nodeID,
×
NEW
1144
                        },
×
NEW
1145
                )
×
NEW
1146
                if err != nil {
×
NEW
1147
                        return fmt.Errorf("unable to fetch channels of "+
×
NEW
1148
                                "node(id=%d): %w", nodeID, err)
×
NEW
1149
                }
×
1150

NEW
1151
                channels := make(map[uint64]*DirectedChannel, len(rows))
×
NEW
1152
                for _, row := range rows {
×
NEW
1153
                        node1, node2, err := buildNodeVertices(
×
NEW
1154
                                row.Node1Pubkey, row.Node2Pubkey,
×
NEW
1155
                        )
×
NEW
1156
                        if err != nil {
×
NEW
1157
                                return err
×
NEW
1158
                        }
×
1159

NEW
1160
                        e, err := getAndBuildEdgeInfo(
×
NEW
1161
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
NEW
1162
                                node1, node2,
×
NEW
1163
                        )
×
1164
                        if err != nil {
×
NEW
1165
                                return fmt.Errorf("unable to build channel "+
×
NEW
1166
                                        "info: %w", err)
×
UNCOV
1167
                        }
×
1168

NEW
1169
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
NEW
1170
                        if err != nil {
×
NEW
1171
                                return fmt.Errorf("unable to extract channel "+
×
NEW
1172
                                        "policies: %w", err)
×
UNCOV
1173
                        }
×
1174

NEW
1175
                        p1, p2, err := getAndBuildChanPolicies(
×
NEW
1176
                                ctx, db, dbPol1, dbPol2, e.ChannelID, node1,
×
NEW
1177
                                node2,
×
1178
                        )
×
1179
                        if err != nil {
×
NEW
1180
                                return fmt.Errorf("unable to build channel "+
×
NEW
1181
                                        "policies: %w", err)
×
UNCOV
1182
                        }
×
1183

1184
                        // Determine the outgoing and incoming policy
1185
                        // for this channel and node combo.
NEW
1186
                        outPolicy, inPolicy := p1, p2
×
NEW
1187
                        if p1 != nil && p1.ToNode == nodePub {
×
NEW
1188
                                outPolicy, inPolicy = p2, p1
×
NEW
1189
                        } else if p2 != nil && p2.ToNode != nodePub {
×
NEW
1190
                                outPolicy, inPolicy = p2, p1
×
NEW
1191
                        }
×
1192

NEW
1193
                        var cachedInPolicy *models.CachedEdgePolicy
×
NEW
1194
                        if inPolicy != nil {
×
NEW
1195
                                cachedInPolicy = models.NewCachedPolicy(
×
NEW
1196
                                        inPolicy,
×
1197
                                )
×
NEW
1198
                                cachedInPolicy.ToNodePubKey = toNodeCallback
×
NEW
1199
                                cachedInPolicy.ToNodeFeatures = features
×
NEW
1200
                        }
×
1201

NEW
1202
                        var inboundFee lnwire.Fee
×
NEW
1203
                        if outPolicy != nil {
×
NEW
1204
                                outPolicy.InboundFee.WhenSome(
×
NEW
1205
                                        func(fee lnwire.Fee) {
×
NEW
1206
                                                inboundFee = fee
×
NEW
1207
                                        },
×
1208
                                )
1209
                        }
1210

NEW
1211
                        directedChannel := &DirectedChannel{
×
NEW
1212
                                ChannelID:    e.ChannelID,
×
NEW
1213
                                IsNode1:      nodePub == e.NodeKey1Bytes,
×
NEW
1214
                                OtherNode:    e.NodeKey2Bytes,
×
NEW
1215
                                Capacity:     e.Capacity,
×
NEW
1216
                                OutPolicySet: outPolicy != nil,
×
NEW
1217
                                InPolicy:     cachedInPolicy,
×
NEW
1218
                                InboundFee:   inboundFee,
×
NEW
1219
                        }
×
1220

×
NEW
1221
                        if nodePub == e.NodeKey2Bytes {
×
NEW
1222
                                directedChannel.OtherNode = e.NodeKey1Bytes
×
NEW
1223
                        }
×
1224

NEW
1225
                        channels[e.ChannelID] = directedChannel
×
1226
                }
1227

NEW
1228
                return cb(nodePub, channels)
×
1229
        }
1230

NEW
1231
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1232
                return forEachNodeCacheable(
×
NEW
1233
                        ctx, s.cfg.QueryCfg, db,
×
NEW
1234
                        func(nodeID int64, nodePub route.Vertex) error {
×
NEW
1235
                                return handleNode(db, nodeID, nodePub)
×
NEW
1236
                        },
×
1237
                )
1238
        }, reset)
1239
}
1240

1241
// ForEachChannelCacheable iterates through all the channel edges stored
1242
// within the graph and invokes the passed callback for each edge. The
1243
// callback takes two edges as since this is a directed graph, both the
1244
// in/out edges are visited. If the callback returns an error, then the
1245
// transaction is aborted and the iteration stops early.
1246
//
1247
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1248
// pointer for that particular channel edge routing policy will be
1249
// passed into the callback.
1250
//
1251
// NOTE: this method is like ForEachChannel but fetches only the data
1252
// required for the graph cache.
1253
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1254
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1255
        reset func()) error {
×
1256

×
1257
        ctx := context.TODO()
×
1258

×
NEW
1259
        handleChannel := func(_ context.Context,
×
1260
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1261

×
1262
                node1, node2, err := buildNodeVertices(
×
1263
                        row.Node1Pubkey, row.Node2Pubkey,
×
1264
                )
×
1265
                if err != nil {
×
1266
                        return err
×
1267
                }
×
1268

1269
                edge := buildCacheableChannelInfo(
×
1270
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1271
                )
×
1272

×
1273
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1274
                if err != nil {
×
1275
                        return err
×
1276
                }
×
1277

NEW
1278
                pol1, pol2, err := buildCachedChanPolicies(
×
NEW
1279
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
NEW
1280
                )
×
NEW
1281
                if err != nil {
×
NEW
1282
                        return err
×
1283
                }
×
1284

NEW
1285
                return cb(edge, pol1, pol2)
×
1286
        }
1287

NEW
1288
        extractCursor := func(
×
NEW
1289
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1290

×
NEW
1291
                return row.ID
×
UNCOV
1292
        }
×
1293

1294
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1295
                //nolint:ll
×
NEW
1296
                queryFunc := func(ctx context.Context, lastID int64,
×
NEW
1297
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
NEW
1298
                        error) {
×
NEW
1299

×
NEW
1300
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1301
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1302
                                        Version: int16(ProtocolV1),
×
1303
                                        ID:      lastID,
×
NEW
1304
                                        Limit:   limit,
×
1305
                                },
×
1306
                        )
×
1307
                }
×
1308

NEW
1309
                return sqldb.ExecutePaginatedQuery(
×
NEW
1310
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
NEW
1311
                        extractCursor, handleChannel,
×
NEW
1312
                )
×
1313
        }, reset)
1314
}
1315

1316
// ForEachChannel iterates through all the channel edges stored within the
1317
// graph and invokes the passed callback for each edge. The callback takes two
1318
// edges as since this is a directed graph, both the in/out edges are visited.
1319
// If the callback returns an error, then the transaction is aborted and the
1320
// iteration stops early.
1321
//
1322
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1323
// for that particular channel edge routing policy will be passed into the
1324
// callback.
1325
//
1326
// NOTE: part of the V1Store interface.
1327
func (s *SQLStore) ForEachChannel(ctx context.Context,
1328
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1329
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1330

×
1331
        handleChannel := func(db SQLQueries, batchData *batchChannelData,
×
1332
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1333

×
1334
                node1, node2, err := buildNodeVertices(
×
1335
                        row.Node1Pubkey, row.Node2Pubkey,
×
1336
                )
×
1337
                if err != nil {
×
1338
                        return fmt.Errorf("unable to build node vertices: %w",
×
1339
                                err)
×
1340
                }
×
1341

1342
                edge, err := buildEdgeInfoWithBatchData(
×
1343
                        s.cfg.ChainHash, row.GraphChannel, node1, node2,
×
1344
                        batchData,
×
1345
                )
×
1346
                if err != nil {
×
1347
                        return fmt.Errorf("unable to build channel info: %w",
×
1348
                                err)
×
1349
                }
×
1350

1351
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1352
                if err != nil {
×
1353
                        return fmt.Errorf("unable to extract channel "+
×
1354
                                "policies: %w", err)
×
1355
                }
×
1356

1357
                p1, p2, err := buildChanPoliciesWithBatchData(
×
1358
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
1359
                )
×
1360
                if err != nil {
×
1361
                        return fmt.Errorf("unable to build channel "+
×
1362
                                "policies: %w", err)
×
1363
                }
×
1364

1365
                err = cb(edge, p1, p2)
×
1366
                if err != nil {
×
1367
                        return fmt.Errorf("callback failed for channel "+
×
1368
                                "id=%d: %w", edge.ChannelID, err)
×
1369
                }
×
1370

1371
                return nil
×
1372
        }
1373

1374
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1375
                lastID := int64(-1)
×
1376
                for {
×
1377
                        //nolint:ll
×
1378
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1379
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1380
                                        Version: int16(ProtocolV1),
×
1381
                                        ID:      lastID,
×
NEW
1382
                                        Limit:   s.cfg.QueryCfg.MaxPageSize,
×
1383
                                },
×
1384
                        )
×
1385
                        if err != nil {
×
1386
                                return err
×
1387
                        }
×
1388

1389
                        if len(rows) == 0 {
×
1390
                                break
×
1391
                        }
1392

1393
                        // Collect the channel & policy IDs that we want to
1394
                        // do a batch collection for.
1395
                        var (
×
1396
                                channelIDs = make([]int64, len(rows))
×
1397
                                policyIDs  = make([]int64, 0, len(rows)*2)
×
1398
                        )
×
1399
                        for i, row := range rows {
×
1400
                                channelIDs[i] = row.GraphChannel.ID
×
1401

×
1402
                                // Extract policy IDs from the row
×
1403
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1404
                                        row,
×
1405
                                )
×
1406
                                if err != nil {
×
1407
                                        return fmt.Errorf("unable to extract "+
×
1408
                                                "channel policies: %w", err)
×
1409
                                }
×
1410

1411
                                if dbPol1 != nil {
×
1412
                                        policyIDs = append(policyIDs, dbPol1.ID)
×
1413
                                }
×
1414

1415
                                if dbPol2 != nil {
×
1416
                                        policyIDs = append(policyIDs, dbPol2.ID)
×
1417
                                }
×
1418
                        }
1419

1420
                        batchData, err := batchLoadChannelData(
×
NEW
1421
                                ctx, s.cfg.QueryCfg, db, channelIDs,
×
1422
                                policyIDs,
×
1423
                        )
×
1424
                        if err != nil {
×
1425
                                return fmt.Errorf("unable to batch load "+
×
1426
                                        "channel data: %w", err)
×
1427
                        }
×
1428

1429
                        for _, row := range rows {
×
1430
                                err := handleChannel(db, batchData, row)
×
1431
                                if err != nil {
×
1432
                                        return err
×
1433
                                }
×
1434

1435
                                lastID = row.GraphChannel.ID
×
1436
                        }
1437
                }
1438

1439
                return nil
×
1440
        }, reset)
1441
}
1442

1443
// FilterChannelRange returns the channel ID's of all known channels which were
1444
// mined in a block height within the passed range. The channel IDs are grouped
1445
// by their common block height. This method can be used to quickly share with a
1446
// peer the set of channels we know of within a particular range to catch them
1447
// up after a period of time offline. If withTimestamps is true then the
1448
// timestamp info of the latest received channel update messages of the channel
1449
// will be included in the response.
1450
//
1451
// NOTE: This is part of the V1Store interface.
1452
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1453
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1454

×
1455
        var (
×
1456
                ctx       = context.TODO()
×
1457
                startSCID = &lnwire.ShortChannelID{
×
1458
                        BlockHeight: startHeight,
×
1459
                }
×
1460
                endSCID = lnwire.ShortChannelID{
×
1461
                        BlockHeight: endHeight,
×
1462
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1463
                        TxPosition:  math.MaxUint16,
×
1464
                }
×
1465
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1466
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1467
        )
×
1468

×
1469
        // 1) get all channels where channelID is between start and end chan ID.
×
1470
        // 2) skip if not public (ie, no channel_proof)
×
1471
        // 3) collect that channel.
×
1472
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1473
        //    and add those timestamps to the collected channel.
×
1474
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1475
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1476
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1477
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1478
                                StartScid: chanIDStart,
×
1479
                                EndScid:   chanIDEnd,
×
1480
                        },
×
1481
                )
×
1482
                if err != nil {
×
1483
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1484
                                err)
×
1485
                }
×
1486

1487
                for _, dbChan := range dbChans {
×
1488
                        cid := lnwire.NewShortChanIDFromInt(
×
1489
                                byteOrder.Uint64(dbChan.Scid),
×
1490
                        )
×
1491
                        chanInfo := NewChannelUpdateInfo(
×
1492
                                cid, time.Time{}, time.Time{},
×
1493
                        )
×
1494

×
1495
                        if !withTimestamps {
×
1496
                                channelsPerBlock[cid.BlockHeight] = append(
×
1497
                                        channelsPerBlock[cid.BlockHeight],
×
1498
                                        chanInfo,
×
1499
                                )
×
1500

×
1501
                                continue
×
1502
                        }
1503

1504
                        //nolint:ll
1505
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1506
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1507
                                        Version:   int16(ProtocolV1),
×
1508
                                        ChannelID: dbChan.ID,
×
1509
                                        NodeID:    dbChan.NodeID1,
×
1510
                                },
×
1511
                        )
×
1512
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1513
                                return fmt.Errorf("unable to fetch node1 "+
×
1514
                                        "policy: %w", err)
×
1515
                        } else if err == nil {
×
1516
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1517
                                        node1Policy.LastUpdate.Int64, 0,
×
1518
                                )
×
1519
                        }
×
1520

1521
                        //nolint:ll
1522
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1523
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1524
                                        Version:   int16(ProtocolV1),
×
1525
                                        ChannelID: dbChan.ID,
×
1526
                                        NodeID:    dbChan.NodeID2,
×
1527
                                },
×
1528
                        )
×
1529
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1530
                                return fmt.Errorf("unable to fetch node2 "+
×
1531
                                        "policy: %w", err)
×
1532
                        } else if err == nil {
×
1533
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1534
                                        node2Policy.LastUpdate.Int64, 0,
×
1535
                                )
×
1536
                        }
×
1537

1538
                        channelsPerBlock[cid.BlockHeight] = append(
×
1539
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1540
                        )
×
1541
                }
1542

1543
                return nil
×
1544
        }, func() {
×
1545
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1546
        })
×
1547
        if err != nil {
×
1548
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1549
        }
×
1550

1551
        if len(channelsPerBlock) == 0 {
×
1552
                return nil, nil
×
1553
        }
×
1554

1555
        // Return the channel ranges in ascending block height order.
1556
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1557
        slices.Sort(blocks)
×
1558

×
1559
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1560
                return BlockChannelRange{
×
1561
                        Height:   block,
×
1562
                        Channels: channelsPerBlock[block],
×
1563
                }
×
1564
        }), nil
×
1565
}
1566

1567
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1568
// zombie. This method is used on an ad-hoc basis, when channels need to be
1569
// marked as zombies outside the normal pruning cycle.
1570
//
1571
// NOTE: part of the V1Store interface.
1572
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1573
        pubKey1, pubKey2 [33]byte) error {
×
1574

×
1575
        ctx := context.TODO()
×
1576

×
1577
        s.cacheMu.Lock()
×
1578
        defer s.cacheMu.Unlock()
×
1579

×
1580
        chanIDB := channelIDToBytes(chanID)
×
1581

×
1582
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1583
                return db.UpsertZombieChannel(
×
1584
                        ctx, sqlc.UpsertZombieChannelParams{
×
1585
                                Version:  int16(ProtocolV1),
×
1586
                                Scid:     chanIDB,
×
1587
                                NodeKey1: pubKey1[:],
×
1588
                                NodeKey2: pubKey2[:],
×
1589
                        },
×
1590
                )
×
1591
        }, sqldb.NoOpReset)
×
1592
        if err != nil {
×
1593
                return fmt.Errorf("unable to upsert zombie channel "+
×
1594
                        "(channel_id=%d): %w", chanID, err)
×
1595
        }
×
1596

1597
        s.rejectCache.remove(chanID)
×
1598
        s.chanCache.remove(chanID)
×
1599

×
1600
        return nil
×
1601
}
1602

1603
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1604
//
1605
// NOTE: part of the V1Store interface.
1606
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1607
        s.cacheMu.Lock()
×
1608
        defer s.cacheMu.Unlock()
×
1609

×
1610
        var (
×
1611
                ctx     = context.TODO()
×
1612
                chanIDB = channelIDToBytes(chanID)
×
1613
        )
×
1614

×
1615
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1616
                res, err := db.DeleteZombieChannel(
×
1617
                        ctx, sqlc.DeleteZombieChannelParams{
×
1618
                                Scid:    chanIDB,
×
1619
                                Version: int16(ProtocolV1),
×
1620
                        },
×
1621
                )
×
1622
                if err != nil {
×
1623
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1624
                                err)
×
1625
                }
×
1626

1627
                rows, err := res.RowsAffected()
×
1628
                if err != nil {
×
1629
                        return err
×
1630
                }
×
1631

1632
                if rows == 0 {
×
1633
                        return ErrZombieEdgeNotFound
×
1634
                } else if rows > 1 {
×
1635
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1636
                                "expected 1", rows)
×
1637
                }
×
1638

1639
                return nil
×
1640
        }, sqldb.NoOpReset)
1641
        if err != nil {
×
1642
                return fmt.Errorf("unable to mark edge live "+
×
1643
                        "(channel_id=%d): %w", chanID, err)
×
1644
        }
×
1645

1646
        s.rejectCache.remove(chanID)
×
1647
        s.chanCache.remove(chanID)
×
1648

×
1649
        return err
×
1650
}
1651

1652
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1653
// zombie, then the two node public keys corresponding to this edge are also
1654
// returned.
1655
//
1656
// NOTE: part of the V1Store interface.
1657
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1658
        error) {
×
1659

×
1660
        var (
×
1661
                ctx              = context.TODO()
×
1662
                isZombie         bool
×
1663
                pubKey1, pubKey2 route.Vertex
×
1664
                chanIDB          = channelIDToBytes(chanID)
×
1665
        )
×
1666

×
1667
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1668
                zombie, err := db.GetZombieChannel(
×
1669
                        ctx, sqlc.GetZombieChannelParams{
×
1670
                                Scid:    chanIDB,
×
1671
                                Version: int16(ProtocolV1),
×
1672
                        },
×
1673
                )
×
1674
                if errors.Is(err, sql.ErrNoRows) {
×
1675
                        return nil
×
1676
                }
×
1677
                if err != nil {
×
1678
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1679
                                err)
×
1680
                }
×
1681

1682
                copy(pubKey1[:], zombie.NodeKey1)
×
1683
                copy(pubKey2[:], zombie.NodeKey2)
×
1684
                isZombie = true
×
1685

×
1686
                return nil
×
1687
        }, sqldb.NoOpReset)
1688
        if err != nil {
×
1689
                return false, route.Vertex{}, route.Vertex{},
×
1690
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1691
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1692
        }
×
1693

1694
        return isZombie, pubKey1, pubKey2, nil
×
1695
}
1696

1697
// NumZombies returns the current number of zombie channels in the graph.
1698
//
1699
// NOTE: part of the V1Store interface.
1700
func (s *SQLStore) NumZombies() (uint64, error) {
×
1701
        var (
×
1702
                ctx        = context.TODO()
×
1703
                numZombies uint64
×
1704
        )
×
1705
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1706
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1707
                if err != nil {
×
1708
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1709
                                err)
×
1710
                }
×
1711

1712
                numZombies = uint64(count)
×
1713

×
1714
                return nil
×
1715
        }, sqldb.NoOpReset)
1716
        if err != nil {
×
1717
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1718
        }
×
1719

1720
        return numZombies, nil
×
1721
}
1722

1723
// DeleteChannelEdges removes edges with the given channel IDs from the
1724
// database and marks them as zombies. This ensures that we're unable to re-add
1725
// it to our database once again. If an edge does not exist within the
1726
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1727
// true, then when we mark these edges as zombies, we'll set up the keys such
1728
// that we require the node that failed to send the fresh update to be the one
1729
// that resurrects the channel from its zombie state. The markZombie bool
1730
// denotes whether to mark the channel as a zombie.
1731
//
1732
// NOTE: part of the V1Store interface.
1733
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1734
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1735

×
1736
        s.cacheMu.Lock()
×
1737
        defer s.cacheMu.Unlock()
×
1738

×
1739
        // Keep track of which channels we end up finding so that we can
×
1740
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1741
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1742
        for _, chanID := range chanIDs {
×
1743
                chanLookup[chanID] = struct{}{}
×
1744
        }
×
1745

1746
        var (
×
1747
                ctx     = context.TODO()
×
1748
                deleted []*models.ChannelEdgeInfo
×
1749
        )
×
1750
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1751
                chanIDsToDelete := make([]int64, 0, len(chanIDs))
×
1752
                chanCallBack := func(ctx context.Context,
×
1753
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1754

×
1755
                        // Deleting the entry from the map indicates that we
×
1756
                        // have found the channel.
×
1757
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1758
                        delete(chanLookup, scid)
×
1759

×
1760
                        node1, node2, err := buildNodeVertices(
×
1761
                                row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1762
                        )
×
1763
                        if err != nil {
×
1764
                                return err
×
1765
                        }
×
1766

1767
                        info, err := getAndBuildEdgeInfo(
×
1768
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
1769
                                node1, node2,
×
1770
                        )
×
1771
                        if err != nil {
×
1772
                                return err
×
1773
                        }
×
1774

1775
                        deleted = append(deleted, info)
×
1776
                        chanIDsToDelete = append(
×
1777
                                chanIDsToDelete, row.GraphChannel.ID,
×
1778
                        )
×
1779

×
1780
                        if !markZombie {
×
1781
                                return nil
×
1782
                        }
×
1783

1784
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1785
                                info.NodeKey2Bytes
×
1786
                        if strictZombiePruning {
×
1787
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1788
                                if row.Policy1LastUpdate.Valid {
×
1789
                                        e1Time := time.Unix(
×
1790
                                                row.Policy1LastUpdate.Int64, 0,
×
1791
                                        )
×
1792
                                        e1UpdateTime = &e1Time
×
1793
                                }
×
1794
                                if row.Policy2LastUpdate.Valid {
×
1795
                                        e2Time := time.Unix(
×
1796
                                                row.Policy2LastUpdate.Int64, 0,
×
1797
                                        )
×
1798
                                        e2UpdateTime = &e2Time
×
1799
                                }
×
1800

1801
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1802
                                        info, e1UpdateTime, e2UpdateTime,
×
1803
                                )
×
1804
                        }
1805

1806
                        err = db.UpsertZombieChannel(
×
1807
                                ctx, sqlc.UpsertZombieChannelParams{
×
1808
                                        Version:  int16(ProtocolV1),
×
1809
                                        Scid:     channelIDToBytes(scid),
×
1810
                                        NodeKey1: nodeKey1[:],
×
1811
                                        NodeKey2: nodeKey2[:],
×
1812
                                },
×
1813
                        )
×
1814
                        if err != nil {
×
1815
                                return fmt.Errorf("unable to mark channel as "+
×
1816
                                        "zombie: %w", err)
×
1817
                        }
×
1818

1819
                        return nil
×
1820
                }
1821

1822
                err := s.forEachChanWithPoliciesInSCIDList(
×
1823
                        ctx, db, chanCallBack, chanIDs,
×
1824
                )
×
1825
                if err != nil {
×
1826
                        return err
×
1827
                }
×
1828

1829
                if len(chanLookup) > 0 {
×
1830
                        return ErrEdgeNotFound
×
1831
                }
×
1832

1833
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1834
        }, func() {
×
1835
                deleted = nil
×
1836

×
1837
                // Re-fill the lookup map.
×
1838
                for _, chanID := range chanIDs {
×
1839
                        chanLookup[chanID] = struct{}{}
×
1840
                }
×
1841
        })
1842
        if err != nil {
×
1843
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1844
                        err)
×
1845
        }
×
1846

1847
        for _, chanID := range chanIDs {
×
1848
                s.rejectCache.remove(chanID)
×
1849
                s.chanCache.remove(chanID)
×
1850
        }
×
1851

1852
        return deleted, nil
×
1853
}
1854

1855
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1856
// channel identified by the channel ID. If the channel can't be found, then
1857
// ErrEdgeNotFound is returned. A struct which houses the general information
1858
// for the channel itself is returned as well as two structs that contain the
1859
// routing policies for the channel in either direction.
1860
//
1861
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1862
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1863
// the ChannelEdgeInfo will only include the public keys of each node.
1864
//
1865
// NOTE: part of the V1Store interface.
1866
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1867
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1868
        *models.ChannelEdgePolicy, error) {
×
1869

×
1870
        var (
×
1871
                ctx              = context.TODO()
×
1872
                edge             *models.ChannelEdgeInfo
×
1873
                policy1, policy2 *models.ChannelEdgePolicy
×
1874
                chanIDB          = channelIDToBytes(chanID)
×
1875
        )
×
1876
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1877
                row, err := db.GetChannelBySCIDWithPolicies(
×
1878
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1879
                                Scid:    chanIDB,
×
1880
                                Version: int16(ProtocolV1),
×
1881
                        },
×
1882
                )
×
1883
                if errors.Is(err, sql.ErrNoRows) {
×
1884
                        // First check if this edge is perhaps in the zombie
×
1885
                        // index.
×
1886
                        zombie, err := db.GetZombieChannel(
×
1887
                                ctx, sqlc.GetZombieChannelParams{
×
1888
                                        Scid:    chanIDB,
×
1889
                                        Version: int16(ProtocolV1),
×
1890
                                },
×
1891
                        )
×
1892
                        if errors.Is(err, sql.ErrNoRows) {
×
1893
                                return ErrEdgeNotFound
×
1894
                        } else if err != nil {
×
1895
                                return fmt.Errorf("unable to check if "+
×
1896
                                        "channel is zombie: %w", err)
×
1897
                        }
×
1898

1899
                        // At this point, we know the channel is a zombie, so
1900
                        // we'll return an error indicating this, and we will
1901
                        // populate the edge info with the public keys of each
1902
                        // party as this is the only information we have about
1903
                        // it.
1904
                        edge = &models.ChannelEdgeInfo{}
×
1905
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1906
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1907

×
1908
                        return ErrZombieEdge
×
1909
                } else if err != nil {
×
1910
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1911
                }
×
1912

1913
                node1, node2, err := buildNodeVertices(
×
1914
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1915
                )
×
1916
                if err != nil {
×
1917
                        return err
×
1918
                }
×
1919

1920
                edge, err = getAndBuildEdgeInfo(
×
1921
                        ctx, db, s.cfg.ChainHash, row.GraphChannel, node1,
×
1922
                        node2,
×
1923
                )
×
1924
                if err != nil {
×
1925
                        return fmt.Errorf("unable to build channel info: %w",
×
1926
                                err)
×
1927
                }
×
1928

1929
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1930
                if err != nil {
×
1931
                        return fmt.Errorf("unable to extract channel "+
×
1932
                                "policies: %w", err)
×
1933
                }
×
1934

1935
                policy1, policy2, err = getAndBuildChanPolicies(
×
1936
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1937
                )
×
1938
                if err != nil {
×
1939
                        return fmt.Errorf("unable to build channel "+
×
1940
                                "policies: %w", err)
×
1941
                }
×
1942

1943
                return nil
×
1944
        }, sqldb.NoOpReset)
1945
        if err != nil {
×
1946
                // If we are returning the ErrZombieEdge, then we also need to
×
1947
                // return the edge info as the method comment indicates that
×
1948
                // this will be populated when the edge is a zombie.
×
1949
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1950
                        err)
×
1951
        }
×
1952

1953
        return edge, policy1, policy2, nil
×
1954
}
1955

1956
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1957
// the channel identified by the funding outpoint. If the channel can't be
1958
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1959
// information for the channel itself is returned as well as two structs that
1960
// contain the routing policies for the channel in either direction.
1961
//
1962
// NOTE: part of the V1Store interface.
1963
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1964
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1965
        *models.ChannelEdgePolicy, error) {
×
1966

×
1967
        var (
×
1968
                ctx              = context.TODO()
×
1969
                edge             *models.ChannelEdgeInfo
×
1970
                policy1, policy2 *models.ChannelEdgePolicy
×
1971
        )
×
1972
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1973
                row, err := db.GetChannelByOutpointWithPolicies(
×
1974
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1975
                                Outpoint: op.String(),
×
1976
                                Version:  int16(ProtocolV1),
×
1977
                        },
×
1978
                )
×
1979
                if errors.Is(err, sql.ErrNoRows) {
×
1980
                        return ErrEdgeNotFound
×
1981
                } else if err != nil {
×
1982
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1983
                }
×
1984

1985
                node1, node2, err := buildNodeVertices(
×
1986
                        row.Node1Pubkey, row.Node2Pubkey,
×
1987
                )
×
1988
                if err != nil {
×
1989
                        return err
×
1990
                }
×
1991

1992
                edge, err = getAndBuildEdgeInfo(
×
1993
                        ctx, db, s.cfg.ChainHash, row.GraphChannel, node1,
×
1994
                        node2,
×
1995
                )
×
1996
                if err != nil {
×
1997
                        return fmt.Errorf("unable to build channel info: %w",
×
1998
                                err)
×
1999
                }
×
2000

2001
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2002
                if err != nil {
×
2003
                        return fmt.Errorf("unable to extract channel "+
×
2004
                                "policies: %w", err)
×
2005
                }
×
2006

2007
                policy1, policy2, err = getAndBuildChanPolicies(
×
2008
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
2009
                )
×
2010
                if err != nil {
×
2011
                        return fmt.Errorf("unable to build channel "+
×
2012
                                "policies: %w", err)
×
2013
                }
×
2014

2015
                return nil
×
2016
        }, sqldb.NoOpReset)
2017
        if err != nil {
×
2018
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2019
                        err)
×
2020
        }
×
2021

2022
        return edge, policy1, policy2, nil
×
2023
}
2024

2025
// HasChannelEdge returns true if the database knows of a channel edge with the
2026
// passed channel ID, and false otherwise. If an edge with that ID is found
2027
// within the graph, then two time stamps representing the last time the edge
2028
// was updated for both directed edges are returned along with the boolean. If
2029
// it is not found, then the zombie index is checked and its result is returned
2030
// as the second boolean.
2031
//
2032
// NOTE: part of the V1Store interface.
2033
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
2034
        bool, error) {
×
2035

×
2036
        ctx := context.TODO()
×
2037

×
2038
        var (
×
2039
                exists          bool
×
2040
                isZombie        bool
×
2041
                node1LastUpdate time.Time
×
2042
                node2LastUpdate time.Time
×
2043
        )
×
2044

×
2045
        // We'll query the cache with the shared lock held to allow multiple
×
2046
        // readers to access values in the cache concurrently if they exist.
×
2047
        s.cacheMu.RLock()
×
2048
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2049
                s.cacheMu.RUnlock()
×
2050
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2051
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2052
                exists, isZombie = entry.flags.unpack()
×
2053

×
2054
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2055
        }
×
2056
        s.cacheMu.RUnlock()
×
2057

×
2058
        s.cacheMu.Lock()
×
2059
        defer s.cacheMu.Unlock()
×
2060

×
2061
        // The item was not found with the shared lock, so we'll acquire the
×
2062
        // exclusive lock and check the cache again in case another method added
×
2063
        // the entry to the cache while no lock was held.
×
2064
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2065
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2066
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2067
                exists, isZombie = entry.flags.unpack()
×
2068

×
2069
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2070
        }
×
2071

2072
        chanIDB := channelIDToBytes(chanID)
×
2073
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2074
                channel, err := db.GetChannelBySCID(
×
2075
                        ctx, sqlc.GetChannelBySCIDParams{
×
2076
                                Scid:    chanIDB,
×
2077
                                Version: int16(ProtocolV1),
×
2078
                        },
×
2079
                )
×
2080
                if errors.Is(err, sql.ErrNoRows) {
×
2081
                        // Check if it is a zombie channel.
×
2082
                        isZombie, err = db.IsZombieChannel(
×
2083
                                ctx, sqlc.IsZombieChannelParams{
×
2084
                                        Scid:    chanIDB,
×
2085
                                        Version: int16(ProtocolV1),
×
2086
                                },
×
2087
                        )
×
2088
                        if err != nil {
×
2089
                                return fmt.Errorf("could not check if channel "+
×
2090
                                        "is zombie: %w", err)
×
2091
                        }
×
2092

2093
                        return nil
×
2094
                } else if err != nil {
×
2095
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2096
                }
×
2097

2098
                exists = true
×
2099

×
2100
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2101
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2102
                                Version:   int16(ProtocolV1),
×
2103
                                ChannelID: channel.ID,
×
2104
                                NodeID:    channel.NodeID1,
×
2105
                        },
×
2106
                )
×
2107
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2108
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2109
                                err)
×
2110
                } else if err == nil {
×
2111
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2112
                }
×
2113

2114
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2115
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2116
                                Version:   int16(ProtocolV1),
×
2117
                                ChannelID: channel.ID,
×
2118
                                NodeID:    channel.NodeID2,
×
2119
                        },
×
2120
                )
×
2121
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2122
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2123
                                err)
×
2124
                } else if err == nil {
×
2125
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2126
                }
×
2127

2128
                return nil
×
2129
        }, sqldb.NoOpReset)
2130
        if err != nil {
×
2131
                return time.Time{}, time.Time{}, false, false,
×
2132
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2133
        }
×
2134

2135
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2136
                upd1Time: node1LastUpdate.Unix(),
×
2137
                upd2Time: node2LastUpdate.Unix(),
×
2138
                flags:    packRejectFlags(exists, isZombie),
×
2139
        })
×
2140

×
2141
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2142
}
2143

2144
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2145
// passed channel point (outpoint). If the passed channel doesn't exist within
2146
// the database, then ErrEdgeNotFound is returned.
2147
//
2148
// NOTE: part of the V1Store interface.
2149
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2150
        var (
×
2151
                ctx       = context.TODO()
×
2152
                channelID uint64
×
2153
        )
×
2154
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2155
                chanID, err := db.GetSCIDByOutpoint(
×
2156
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2157
                                Outpoint: chanPoint.String(),
×
2158
                                Version:  int16(ProtocolV1),
×
2159
                        },
×
2160
                )
×
2161
                if errors.Is(err, sql.ErrNoRows) {
×
2162
                        return ErrEdgeNotFound
×
2163
                } else if err != nil {
×
2164
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2165
                                err)
×
2166
                }
×
2167

2168
                channelID = byteOrder.Uint64(chanID)
×
2169

×
2170
                return nil
×
2171
        }, sqldb.NoOpReset)
2172
        if err != nil {
×
2173
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2174
        }
×
2175

2176
        return channelID, nil
×
2177
}
2178

2179
// IsPublicNode is a helper method that determines whether the node with the
2180
// given public key is seen as a public node in the graph from the graph's
2181
// source node's point of view.
2182
//
2183
// NOTE: part of the V1Store interface.
2184
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2185
        ctx := context.TODO()
×
2186

×
2187
        var isPublic bool
×
2188
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2189
                var err error
×
2190
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2191

×
2192
                return err
×
2193
        }, sqldb.NoOpReset)
×
2194
        if err != nil {
×
2195
                return false, fmt.Errorf("unable to check if node is "+
×
2196
                        "public: %w", err)
×
2197
        }
×
2198

2199
        return isPublic, nil
×
2200
}
2201

2202
// FetchChanInfos returns the set of channel edges that correspond to the passed
2203
// channel ID's. If an edge is the query is unknown to the database, it will
2204
// skipped and the result will contain only those edges that exist at the time
2205
// of the query. This can be used to respond to peer queries that are seeking to
2206
// fill in gaps in their view of the channel graph.
2207
//
2208
// NOTE: part of the V1Store interface.
2209
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2210
        var (
×
2211
                ctx   = context.TODO()
×
2212
                edges = make(map[uint64]ChannelEdge)
×
2213
        )
×
2214
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2215
                chanCallBack := func(ctx context.Context,
×
2216
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2217

×
2218
                        node1, node2, err := buildNodes(
×
2219
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
2220
                        )
×
2221
                        if err != nil {
×
2222
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2223
                                        err)
×
2224
                        }
×
2225

2226
                        edge, err := getAndBuildEdgeInfo(
×
2227
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
2228
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2229
                        )
×
2230
                        if err != nil {
×
2231
                                return fmt.Errorf("unable to build "+
×
2232
                                        "channel info: %w", err)
×
2233
                        }
×
2234

2235
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2236
                        if err != nil {
×
2237
                                return fmt.Errorf("unable to extract channel "+
×
2238
                                        "policies: %w", err)
×
2239
                        }
×
2240

2241
                        p1, p2, err := getAndBuildChanPolicies(
×
2242
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2243
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2244
                        )
×
2245
                        if err != nil {
×
2246
                                return fmt.Errorf("unable to build channel "+
×
2247
                                        "policies: %w", err)
×
2248
                        }
×
2249

2250
                        edges[edge.ChannelID] = ChannelEdge{
×
2251
                                Info:    edge,
×
2252
                                Policy1: p1,
×
2253
                                Policy2: p2,
×
2254
                                Node1:   node1,
×
2255
                                Node2:   node2,
×
2256
                        }
×
2257

×
2258
                        return nil
×
2259
                }
2260

2261
                return s.forEachChanWithPoliciesInSCIDList(
×
2262
                        ctx, db, chanCallBack, chanIDs,
×
2263
                )
×
2264
        }, func() {
×
2265
                clear(edges)
×
2266
        })
×
2267
        if err != nil {
×
2268
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2269
        }
×
2270

2271
        res := make([]ChannelEdge, 0, len(edges))
×
2272
        for _, chanID := range chanIDs {
×
2273
                edge, ok := edges[chanID]
×
2274
                if !ok {
×
2275
                        continue
×
2276
                }
2277

2278
                res = append(res, edge)
×
2279
        }
2280

2281
        return res, nil
×
2282
}
2283

2284
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2285
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2286
// channels in a paginated manner.
2287
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2288
        db SQLQueries, cb func(ctx context.Context,
2289
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2290
        chanIDs []uint64) error {
×
2291

×
2292
        queryWrapper := func(ctx context.Context,
×
2293
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2294
                error) {
×
2295

×
2296
                return db.GetChannelsBySCIDWithPolicies(
×
2297
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2298
                                Version: int16(ProtocolV1),
×
2299
                                Scids:   scids,
×
2300
                        },
×
2301
                )
×
2302
        }
×
2303

NEW
2304
        return sqldb.ExecuteBatchQuery(
×
NEW
2305
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
NEW
2306
                cb,
×
UNCOV
2307
        )
×
2308
}
2309

2310
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2311
// ID's that we don't know and are not known zombies of the passed set. In other
2312
// words, we perform a set difference of our set of chan ID's and the ones
2313
// passed in. This method can be used by callers to determine the set of
2314
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2315
// known zombies is also returned.
2316
//
2317
// NOTE: part of the V1Store interface.
2318
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2319
        []ChannelUpdateInfo, error) {
×
2320

×
2321
        var (
×
2322
                ctx          = context.TODO()
×
2323
                newChanIDs   []uint64
×
2324
                knownZombies []ChannelUpdateInfo
×
2325
                infoLookup   = make(
×
2326
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2327
                )
×
2328
        )
×
2329

×
2330
        // We first build a lookup map of the channel ID's to the
×
2331
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2332
        // already know about.
×
2333
        for _, chanInfo := range chansInfo {
×
2334
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2335
        }
×
2336

2337
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2338
                // The call-back function deletes known channels from
×
2339
                // infoLookup, so that we can later check which channels are
×
2340
                // zombies by only looking at the remaining channels in the set.
×
2341
                cb := func(ctx context.Context,
×
2342
                        channel sqlc.GraphChannel) error {
×
2343

×
2344
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2345

×
2346
                        return nil
×
2347
                }
×
2348

2349
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2350
                if err != nil {
×
2351
                        return fmt.Errorf("unable to iterate through "+
×
2352
                                "channels: %w", err)
×
2353
                }
×
2354

2355
                // We want to ensure that we deal with the channels in the
2356
                // same order that they were passed in, so we iterate over the
2357
                // original chansInfo slice and then check if that channel is
2358
                // still in the infoLookup map.
2359
                for _, chanInfo := range chansInfo {
×
2360
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2361
                        if _, ok := infoLookup[channelID]; !ok {
×
2362
                                continue
×
2363
                        }
2364

2365
                        isZombie, err := db.IsZombieChannel(
×
2366
                                ctx, sqlc.IsZombieChannelParams{
×
2367
                                        Scid:    channelIDToBytes(channelID),
×
2368
                                        Version: int16(ProtocolV1),
×
2369
                                },
×
2370
                        )
×
2371
                        if err != nil {
×
2372
                                return fmt.Errorf("unable to fetch zombie "+
×
2373
                                        "channel: %w", err)
×
2374
                        }
×
2375

2376
                        if isZombie {
×
2377
                                knownZombies = append(knownZombies, chanInfo)
×
2378

×
2379
                                continue
×
2380
                        }
2381

2382
                        newChanIDs = append(newChanIDs, channelID)
×
2383
                }
2384

2385
                return nil
×
2386
        }, func() {
×
2387
                newChanIDs = nil
×
2388
                knownZombies = nil
×
2389
                // Rebuild the infoLookup map in case of a rollback.
×
2390
                for _, chanInfo := range chansInfo {
×
2391
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2392
                        infoLookup[scid] = chanInfo
×
2393
                }
×
2394
        })
2395
        if err != nil {
×
2396
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2397
        }
×
2398

2399
        return newChanIDs, knownZombies, nil
×
2400
}
2401

2402
// forEachChanInSCIDList is a helper method that executes a paged query
2403
// against the database to fetch all channels that match the passed
2404
// ChannelUpdateInfo slice. The callback function is called for each channel
2405
// that is found.
2406
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2407
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2408
        chansInfo []ChannelUpdateInfo) error {
×
2409

×
2410
        queryWrapper := func(ctx context.Context,
×
2411
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2412

×
2413
                return db.GetChannelsBySCIDs(
×
2414
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2415
                                Version: int16(ProtocolV1),
×
2416
                                Scids:   scids,
×
2417
                        },
×
2418
                )
×
2419
        }
×
2420

2421
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2422
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2423

×
2424
                return channelIDToBytes(channelID)
×
2425
        }
×
2426

NEW
2427
        return sqldb.ExecuteBatchQuery(
×
NEW
2428
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
NEW
2429
                cb,
×
UNCOV
2430
        )
×
2431
}
2432

2433
// PruneGraphNodes is a garbage collection method which attempts to prune out
2434
// any nodes from the channel graph that are currently unconnected. This ensure
2435
// that we only maintain a graph of reachable nodes. In the event that a pruned
2436
// node gains more channels, it will be re-added back to the graph.
2437
//
2438
// NOTE: this prunes nodes across protocol versions. It will never prune the
2439
// source nodes.
2440
//
2441
// NOTE: part of the V1Store interface.
2442
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2443
        var ctx = context.TODO()
×
2444

×
2445
        var prunedNodes []route.Vertex
×
2446
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2447
                var err error
×
2448
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2449

×
2450
                return err
×
2451
        }, func() {
×
2452
                prunedNodes = nil
×
2453
        })
×
2454
        if err != nil {
×
2455
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2456
        }
×
2457

2458
        return prunedNodes, nil
×
2459
}
2460

2461
// PruneGraph prunes newly closed channels from the channel graph in response
2462
// to a new block being solved on the network. Any transactions which spend the
2463
// funding output of any known channels within he graph will be deleted.
2464
// Additionally, the "prune tip", or the last block which has been used to
2465
// prune the graph is stored so callers can ensure the graph is fully in sync
2466
// with the current UTXO state. A slice of channels that have been closed by
2467
// the target block along with any pruned nodes are returned if the function
2468
// succeeds without error.
2469
//
2470
// NOTE: part of the V1Store interface.
2471
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2472
        blockHash *chainhash.Hash, blockHeight uint32) (
2473
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2474

×
2475
        ctx := context.TODO()
×
2476

×
2477
        s.cacheMu.Lock()
×
2478
        defer s.cacheMu.Unlock()
×
2479

×
2480
        var (
×
2481
                closedChans []*models.ChannelEdgeInfo
×
2482
                prunedNodes []route.Vertex
×
2483
        )
×
2484
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2485
                var chansToDelete []int64
×
2486

×
2487
                // Define the callback function for processing each channel.
×
2488
                channelCallback := func(ctx context.Context,
×
2489
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2490

×
2491
                        node1, node2, err := buildNodeVertices(
×
2492
                                row.Node1Pubkey, row.Node2Pubkey,
×
2493
                        )
×
2494
                        if err != nil {
×
2495
                                return err
×
2496
                        }
×
2497

2498
                        info, err := getAndBuildEdgeInfo(
×
2499
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
2500
                                node1, node2,
×
2501
                        )
×
2502
                        if err != nil {
×
2503
                                return err
×
2504
                        }
×
2505

2506
                        closedChans = append(closedChans, info)
×
2507
                        chansToDelete = append(
×
2508
                                chansToDelete, row.GraphChannel.ID,
×
2509
                        )
×
2510

×
2511
                        return nil
×
2512
                }
2513

2514
                err := s.forEachChanInOutpoints(
×
2515
                        ctx, db, spentOutputs, channelCallback,
×
2516
                )
×
2517
                if err != nil {
×
2518
                        return fmt.Errorf("unable to fetch channels by "+
×
2519
                                "outpoints: %w", err)
×
2520
                }
×
2521

2522
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2523
                if err != nil {
×
2524
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2525
                }
×
2526

2527
                err = db.UpsertPruneLogEntry(
×
2528
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2529
                                BlockHash:   blockHash[:],
×
2530
                                BlockHeight: int64(blockHeight),
×
2531
                        },
×
2532
                )
×
2533
                if err != nil {
×
2534
                        return fmt.Errorf("unable to insert prune log "+
×
2535
                                "entry: %w", err)
×
2536
                }
×
2537

2538
                // Now that we've pruned some channels, we'll also prune any
2539
                // nodes that no longer have any channels.
2540
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2541
                if err != nil {
×
2542
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2543
                                err)
×
2544
                }
×
2545

2546
                return nil
×
2547
        }, func() {
×
2548
                prunedNodes = nil
×
2549
                closedChans = nil
×
2550
        })
×
2551
        if err != nil {
×
2552
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2553
        }
×
2554

2555
        for _, channel := range closedChans {
×
2556
                s.rejectCache.remove(channel.ChannelID)
×
2557
                s.chanCache.remove(channel.ChannelID)
×
2558
        }
×
2559

2560
        return closedChans, prunedNodes, nil
×
2561
}
2562

2563
// forEachChanInOutpoints is a helper function that executes a paginated
2564
// query to fetch channels by their outpoints and applies the given call-back
2565
// to each.
2566
//
2567
// NOTE: this fetches channels for all protocol versions.
2568
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2569
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2570
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2571

×
2572
        // Create a wrapper that uses the transaction's db instance to execute
×
2573
        // the query.
×
2574
        queryWrapper := func(ctx context.Context,
×
2575
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2576
                error) {
×
2577

×
2578
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2579
        }
×
2580

2581
        // Define the conversion function from Outpoint to string.
2582
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2583
                return outpoint.String()
×
2584
        }
×
2585

NEW
2586
        return sqldb.ExecuteBatchQuery(
×
NEW
2587
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2588
                queryWrapper, cb,
×
2589
        )
×
2590
}
2591

2592
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2593
        dbIDs []int64) error {
×
2594

×
2595
        // Create a wrapper that uses the transaction's db instance to execute
×
2596
        // the query.
×
2597
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2598
                return nil, db.DeleteChannels(ctx, ids)
×
2599
        }
×
2600

2601
        idConverter := func(id int64) int64 {
×
2602
                return id
×
2603
        }
×
2604

NEW
2605
        return sqldb.ExecuteBatchQuery(
×
NEW
2606
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2607
                queryWrapper, func(ctx context.Context, _ any) error {
×
2608
                        return nil
×
2609
                },
×
2610
        )
2611
}
2612

2613
// ChannelView returns the verifiable edge information for each active channel
2614
// within the known channel graph. The set of UTXOs (along with their scripts)
2615
// returned are the ones that need to be watched on chain to detect channel
2616
// closes on the resident blockchain.
2617
//
2618
// NOTE: part of the V1Store interface.
2619
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2620
        var (
×
2621
                ctx        = context.TODO()
×
2622
                edgePoints []EdgePoint
×
2623
        )
×
2624

×
NEW
2625
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2626
                handleChannel := func(_ context.Context,
×
NEW
2627
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2628

×
NEW
2629
                        pkScript, err := genMultiSigP2WSH(
×
NEW
2630
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
NEW
2631
                        )
×
NEW
2632
                        if err != nil {
×
NEW
2633
                                return err
×
NEW
2634
                        }
×
2635

NEW
2636
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
NEW
2637
                        if err != nil {
×
NEW
2638
                                return err
×
NEW
2639
                        }
×
2640

NEW
2641
                        edgePoints = append(edgePoints, EdgePoint{
×
NEW
2642
                                FundingPkScript: pkScript,
×
NEW
2643
                                OutPoint:        *op,
×
NEW
2644
                        })
×
2645

×
NEW
2646
                        return nil
×
2647
                }
2648

NEW
2649
                queryFunc := func(ctx context.Context, lastID int64,
×
NEW
2650
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
NEW
2651

×
NEW
2652
                        return db.ListChannelsPaginated(
×
2653
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2654
                                        Version: int16(ProtocolV1),
×
2655
                                        ID:      lastID,
×
NEW
2656
                                        Limit:   limit,
×
2657
                                },
×
2658
                        )
×
NEW
2659
                }
×
2660

NEW
2661
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
NEW
2662
                        return row.ID
×
UNCOV
2663
                }
×
2664

NEW
2665
                return sqldb.ExecutePaginatedQuery(
×
NEW
2666
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
NEW
2667
                        extractCursor, handleChannel,
×
NEW
2668
                )
×
2669
        }, func() {
×
2670
                edgePoints = nil
×
2671
        })
×
2672
        if err != nil {
×
2673
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2674
        }
×
2675

2676
        return edgePoints, nil
×
2677
}
2678

2679
// PruneTip returns the block height and hash of the latest block that has been
2680
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2681
// to tell if the graph is currently in sync with the current best known UTXO
2682
// state.
2683
//
2684
// NOTE: part of the V1Store interface.
2685
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2686
        var (
×
2687
                ctx       = context.TODO()
×
2688
                tipHash   chainhash.Hash
×
2689
                tipHeight uint32
×
2690
        )
×
2691
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2692
                pruneTip, err := db.GetPruneTip(ctx)
×
2693
                if errors.Is(err, sql.ErrNoRows) {
×
2694
                        return ErrGraphNeverPruned
×
2695
                } else if err != nil {
×
2696
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2697
                }
×
2698

2699
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2700
                tipHeight = uint32(pruneTip.BlockHeight)
×
2701

×
2702
                return nil
×
2703
        }, sqldb.NoOpReset)
2704
        if err != nil {
×
2705
                return nil, 0, err
×
2706
        }
×
2707

2708
        return &tipHash, tipHeight, nil
×
2709
}
2710

2711
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2712
//
2713
// NOTE: this prunes nodes across protocol versions. It will never prune the
2714
// source nodes.
2715
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2716
        db SQLQueries) ([]route.Vertex, error) {
×
2717

×
2718
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2719
        if err != nil {
×
2720
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2721
                        "nodes: %w", err)
×
2722
        }
×
2723

2724
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2725
        for i, nodeKey := range nodeKeys {
×
2726
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2727
                if err != nil {
×
2728
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2729
                                "from bytes: %w", err)
×
2730
                }
×
2731

2732
                prunedNodes[i] = pub
×
2733
        }
2734

2735
        return prunedNodes, nil
×
2736
}
2737

2738
// DisconnectBlockAtHeight is used to indicate that the block specified
2739
// by the passed height has been disconnected from the main chain. This
2740
// will "rewind" the graph back to the height below, deleting channels
2741
// that are no longer confirmed from the graph. The prune log will be
2742
// set to the last prune height valid for the remaining chain.
2743
// Channels that were removed from the graph resulting from the
2744
// disconnected block are returned.
2745
//
2746
// NOTE: part of the V1Store interface.
2747
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2748
        []*models.ChannelEdgeInfo, error) {
×
2749

×
2750
        ctx := context.TODO()
×
2751

×
2752
        var (
×
2753
                // Every channel having a ShortChannelID starting at 'height'
×
2754
                // will no longer be confirmed.
×
2755
                startShortChanID = lnwire.ShortChannelID{
×
2756
                        BlockHeight: height,
×
2757
                }
×
2758

×
2759
                // Delete everything after this height from the db up until the
×
2760
                // SCID alias range.
×
2761
                endShortChanID = aliasmgr.StartingAlias
×
2762

×
2763
                removedChans []*models.ChannelEdgeInfo
×
2764

×
2765
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2766
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2767
        )
×
2768

×
2769
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2770
                rows, err := db.GetChannelsBySCIDRange(
×
2771
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2772
                                StartScid: chanIDStart,
×
2773
                                EndScid:   chanIDEnd,
×
2774
                        },
×
2775
                )
×
2776
                if err != nil {
×
2777
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2778
                }
×
2779

2780
                chanIDsToDelete := make([]int64, len(rows))
×
2781
                for i, row := range rows {
×
2782
                        node1, node2, err := buildNodeVertices(
×
2783
                                row.Node1PubKey, row.Node2PubKey,
×
2784
                        )
×
2785
                        if err != nil {
×
2786
                                return err
×
2787
                        }
×
2788

2789
                        channel, err := getAndBuildEdgeInfo(
×
2790
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
2791
                                node1, node2,
×
2792
                        )
×
2793
                        if err != nil {
×
2794
                                return err
×
2795
                        }
×
2796

2797
                        chanIDsToDelete[i] = row.GraphChannel.ID
×
2798
                        removedChans = append(removedChans, channel)
×
2799
                }
2800

2801
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2802
                if err != nil {
×
2803
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2804
                }
×
2805

2806
                return db.DeletePruneLogEntriesInRange(
×
2807
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2808
                                StartHeight: int64(height),
×
2809
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2810
                        },
×
2811
                )
×
2812
        }, func() {
×
2813
                removedChans = nil
×
2814
        })
×
2815
        if err != nil {
×
2816
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2817
                        "height: %w", err)
×
2818
        }
×
2819

2820
        for _, channel := range removedChans {
×
2821
                s.rejectCache.remove(channel.ChannelID)
×
2822
                s.chanCache.remove(channel.ChannelID)
×
2823
        }
×
2824

2825
        return removedChans, nil
×
2826
}
2827

2828
// AddEdgeProof sets the proof of an existing edge in the graph database.
2829
//
2830
// NOTE: part of the V1Store interface.
2831
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2832
        proof *models.ChannelAuthProof) error {
×
2833

×
2834
        var (
×
2835
                ctx       = context.TODO()
×
2836
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2837
        )
×
2838

×
2839
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2840
                res, err := db.AddV1ChannelProof(
×
2841
                        ctx, sqlc.AddV1ChannelProofParams{
×
2842
                                Scid:              scidBytes,
×
2843
                                Node1Signature:    proof.NodeSig1Bytes,
×
2844
                                Node2Signature:    proof.NodeSig2Bytes,
×
2845
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2846
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2847
                        },
×
2848
                )
×
2849
                if err != nil {
×
2850
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2851
                }
×
2852

2853
                n, err := res.RowsAffected()
×
2854
                if err != nil {
×
2855
                        return err
×
2856
                }
×
2857

2858
                if n == 0 {
×
2859
                        return fmt.Errorf("no rows affected when adding edge "+
×
2860
                                "proof for SCID %v", scid)
×
2861
                } else if n > 1 {
×
2862
                        return fmt.Errorf("multiple rows affected when adding "+
×
2863
                                "edge proof for SCID %v: %d rows affected",
×
2864
                                scid, n)
×
2865
                }
×
2866

2867
                return nil
×
2868
        }, sqldb.NoOpReset)
2869
        if err != nil {
×
2870
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2871
        }
×
2872

2873
        return nil
×
2874
}
2875

2876
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2877
// that we can ignore channel announcements that we know to be closed without
2878
// having to validate them and fetch a block.
2879
//
2880
// NOTE: part of the V1Store interface.
2881
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2882
        var (
×
2883
                ctx     = context.TODO()
×
2884
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2885
        )
×
2886

×
2887
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2888
                return db.InsertClosedChannel(ctx, chanIDB)
×
2889
        }, sqldb.NoOpReset)
×
2890
}
2891

2892
// IsClosedScid checks whether a channel identified by the passed in scid is
2893
// closed. This helps avoid having to perform expensive validation checks.
2894
//
2895
// NOTE: part of the V1Store interface.
2896
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2897
        var (
×
2898
                ctx      = context.TODO()
×
2899
                isClosed bool
×
2900
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2901
        )
×
2902
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2903
                var err error
×
2904
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2905
                if err != nil {
×
2906
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2907
                                err)
×
2908
                }
×
2909

2910
                return nil
×
2911
        }, sqldb.NoOpReset)
2912
        if err != nil {
×
2913
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2914
                        err)
×
2915
        }
×
2916

2917
        return isClosed, nil
×
2918
}
2919

2920
// GraphSession will provide the call-back with access to a NodeTraverser
2921
// instance which can be used to perform queries against the channel graph.
2922
//
2923
// NOTE: part of the V1Store interface.
2924
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
2925
        reset func()) error {
×
2926

×
2927
        var ctx = context.TODO()
×
2928

×
2929
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2930
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2931
        }, reset)
×
2932
}
2933

2934
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2935
// read only transaction for a consistent view of the graph.
2936
type sqlNodeTraverser struct {
2937
        db    SQLQueries
2938
        chain chainhash.Hash
2939
}
2940

2941
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2942
// NodeTraverser interface.
2943
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2944

2945
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2946
func newSQLNodeTraverser(db SQLQueries,
2947
        chain chainhash.Hash) *sqlNodeTraverser {
×
2948

×
2949
        return &sqlNodeTraverser{
×
2950
                db:    db,
×
2951
                chain: chain,
×
2952
        }
×
2953
}
×
2954

2955
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2956
// node.
2957
//
2958
// NOTE: Part of the NodeTraverser interface.
2959
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2960
        cb func(channel *DirectedChannel) error, _ func()) error {
×
2961

×
2962
        ctx := context.TODO()
×
2963

×
2964
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2965
}
×
2966

2967
// FetchNodeFeatures returns the features of the given node. If the node is
2968
// unknown, assume no additional features are supported.
2969
//
2970
// NOTE: Part of the NodeTraverser interface.
2971
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2972
        *lnwire.FeatureVector, error) {
×
2973

×
2974
        ctx := context.TODO()
×
2975

×
2976
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2977
}
×
2978

2979
// forEachNodeDirectedChannel iterates through all channels of a given
2980
// node, executing the passed callback on the directed edge representing the
2981
// channel and its incoming policy. If the node is not found, no error is
2982
// returned.
2983
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2984
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2985

×
2986
        toNodeCallback := func() route.Vertex {
×
2987
                return nodePub
×
2988
        }
×
2989

2990
        dbID, err := db.GetNodeIDByPubKey(
×
2991
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2992
                        Version: int16(ProtocolV1),
×
2993
                        PubKey:  nodePub[:],
×
2994
                },
×
2995
        )
×
2996
        if errors.Is(err, sql.ErrNoRows) {
×
2997
                return nil
×
2998
        } else if err != nil {
×
2999
                return fmt.Errorf("unable to fetch node: %w", err)
×
3000
        }
×
3001

3002
        rows, err := db.ListChannelsByNodeID(
×
3003
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3004
                        Version: int16(ProtocolV1),
×
3005
                        NodeID1: dbID,
×
3006
                },
×
3007
        )
×
3008
        if err != nil {
×
3009
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3010
        }
×
3011

3012
        // Exit early if there are no channels for this node so we don't
3013
        // do the unnecessary feature fetching.
3014
        if len(rows) == 0 {
×
3015
                return nil
×
3016
        }
×
3017

3018
        features, err := getNodeFeatures(ctx, db, dbID)
×
3019
        if err != nil {
×
3020
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3021
        }
×
3022

3023
        for _, row := range rows {
×
3024
                node1, node2, err := buildNodeVertices(
×
3025
                        row.Node1Pubkey, row.Node2Pubkey,
×
3026
                )
×
3027
                if err != nil {
×
3028
                        return fmt.Errorf("unable to build node vertices: %w",
×
3029
                                err)
×
3030
                }
×
3031

3032
                edge := buildCacheableChannelInfo(
×
3033
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
3034
                        node1, node2,
×
3035
                )
×
3036

×
3037
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3038
                if err != nil {
×
3039
                        return err
×
3040
                }
×
3041

NEW
3042
                p1, p2, err := buildCachedChanPolicies(
×
NEW
3043
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
NEW
3044
                )
×
NEW
3045
                if err != nil {
×
NEW
3046
                        return err
×
UNCOV
3047
                }
×
3048

3049
                // Determine the outgoing and incoming policy for this
3050
                // channel and node combo.
3051
                outPolicy, inPolicy := p1, p2
×
3052
                if p1 != nil && node2 == nodePub {
×
3053
                        outPolicy, inPolicy = p2, p1
×
3054
                } else if p2 != nil && node1 != nodePub {
×
3055
                        outPolicy, inPolicy = p2, p1
×
3056
                }
×
3057

3058
                var cachedInPolicy *models.CachedEdgePolicy
×
3059
                if inPolicy != nil {
×
3060
                        cachedInPolicy = inPolicy
×
3061
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3062
                        cachedInPolicy.ToNodeFeatures = features
×
3063
                }
×
3064

3065
                directedChannel := &DirectedChannel{
×
3066
                        ChannelID:    edge.ChannelID,
×
3067
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3068
                        OtherNode:    edge.NodeKey2Bytes,
×
3069
                        Capacity:     edge.Capacity,
×
3070
                        OutPolicySet: outPolicy != nil,
×
3071
                        InPolicy:     cachedInPolicy,
×
3072
                }
×
3073
                if outPolicy != nil {
×
3074
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3075
                                directedChannel.InboundFee = fee
×
3076
                        })
×
3077
                }
3078

3079
                if nodePub == edge.NodeKey2Bytes {
×
3080
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3081
                }
×
3082

3083
                if err := cb(directedChannel); err != nil {
×
3084
                        return err
×
3085
                }
×
3086
        }
3087

3088
        return nil
×
3089
}
3090

3091
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3092
// and executes the provided callback for each node.
3093
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
3094
        db SQLQueries,
3095
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
3096

×
NEW
3097
        handleNode := func(_ context.Context,
×
NEW
3098
                node sqlc.ListNodeIDsAndPubKeysRow) error {
×
NEW
3099

×
NEW
3100
                var pub route.Vertex
×
NEW
3101
                copy(pub[:], node.PubKey)
×
NEW
3102

×
NEW
3103
                return cb(node.ID, pub)
×
NEW
3104
        }
×
3105

NEW
3106
        queryFunc := func(ctx context.Context, lastID int64,
×
NEW
3107
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
3108

×
NEW
3109
                return db.ListNodeIDsAndPubKeys(
×
3110
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3111
                                Version: int16(ProtocolV1),
×
3112
                                ID:      lastID,
×
NEW
3113
                                Limit:   limit,
×
3114
                        },
×
3115
                )
×
NEW
3116
        }
×
3117

NEW
3118
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
NEW
3119
                return row.ID
×
UNCOV
3120
        }
×
3121

NEW
3122
        return sqldb.ExecutePaginatedQuery(
×
NEW
3123
                ctx, cfg, int64(-1), queryFunc, extractCursor, handleNode,
×
NEW
3124
        )
×
3125
}
3126

3127
// forEachNodeChannel iterates through all channels of a node, executing
3128
// the passed callback on each. The call-back is provided with the channel's
3129
// edge information, the outgoing policy and the incoming policy for the
3130
// channel and node combo.
3131
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3132
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
3133
                *models.ChannelEdgePolicy,
3134
                *models.ChannelEdgePolicy) error) error {
×
3135

×
3136
        // Get all the V1 channels for this node.Add commentMore actions
×
3137
        rows, err := db.ListChannelsByNodeID(
×
3138
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3139
                        Version: int16(ProtocolV1),
×
3140
                        NodeID1: id,
×
3141
                },
×
3142
        )
×
3143
        if err != nil {
×
3144
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3145
        }
×
3146

3147
        // Call the call-back for each channel and its known policies.
3148
        for _, row := range rows {
×
3149
                node1, node2, err := buildNodeVertices(
×
3150
                        row.Node1Pubkey, row.Node2Pubkey,
×
3151
                )
×
3152
                if err != nil {
×
3153
                        return fmt.Errorf("unable to build node vertices: %w",
×
3154
                                err)
×
3155
                }
×
3156

3157
                edge, err := getAndBuildEdgeInfo(
×
3158
                        ctx, db, chain, row.GraphChannel, node1, node2,
×
3159
                )
×
3160
                if err != nil {
×
3161
                        return fmt.Errorf("unable to build channel info: %w",
×
3162
                                err)
×
3163
                }
×
3164

3165
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3166
                if err != nil {
×
3167
                        return fmt.Errorf("unable to extract channel "+
×
3168
                                "policies: %w", err)
×
3169
                }
×
3170

3171
                p1, p2, err := getAndBuildChanPolicies(
×
3172
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3173
                )
×
3174
                if err != nil {
×
3175
                        return fmt.Errorf("unable to build channel "+
×
3176
                                "policies: %w", err)
×
3177
                }
×
3178

3179
                // Determine the outgoing and incoming policy for this
3180
                // channel and node combo.
3181
                p1ToNode := row.GraphChannel.NodeID2
×
3182
                p2ToNode := row.GraphChannel.NodeID1
×
3183
                outPolicy, inPolicy := p1, p2
×
3184
                if (p1 != nil && p1ToNode == id) ||
×
3185
                        (p2 != nil && p2ToNode != id) {
×
3186

×
3187
                        outPolicy, inPolicy = p2, p1
×
3188
                }
×
3189

3190
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3191
                        return err
×
3192
                }
×
3193
        }
3194

3195
        return nil
×
3196
}
3197

3198
// updateChanEdgePolicy upserts the channel policy info we have stored for
3199
// a channel we already know of.
3200
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3201
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3202
        error) {
×
3203

×
3204
        var (
×
3205
                node1Pub, node2Pub route.Vertex
×
3206
                isNode1            bool
×
3207
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3208
        )
×
3209

×
3210
        // Check that this edge policy refers to a channel that we already
×
3211
        // know of. We do this explicitly so that we can return the appropriate
×
3212
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3213
        // abort the transaction which would abort the entire batch.
×
3214
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3215
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3216
                        Scid:    chanIDB,
×
3217
                        Version: int16(ProtocolV1),
×
3218
                },
×
3219
        )
×
3220
        if errors.Is(err, sql.ErrNoRows) {
×
3221
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3222
        } else if err != nil {
×
3223
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3224
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3225
        }
×
3226

3227
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3228
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3229

×
3230
        // Figure out which node this edge is from.
×
3231
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3232
        nodeID := dbChan.NodeID1
×
3233
        if !isNode1 {
×
3234
                nodeID = dbChan.NodeID2
×
3235
        }
×
3236

3237
        var (
×
3238
                inboundBase sql.NullInt64
×
3239
                inboundRate sql.NullInt64
×
3240
        )
×
3241
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3242
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3243
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3244
        })
×
3245

3246
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3247
                Version:     int16(ProtocolV1),
×
3248
                ChannelID:   dbChan.ID,
×
3249
                NodeID:      nodeID,
×
3250
                Timelock:    int32(edge.TimeLockDelta),
×
3251
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3252
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3253
                MinHtlcMsat: int64(edge.MinHTLC),
×
3254
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3255
                Disabled: sql.NullBool{
×
3256
                        Valid: true,
×
3257
                        Bool:  edge.IsDisabled(),
×
3258
                },
×
3259
                MaxHtlcMsat: sql.NullInt64{
×
3260
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3261
                        Int64: int64(edge.MaxHTLC),
×
3262
                },
×
3263
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3264
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3265
                InboundBaseFeeMsat:      inboundBase,
×
3266
                InboundFeeRateMilliMsat: inboundRate,
×
3267
                Signature:               edge.SigBytes,
×
3268
        })
×
3269
        if err != nil {
×
3270
                return node1Pub, node2Pub, isNode1,
×
3271
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3272
        }
×
3273

3274
        // Convert the flat extra opaque data into a map of TLV types to
3275
        // values.
3276
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3277
        if err != nil {
×
3278
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3279
                        "marshal extra opaque data: %w", err)
×
3280
        }
×
3281

3282
        // Update the channel policy's extra signed fields.
3283
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3284
        if err != nil {
×
3285
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3286
                        "policy extra TLVs: %w", err)
×
3287
        }
×
3288

3289
        return node1Pub, node2Pub, isNode1, nil
×
3290
}
3291

3292
// getNodeByPubKey attempts to look up a target node by its public key.
3293
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3294
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3295

×
3296
        dbNode, err := db.GetNodeByPubKey(
×
3297
                ctx, sqlc.GetNodeByPubKeyParams{
×
3298
                        Version: int16(ProtocolV1),
×
3299
                        PubKey:  pubKey[:],
×
3300
                },
×
3301
        )
×
3302
        if errors.Is(err, sql.ErrNoRows) {
×
3303
                return 0, nil, ErrGraphNodeNotFound
×
3304
        } else if err != nil {
×
3305
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3306
        }
×
3307

3308
        node, err := buildNode(ctx, db, &dbNode)
×
3309
        if err != nil {
×
3310
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3311
        }
×
3312

3313
        return dbNode.ID, node, nil
×
3314
}
3315

3316
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3317
// provided parameters.
3318
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3319
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3320

×
3321
        return &models.CachedEdgeInfo{
×
3322
                ChannelID:     byteOrder.Uint64(scid),
×
3323
                NodeKey1Bytes: node1Pub,
×
3324
                NodeKey2Bytes: node2Pub,
×
3325
                Capacity:      btcutil.Amount(capacity),
×
3326
        }
×
3327
}
×
3328

3329
// buildNode constructs a LightningNode instance from the given database node
3330
// record. The node's features, addresses and extra signed fields are also
3331
// fetched from the database and set on the node.
3332
func buildNode(ctx context.Context, db SQLQueries,
3333
        dbNode *sqlc.GraphNode) (*models.LightningNode, error) {
×
3334

×
3335
        // NOTE: buildNode is only used to load the data for a single node, and
×
3336
        // so no paged queries will be performed. This means that it's ok to
×
3337
        // used pass in default config values here.
×
NEW
3338
        cfg := sqldb.DefaultQueryConfig()
×
3339

×
3340
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3341
        if err != nil {
×
3342
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3343
                        err)
×
3344
        }
×
3345

3346
        return buildNodeWithBatchData(dbNode, data)
×
3347
}
3348

3349
// buildNodeWithBatchData builds a models.LightningNode instance
3350
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3351
// features/addresses/extra fields, then the corresponding fields are expected
3352
// to be present in the batchNodeData.
3353
func buildNodeWithBatchData(dbNode *sqlc.GraphNode,
3354
        batchData *batchNodeData) (*models.LightningNode, error) {
×
3355

×
3356
        if dbNode.Version != int16(ProtocolV1) {
×
3357
                return nil, fmt.Errorf("unsupported node version: %d",
×
3358
                        dbNode.Version)
×
3359
        }
×
3360

3361
        var pub [33]byte
×
3362
        copy(pub[:], dbNode.PubKey)
×
3363

×
3364
        node := &models.LightningNode{
×
3365
                PubKeyBytes: pub,
×
3366
                Features:    lnwire.EmptyFeatureVector(),
×
3367
                LastUpdate:  time.Unix(0, 0),
×
3368
        }
×
3369

×
3370
        if len(dbNode.Signature) == 0 {
×
3371
                return node, nil
×
3372
        }
×
3373

3374
        node.HaveNodeAnnouncement = true
×
3375
        node.AuthSigBytes = dbNode.Signature
×
3376
        node.Alias = dbNode.Alias.String
×
3377
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3378

×
3379
        var err error
×
3380
        if dbNode.Color.Valid {
×
3381
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3382
                if err != nil {
×
3383
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3384
                                err)
×
3385
                }
×
3386
        }
3387

3388
        // Use preloaded features.
3389
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3390
                fv := lnwire.EmptyFeatureVector()
×
3391
                for _, bit := range features {
×
3392
                        fv.Set(lnwire.FeatureBit(bit))
×
3393
                }
×
3394
                node.Features = fv
×
3395
        }
3396

3397
        // Use preloaded addresses.
3398
        addresses, exists := batchData.addresses[dbNode.ID]
×
3399
        if exists && len(addresses) > 0 {
×
3400
                node.Addresses, err = buildNodeAddresses(addresses)
×
3401
                if err != nil {
×
3402
                        return nil, fmt.Errorf("unable to build addresses "+
×
3403
                                "for node(%d): %w", dbNode.ID, err)
×
3404
                }
×
3405
        }
3406

3407
        // Use preloaded extra fields.
3408
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3409
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3410
                if err != nil {
×
3411
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3412
                                "signed fields: %w", err)
×
3413
                }
×
3414
                if len(recs) != 0 {
×
3415
                        node.ExtraOpaqueData = recs
×
3416
                }
×
3417
        }
3418

3419
        return node, nil
×
3420
}
3421

3422
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3423
// with the preloaded data, and executes the provided callback for each node.
3424
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3425
        db SQLQueries, nodes []sqlc.GraphNode,
3426
        cb func(dbID int64, node *models.LightningNode) error) error {
×
3427

×
3428
        // Extract node IDs for batch loading.
×
3429
        nodeIDs := make([]int64, len(nodes))
×
3430
        for i, node := range nodes {
×
3431
                nodeIDs[i] = node.ID
×
3432
        }
×
3433

3434
        // Batch load all related data for this page.
3435
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3436
        if err != nil {
×
3437
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3438
        }
×
3439

3440
        for _, dbNode := range nodes {
×
3441
                node, err := buildNodeWithBatchData(&dbNode, batchData)
×
3442
                if err != nil {
×
3443
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3444
                                dbNode.ID, err)
×
3445
                }
×
3446

3447
                if err := cb(dbNode.ID, node); err != nil {
×
3448
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3449
                                dbNode.ID, err)
×
3450
                }
×
3451
        }
3452

3453
        return nil
×
3454
}
3455

3456
// getNodeFeatures fetches the feature bits and constructs the feature vector
3457
// for a node with the given DB ID.
3458
func getNodeFeatures(ctx context.Context, db SQLQueries,
3459
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3460

×
3461
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3462
        if err != nil {
×
3463
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3464
                        nodeID, err)
×
3465
        }
×
3466

3467
        features := lnwire.EmptyFeatureVector()
×
3468
        for _, feature := range rows {
×
3469
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3470
        }
×
3471

3472
        return features, nil
×
3473
}
3474

3475
// upsertNode upserts the node record into the database. If the node already
3476
// exists, then the node's information is updated. If the node doesn't exist,
3477
// then a new node is created. The node's features, addresses and extra TLV
3478
// types are also updated. The node's DB ID is returned.
3479
func upsertNode(ctx context.Context, db SQLQueries,
3480
        node *models.LightningNode) (int64, error) {
×
3481

×
3482
        params := sqlc.UpsertNodeParams{
×
3483
                Version: int16(ProtocolV1),
×
3484
                PubKey:  node.PubKeyBytes[:],
×
3485
        }
×
3486

×
3487
        if node.HaveNodeAnnouncement {
×
3488
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3489
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3490
                params.Alias = sqldb.SQLStr(node.Alias)
×
3491
                params.Signature = node.AuthSigBytes
×
3492
        }
×
3493

3494
        nodeID, err := db.UpsertNode(ctx, params)
×
3495
        if err != nil {
×
3496
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3497
                        err)
×
3498
        }
×
3499

3500
        // We can exit here if we don't have the announcement yet.
3501
        if !node.HaveNodeAnnouncement {
×
3502
                return nodeID, nil
×
3503
        }
×
3504

3505
        // Update the node's features.
3506
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3507
        if err != nil {
×
3508
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3509
        }
×
3510

3511
        // Update the node's addresses.
3512
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3513
        if err != nil {
×
3514
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3515
        }
×
3516

3517
        // Convert the flat extra opaque data into a map of TLV types to
3518
        // values.
3519
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3520
        if err != nil {
×
3521
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3522
                        err)
×
3523
        }
×
3524

3525
        // Update the node's extra signed fields.
3526
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3527
        if err != nil {
×
3528
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3529
        }
×
3530

3531
        return nodeID, nil
×
3532
}
3533

3534
// upsertNodeFeatures updates the node's features node_features table. This
3535
// includes deleting any feature bits no longer present and inserting any new
3536
// feature bits. If the feature bit does not yet exist in the features table,
3537
// then an entry is created in that table first.
3538
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3539
        features *lnwire.FeatureVector) error {
×
3540

×
3541
        // Get any existing features for the node.
×
3542
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3543
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3544
                return err
×
3545
        }
×
3546

3547
        // Copy the nodes latest set of feature bits.
3548
        newFeatures := make(map[int32]struct{})
×
3549
        if features != nil {
×
3550
                for feature := range features.Features() {
×
3551
                        newFeatures[int32(feature)] = struct{}{}
×
3552
                }
×
3553
        }
3554

3555
        // For any current feature that already exists in the DB, remove it from
3556
        // the in-memory map. For any existing feature that does not exist in
3557
        // the in-memory map, delete it from the database.
3558
        for _, feature := range existingFeatures {
×
3559
                // The feature is still present, so there are no updates to be
×
3560
                // made.
×
3561
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3562
                        delete(newFeatures, feature.FeatureBit)
×
3563
                        continue
×
3564
                }
3565

3566
                // The feature is no longer present, so we remove it from the
3567
                // database.
3568
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3569
                        NodeID:     nodeID,
×
3570
                        FeatureBit: feature.FeatureBit,
×
3571
                })
×
3572
                if err != nil {
×
3573
                        return fmt.Errorf("unable to delete node(%d) "+
×
3574
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3575
                                err)
×
3576
                }
×
3577
        }
3578

3579
        // Any remaining entries in newFeatures are new features that need to be
3580
        // added to the database for the first time.
3581
        for feature := range newFeatures {
×
3582
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3583
                        NodeID:     nodeID,
×
3584
                        FeatureBit: feature,
×
3585
                })
×
3586
                if err != nil {
×
3587
                        return fmt.Errorf("unable to insert node(%d) "+
×
3588
                                "feature(%v): %w", nodeID, feature, err)
×
3589
                }
×
3590
        }
3591

3592
        return nil
×
3593
}
3594

3595
// fetchNodeFeatures fetches the features for a node with the given public key.
3596
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3597
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3598

×
3599
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3600
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3601
                        PubKey:  nodePub[:],
×
3602
                        Version: int16(ProtocolV1),
×
3603
                },
×
3604
        )
×
3605
        if err != nil {
×
3606
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3607
                        nodePub, err)
×
3608
        }
×
3609

3610
        features := lnwire.EmptyFeatureVector()
×
3611
        for _, bit := range rows {
×
3612
                features.Set(lnwire.FeatureBit(bit))
×
3613
        }
×
3614

3615
        return features, nil
×
3616
}
3617

3618
// dbAddressType is an enum type that represents the different address types
3619
// that we store in the node_addresses table. The address type determines how
3620
// the address is to be serialised/deserialize.
3621
type dbAddressType uint8
3622

3623
const (
3624
        addressTypeIPv4   dbAddressType = 1
3625
        addressTypeIPv6   dbAddressType = 2
3626
        addressTypeTorV2  dbAddressType = 3
3627
        addressTypeTorV3  dbAddressType = 4
3628
        addressTypeOpaque dbAddressType = math.MaxInt8
3629
)
3630

3631
// upsertNodeAddresses updates the node's addresses in the database. This
3632
// includes deleting any existing addresses and inserting the new set of
3633
// addresses. The deletion is necessary since the ordering of the addresses may
3634
// change, and we need to ensure that the database reflects the latest set of
3635
// addresses so that at the time of reconstructing the node announcement, the
3636
// order is preserved and the signature over the message remains valid.
3637
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3638
        addresses []net.Addr) error {
×
3639

×
3640
        // Delete any existing addresses for the node. This is required since
×
3641
        // even if the new set of addresses is the same, the ordering may have
×
3642
        // changed for a given address type.
×
3643
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3644
        if err != nil {
×
3645
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3646
                        nodeID, err)
×
3647
        }
×
3648

3649
        // Copy the nodes latest set of addresses.
3650
        newAddresses := map[dbAddressType][]string{
×
3651
                addressTypeIPv4:   {},
×
3652
                addressTypeIPv6:   {},
×
3653
                addressTypeTorV2:  {},
×
3654
                addressTypeTorV3:  {},
×
3655
                addressTypeOpaque: {},
×
3656
        }
×
3657
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3658
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3659
        }
×
3660

3661
        for _, address := range addresses {
×
3662
                switch addr := address.(type) {
×
3663
                case *net.TCPAddr:
×
3664
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3665
                                addAddr(addressTypeIPv4, addr)
×
3666
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3667
                                addAddr(addressTypeIPv6, addr)
×
3668
                        } else {
×
3669
                                return fmt.Errorf("unhandled IP address: %v",
×
3670
                                        addr)
×
3671
                        }
×
3672

3673
                case *tor.OnionAddr:
×
3674
                        switch len(addr.OnionService) {
×
3675
                        case tor.V2Len:
×
3676
                                addAddr(addressTypeTorV2, addr)
×
3677
                        case tor.V3Len:
×
3678
                                addAddr(addressTypeTorV3, addr)
×
3679
                        default:
×
3680
                                return fmt.Errorf("invalid length for a tor " +
×
3681
                                        "address")
×
3682
                        }
3683

3684
                case *lnwire.OpaqueAddrs:
×
3685
                        addAddr(addressTypeOpaque, addr)
×
3686

3687
                default:
×
3688
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3689
                }
3690
        }
3691

3692
        // Any remaining entries in newAddresses are new addresses that need to
3693
        // be added to the database for the first time.
3694
        for addrType, addrList := range newAddresses {
×
3695
                for position, addr := range addrList {
×
3696
                        err := db.InsertNodeAddress(
×
3697
                                ctx, sqlc.InsertNodeAddressParams{
×
3698
                                        NodeID:   nodeID,
×
3699
                                        Type:     int16(addrType),
×
3700
                                        Address:  addr,
×
3701
                                        Position: int32(position),
×
3702
                                },
×
3703
                        )
×
3704
                        if err != nil {
×
3705
                                return fmt.Errorf("unable to insert "+
×
3706
                                        "node(%d) address(%v): %w", nodeID,
×
3707
                                        addr, err)
×
3708
                        }
×
3709
                }
3710
        }
3711

3712
        return nil
×
3713
}
3714

3715
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3716
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
3717
        error) {
×
3718

×
3719
        // GetNodeAddresses ensures that the addresses for a given type are
×
3720
        // returned in the same order as they were inserted.
×
3721
        rows, err := db.GetNodeAddresses(ctx, id)
×
3722
        if err != nil {
×
3723
                return nil, err
×
3724
        }
×
3725

3726
        addresses := make([]net.Addr, 0, len(rows))
×
3727
        for _, row := range rows {
×
3728
                address := row.Address
×
3729

×
3730
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
3731
                if err != nil {
×
3732
                        return nil, fmt.Errorf("unable to parse address "+
×
3733
                                "for node(%d): %v: %w", id, address, err)
×
3734
                }
×
3735

3736
                addresses = append(addresses, addr)
×
3737
        }
3738

3739
        // If we have no addresses, then we'll return nil instead of an
3740
        // empty slice.
3741
        if len(addresses) == 0 {
×
3742
                addresses = nil
×
3743
        }
×
3744

3745
        return addresses, nil
×
3746
}
3747

3748
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3749
// database. This includes updating any existing types, inserting any new types,
3750
// and deleting any types that are no longer present.
3751
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3752
        nodeID int64, extraFields map[uint64][]byte) error {
×
3753

×
3754
        // Get any existing extra signed fields for the node.
×
3755
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3756
        if err != nil {
×
3757
                return err
×
3758
        }
×
3759

3760
        // Make a lookup map of the existing field types so that we can use it
3761
        // to keep track of any fields we should delete.
3762
        m := make(map[uint64]bool)
×
3763
        for _, field := range existingFields {
×
3764
                m[uint64(field.Type)] = true
×
3765
        }
×
3766

3767
        // For all the new fields, we'll upsert them and remove them from the
3768
        // map of existing fields.
3769
        for tlvType, value := range extraFields {
×
3770
                err = db.UpsertNodeExtraType(
×
3771
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3772
                                NodeID: nodeID,
×
3773
                                Type:   int64(tlvType),
×
3774
                                Value:  value,
×
3775
                        },
×
3776
                )
×
3777
                if err != nil {
×
3778
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3779
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3780
                }
×
3781

3782
                // Remove the field from the map of existing fields if it was
3783
                // present.
3784
                delete(m, tlvType)
×
3785
        }
3786

3787
        // For all the fields that are left in the map of existing fields, we'll
3788
        // delete them as they are no longer present in the new set of fields.
3789
        for tlvType := range m {
×
3790
                err = db.DeleteExtraNodeType(
×
3791
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3792
                                NodeID: nodeID,
×
3793
                                Type:   int64(tlvType),
×
3794
                        },
×
3795
                )
×
3796
                if err != nil {
×
3797
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3798
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3799
                }
×
3800
        }
3801

3802
        return nil
×
3803
}
3804

3805
// srcNodeInfo holds the information about the source node of the graph.
3806
type srcNodeInfo struct {
3807
        // id is the DB level ID of the source node entry in the "nodes" table.
3808
        id int64
3809

3810
        // pub is the public key of the source node.
3811
        pub route.Vertex
3812
}
3813

3814
// sourceNode returns the DB node ID and pub key of the source node for the
3815
// specified protocol version.
3816
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3817
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3818

×
3819
        s.srcNodeMu.Lock()
×
3820
        defer s.srcNodeMu.Unlock()
×
3821

×
3822
        // If we already have the source node ID and pub key cached, then
×
3823
        // return them.
×
3824
        if info, ok := s.srcNodes[version]; ok {
×
3825
                return info.id, info.pub, nil
×
3826
        }
×
3827

3828
        var pubKey route.Vertex
×
3829

×
3830
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3831
        if err != nil {
×
3832
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3833
                        err)
×
3834
        }
×
3835

3836
        if len(nodes) == 0 {
×
3837
                return 0, pubKey, ErrSourceNodeNotSet
×
3838
        } else if len(nodes) > 1 {
×
3839
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3840
                        "protocol %s found", version)
×
3841
        }
×
3842

3843
        copy(pubKey[:], nodes[0].PubKey)
×
3844

×
3845
        s.srcNodes[version] = &srcNodeInfo{
×
3846
                id:  nodes[0].NodeID,
×
3847
                pub: pubKey,
×
3848
        }
×
3849

×
3850
        return nodes[0].NodeID, pubKey, nil
×
3851
}
3852

3853
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3854
// This then produces a map from TLV type to value. If the input is not a
3855
// valid TLV stream, then an error is returned.
3856
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3857
        r := bytes.NewReader(data)
×
3858

×
3859
        tlvStream, err := tlv.NewStream()
×
3860
        if err != nil {
×
3861
                return nil, err
×
3862
        }
×
3863

3864
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3865
        // pass it into the P2P decoding variant.
3866
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3867
        if err != nil {
×
3868
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3869
        }
×
3870
        if len(parsedTypes) == 0 {
×
3871
                return nil, nil
×
3872
        }
×
3873

3874
        records := make(map[uint64][]byte)
×
3875
        for k, v := range parsedTypes {
×
3876
                records[uint64(k)] = v
×
3877
        }
×
3878

3879
        return records, nil
×
3880
}
3881

3882
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3883
// channel.
3884
type dbChanInfo struct {
3885
        channelID int64
3886
        node1ID   int64
3887
        node2ID   int64
3888
}
3889

3890
// insertChannel inserts a new channel record into the database.
3891
func insertChannel(ctx context.Context, db SQLQueries,
3892
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3893

×
3894
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3895

×
3896
        // Make sure that the channel doesn't already exist. We do this
×
3897
        // explicitly instead of relying on catching a unique constraint error
×
3898
        // because relying on SQL to throw that error would abort the entire
×
3899
        // batch of transactions.
×
3900
        _, err := db.GetChannelBySCID(
×
3901
                ctx, sqlc.GetChannelBySCIDParams{
×
3902
                        Scid:    chanIDB,
×
3903
                        Version: int16(ProtocolV1),
×
3904
                },
×
3905
        )
×
3906
        if err == nil {
×
3907
                return nil, ErrEdgeAlreadyExist
×
3908
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3909
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3910
        }
×
3911

3912
        // Make sure that at least a "shell" entry for each node is present in
3913
        // the nodes table.
3914
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3915
        if err != nil {
×
3916
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3917
        }
×
3918

3919
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3920
        if err != nil {
×
3921
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3922
        }
×
3923

3924
        var capacity sql.NullInt64
×
3925
        if edge.Capacity != 0 {
×
3926
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3927
        }
×
3928

3929
        createParams := sqlc.CreateChannelParams{
×
3930
                Version:     int16(ProtocolV1),
×
3931
                Scid:        chanIDB,
×
3932
                NodeID1:     node1DBID,
×
3933
                NodeID2:     node2DBID,
×
3934
                Outpoint:    edge.ChannelPoint.String(),
×
3935
                Capacity:    capacity,
×
3936
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3937
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3938
        }
×
3939

×
3940
        if edge.AuthProof != nil {
×
3941
                proof := edge.AuthProof
×
3942

×
3943
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3944
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3945
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3946
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3947
        }
×
3948

3949
        // Insert the new channel record.
3950
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3951
        if err != nil {
×
3952
                return nil, err
×
3953
        }
×
3954

3955
        // Insert any channel features.
3956
        for feature := range edge.Features.Features() {
×
3957
                err = db.InsertChannelFeature(
×
3958
                        ctx, sqlc.InsertChannelFeatureParams{
×
3959
                                ChannelID:  dbChanID,
×
3960
                                FeatureBit: int32(feature),
×
3961
                        },
×
3962
                )
×
3963
                if err != nil {
×
3964
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
3965
                                "feature(%v): %w", dbChanID, feature, err)
×
3966
                }
×
3967
        }
3968

3969
        // Finally, insert any extra TLV fields in the channel announcement.
3970
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3971
        if err != nil {
×
3972
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
3973
                        "data: %w", err)
×
3974
        }
×
3975

3976
        for tlvType, value := range extra {
×
3977
                err := db.CreateChannelExtraType(
×
3978
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3979
                                ChannelID: dbChanID,
×
3980
                                Type:      int64(tlvType),
×
3981
                                Value:     value,
×
3982
                        },
×
3983
                )
×
3984
                if err != nil {
×
3985
                        return nil, fmt.Errorf("unable to upsert "+
×
3986
                                "channel(%d) extra signed field(%v): %w",
×
3987
                                edge.ChannelID, tlvType, err)
×
3988
                }
×
3989
        }
3990

3991
        return &dbChanInfo{
×
3992
                channelID: dbChanID,
×
3993
                node1ID:   node1DBID,
×
3994
                node2ID:   node2DBID,
×
3995
        }, nil
×
3996
}
3997

3998
// maybeCreateShellNode checks if a shell node entry exists for the
3999
// given public key. If it does not exist, then a new shell node entry is
4000
// created. The ID of the node is returned. A shell node only has a protocol
4001
// version and public key persisted.
4002
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
4003
        pubKey route.Vertex) (int64, error) {
×
4004

×
4005
        dbNode, err := db.GetNodeByPubKey(
×
4006
                ctx, sqlc.GetNodeByPubKeyParams{
×
4007
                        PubKey:  pubKey[:],
×
4008
                        Version: int16(ProtocolV1),
×
4009
                },
×
4010
        )
×
4011
        // The node exists. Return the ID.
×
4012
        if err == nil {
×
4013
                return dbNode.ID, nil
×
4014
        } else if !errors.Is(err, sql.ErrNoRows) {
×
4015
                return 0, err
×
4016
        }
×
4017

4018
        // Otherwise, the node does not exist, so we create a shell entry for
4019
        // it.
4020
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
4021
                Version: int16(ProtocolV1),
×
4022
                PubKey:  pubKey[:],
×
4023
        })
×
4024
        if err != nil {
×
4025
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4026
        }
×
4027

4028
        return id, nil
×
4029
}
4030

4031
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4032
// the database. This includes deleting any existing types and then inserting
4033
// the new types.
4034
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4035
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4036

×
4037
        // Delete all existing extra signed fields for the channel policy.
×
4038
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4039
        if err != nil {
×
4040
                return fmt.Errorf("unable to delete "+
×
4041
                        "existing policy extra signed fields for policy %d: %w",
×
4042
                        chanPolicyID, err)
×
4043
        }
×
4044

4045
        // Insert all new extra signed fields for the channel policy.
4046
        for tlvType, value := range extraFields {
×
4047
                err = db.InsertChanPolicyExtraType(
×
4048
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
4049
                                ChannelPolicyID: chanPolicyID,
×
4050
                                Type:            int64(tlvType),
×
4051
                                Value:           value,
×
4052
                        },
×
4053
                )
×
4054
                if err != nil {
×
4055
                        return fmt.Errorf("unable to insert "+
×
4056
                                "channel_policy(%d) extra signed field(%v): %w",
×
4057
                                chanPolicyID, tlvType, err)
×
4058
                }
×
4059
        }
4060

4061
        return nil
×
4062
}
4063

4064
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4065
// provided dbChanRow and also fetches any other required information
4066
// to construct the edge info.
4067
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
4068
        chain chainhash.Hash, dbChan sqlc.GraphChannel, node1,
4069
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4070

×
4071
        // NOTE: getAndBuildEdgeInfo is only used to load the data for a single
×
4072
        // edge, and so no paged queries will be performed. This means that
×
4073
        // it's ok to used pass in default config values here.
×
NEW
4074
        cfg := sqldb.DefaultQueryConfig()
×
4075

×
4076
        data, err := batchLoadChannelData(ctx, cfg, db, []int64{dbChan.ID}, nil)
×
4077
        if err != nil {
×
4078
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
4079
                        err)
×
4080
        }
×
4081

4082
        return buildEdgeInfoWithBatchData(chain, dbChan, node1, node2, data)
×
4083
}
4084

4085
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
4086
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
4087
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
4088
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
4089

×
4090
        if dbChan.Version != int16(ProtocolV1) {
×
4091
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4092
                        dbChan.Version)
×
4093
        }
×
4094

4095
        // Use pre-loaded features and extras types.
4096
        fv := lnwire.EmptyFeatureVector()
×
4097
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
4098
                for _, bit := range features {
×
4099
                        fv.Set(lnwire.FeatureBit(bit))
×
4100
                }
×
4101
        }
4102

4103
        var extras map[uint64][]byte
×
4104
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
4105
        if exists {
×
4106
                extras = channelExtras
×
4107
        } else {
×
4108
                extras = make(map[uint64][]byte)
×
4109
        }
×
4110

4111
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4112
        if err != nil {
×
4113
                return nil, err
×
4114
        }
×
4115

4116
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4117
        if err != nil {
×
4118
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4119
                        "fields: %w", err)
×
4120
        }
×
4121
        if recs == nil {
×
4122
                recs = make([]byte, 0)
×
4123
        }
×
4124

4125
        var btcKey1, btcKey2 route.Vertex
×
4126
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4127
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4128

×
4129
        channel := &models.ChannelEdgeInfo{
×
4130
                ChainHash:        chain,
×
4131
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4132
                NodeKey1Bytes:    node1,
×
4133
                NodeKey2Bytes:    node2,
×
4134
                BitcoinKey1Bytes: btcKey1,
×
4135
                BitcoinKey2Bytes: btcKey2,
×
4136
                ChannelPoint:     *op,
×
4137
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4138
                Features:         fv,
×
4139
                ExtraOpaqueData:  recs,
×
4140
        }
×
4141

×
4142
        // We always set all the signatures at the same time, so we can
×
4143
        // safely check if one signature is present to determine if we have the
×
4144
        // rest of the signatures for the auth proof.
×
4145
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4146
                channel.AuthProof = &models.ChannelAuthProof{
×
4147
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4148
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4149
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4150
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4151
                }
×
4152
        }
×
4153

4154
        return channel, nil
×
4155
}
4156

4157
// buildNodeVertices is a helper that converts raw node public keys
4158
// into route.Vertex instances.
4159
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4160
        route.Vertex, error) {
×
4161

×
4162
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4163
        if err != nil {
×
4164
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4165
                        "create vertex from node1 pubkey: %w", err)
×
4166
        }
×
4167

4168
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4169
        if err != nil {
×
4170
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4171
                        "create vertex from node2 pubkey: %w", err)
×
4172
        }
×
4173

4174
        return node1Vertex, node2Vertex, nil
×
4175
}
4176

4177
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4178
// retrieves all the extra info required to build the complete
4179
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4180
// the provided sqlc.GraphChannelPolicy records are nil.
4181
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4182
        dbPol1, dbPol2 *sqlc.GraphChannelPolicy, channelID uint64, node1,
4183
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4184
        *models.ChannelEdgePolicy, error) {
×
4185

×
4186
        if dbPol1 == nil && dbPol2 == nil {
×
4187
                return nil, nil, nil
×
4188
        }
×
4189

4190
        var policyIDs = make([]int64, 0, 2)
×
4191
        if dbPol1 != nil {
×
4192
                policyIDs = append(policyIDs, dbPol1.ID)
×
4193
        }
×
4194
        if dbPol2 != nil {
×
4195
                policyIDs = append(policyIDs, dbPol2.ID)
×
4196
        }
×
4197

4198
        // NOTE: getAndBuildChanPolicies is only used to load the data for
4199
        // a maximum of two policies, and so no paged queries will be
4200
        // performed (unless the page size is one). So it's ok to use
4201
        // the default config values here.
NEW
4202
        cfg := sqldb.DefaultQueryConfig()
×
4203

×
4204
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4205
        if err != nil {
×
4206
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4207
                        "data: %w", err)
×
4208
        }
×
4209

4210
        pol1, err := buildChanPolicyWithBatchData(
×
4211
                dbPol1, channelID, node2, batchData,
×
4212
        )
×
4213
        if err != nil {
×
4214
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4215
        }
×
4216

4217
        pol2, err := buildChanPolicyWithBatchData(
×
4218
                dbPol2, channelID, node1, batchData,
×
4219
        )
×
4220
        if err != nil {
×
4221
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4222
        }
×
4223

4224
        return pol1, pol2, nil
×
4225
}
4226

4227
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4228
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4229
// then nil is returned for it.
4230
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4231
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
NEW
4232
        *models.CachedEdgePolicy, error) {
×
NEW
4233

×
NEW
4234
        var p1, p2 *models.CachedEdgePolicy
×
NEW
4235
        if dbPol1 != nil {
×
NEW
4236
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
NEW
4237
                if err != nil {
×
NEW
4238
                        return nil, nil, err
×
NEW
4239
                }
×
4240

NEW
4241
                p1 = models.NewCachedPolicy(policy1)
×
4242
        }
NEW
4243
        if dbPol2 != nil {
×
NEW
4244
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
NEW
4245
                if err != nil {
×
NEW
4246
                        return nil, nil, err
×
NEW
4247
                }
×
4248

NEW
4249
                p2 = models.NewCachedPolicy(policy2)
×
4250
        }
4251

NEW
4252
        return p1, p2, nil
×
4253
}
4254

4255
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4256
// provided sqlc.GraphChannelPolicy and other required information.
4257
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4258
        extras map[uint64][]byte,
4259
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4260

×
4261
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4262
        if err != nil {
×
4263
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4264
                        "fields: %w", err)
×
4265
        }
×
4266

4267
        var inboundFee fn.Option[lnwire.Fee]
×
4268
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4269
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4270

×
4271
                inboundFee = fn.Some(lnwire.Fee{
×
4272
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4273
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4274
                })
×
4275
        }
×
4276

4277
        return &models.ChannelEdgePolicy{
×
4278
                SigBytes:  dbPolicy.Signature,
×
4279
                ChannelID: channelID,
×
4280
                LastUpdate: time.Unix(
×
4281
                        dbPolicy.LastUpdate.Int64, 0,
×
4282
                ),
×
4283
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4284
                        dbPolicy.MessageFlags,
×
4285
                ),
×
4286
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4287
                        dbPolicy.ChannelFlags,
×
4288
                ),
×
4289
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4290
                MinHTLC: lnwire.MilliSatoshi(
×
4291
                        dbPolicy.MinHtlcMsat,
×
4292
                ),
×
4293
                MaxHTLC: lnwire.MilliSatoshi(
×
4294
                        dbPolicy.MaxHtlcMsat.Int64,
×
4295
                ),
×
4296
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4297
                        dbPolicy.BaseFeeMsat,
×
4298
                ),
×
4299
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4300
                ToNode:                    toNode,
×
4301
                InboundFee:                inboundFee,
×
4302
                ExtraOpaqueData:           recs,
×
4303
        }, nil
×
4304
}
4305

4306
// buildNodes builds the models.LightningNode instances for the
4307
// given row which is expected to be a sqlc type that contains node information.
4308
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4309
        dbNode2 sqlc.GraphNode) (*models.LightningNode, *models.LightningNode,
4310
        error) {
×
4311

×
4312
        node1, err := buildNode(ctx, db, &dbNode1)
×
4313
        if err != nil {
×
4314
                return nil, nil, err
×
4315
        }
×
4316

4317
        node2, err := buildNode(ctx, db, &dbNode2)
×
4318
        if err != nil {
×
4319
                return nil, nil, err
×
4320
        }
×
4321

4322
        return node1, node2, nil
×
4323
}
4324

4325
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4326
// row which is expected to be a sqlc type that contains channel policy
4327
// information. It returns two policies, which may be nil if the policy
4328
// information is not present in the row.
4329
//
4330
//nolint:ll,dupl,funlen
4331
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4332
        *sqlc.GraphChannelPolicy, error) {
×
4333

×
4334
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4335
        switch r := row.(type) {
×
4336
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4337
                if r.Policy1Timelock.Valid {
×
4338
                        policy1 = &sqlc.GraphChannelPolicy{
×
4339
                                Timelock:                r.Policy1Timelock.Int32,
×
4340
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4341
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4342
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4343
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4344
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4345
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4346
                                Disabled:                r.Policy1Disabled,
×
4347
                                MessageFlags:            r.Policy1MessageFlags,
×
4348
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4349
                        }
×
4350
                }
×
4351
                if r.Policy2Timelock.Valid {
×
4352
                        policy2 = &sqlc.GraphChannelPolicy{
×
4353
                                Timelock:                r.Policy2Timelock.Int32,
×
4354
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4355
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4356
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4357
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4358
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4359
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4360
                                Disabled:                r.Policy2Disabled,
×
4361
                                MessageFlags:            r.Policy2MessageFlags,
×
4362
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4363
                        }
×
4364
                }
×
4365

4366
                return policy1, policy2, nil
×
4367

4368
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4369
                if r.Policy1ID.Valid {
×
4370
                        policy1 = &sqlc.GraphChannelPolicy{
×
4371
                                ID:                      r.Policy1ID.Int64,
×
4372
                                Version:                 r.Policy1Version.Int16,
×
4373
                                ChannelID:               r.GraphChannel.ID,
×
4374
                                NodeID:                  r.Policy1NodeID.Int64,
×
4375
                                Timelock:                r.Policy1Timelock.Int32,
×
4376
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4377
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4378
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4379
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4380
                                LastUpdate:              r.Policy1LastUpdate,
×
4381
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4382
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4383
                                Disabled:                r.Policy1Disabled,
×
4384
                                MessageFlags:            r.Policy1MessageFlags,
×
4385
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4386
                                Signature:               r.Policy1Signature,
×
4387
                        }
×
4388
                }
×
4389
                if r.Policy2ID.Valid {
×
4390
                        policy2 = &sqlc.GraphChannelPolicy{
×
4391
                                ID:                      r.Policy2ID.Int64,
×
4392
                                Version:                 r.Policy2Version.Int16,
×
4393
                                ChannelID:               r.GraphChannel.ID,
×
4394
                                NodeID:                  r.Policy2NodeID.Int64,
×
4395
                                Timelock:                r.Policy2Timelock.Int32,
×
4396
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4397
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4398
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4399
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4400
                                LastUpdate:              r.Policy2LastUpdate,
×
4401
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4402
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4403
                                Disabled:                r.Policy2Disabled,
×
4404
                                MessageFlags:            r.Policy2MessageFlags,
×
4405
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4406
                                Signature:               r.Policy2Signature,
×
4407
                        }
×
4408
                }
×
4409

4410
                return policy1, policy2, nil
×
4411

4412
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4413
                if r.Policy1ID.Valid {
×
4414
                        policy1 = &sqlc.GraphChannelPolicy{
×
4415
                                ID:                      r.Policy1ID.Int64,
×
4416
                                Version:                 r.Policy1Version.Int16,
×
4417
                                ChannelID:               r.GraphChannel.ID,
×
4418
                                NodeID:                  r.Policy1NodeID.Int64,
×
4419
                                Timelock:                r.Policy1Timelock.Int32,
×
4420
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4421
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4422
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4423
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4424
                                LastUpdate:              r.Policy1LastUpdate,
×
4425
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4426
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4427
                                Disabled:                r.Policy1Disabled,
×
4428
                                MessageFlags:            r.Policy1MessageFlags,
×
4429
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4430
                                Signature:               r.Policy1Signature,
×
4431
                        }
×
4432
                }
×
4433
                if r.Policy2ID.Valid {
×
4434
                        policy2 = &sqlc.GraphChannelPolicy{
×
4435
                                ID:                      r.Policy2ID.Int64,
×
4436
                                Version:                 r.Policy2Version.Int16,
×
4437
                                ChannelID:               r.GraphChannel.ID,
×
4438
                                NodeID:                  r.Policy2NodeID.Int64,
×
4439
                                Timelock:                r.Policy2Timelock.Int32,
×
4440
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4441
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4442
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4443
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4444
                                LastUpdate:              r.Policy2LastUpdate,
×
4445
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4446
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4447
                                Disabled:                r.Policy2Disabled,
×
4448
                                MessageFlags:            r.Policy2MessageFlags,
×
4449
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4450
                                Signature:               r.Policy2Signature,
×
4451
                        }
×
4452
                }
×
4453

4454
                return policy1, policy2, nil
×
4455

4456
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4457
                if r.Policy1ID.Valid {
×
4458
                        policy1 = &sqlc.GraphChannelPolicy{
×
4459
                                ID:                      r.Policy1ID.Int64,
×
4460
                                Version:                 r.Policy1Version.Int16,
×
4461
                                ChannelID:               r.GraphChannel.ID,
×
4462
                                NodeID:                  r.Policy1NodeID.Int64,
×
4463
                                Timelock:                r.Policy1Timelock.Int32,
×
4464
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4465
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4466
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4467
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4468
                                LastUpdate:              r.Policy1LastUpdate,
×
4469
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4470
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4471
                                Disabled:                r.Policy1Disabled,
×
4472
                                MessageFlags:            r.Policy1MessageFlags,
×
4473
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4474
                                Signature:               r.Policy1Signature,
×
4475
                        }
×
4476
                }
×
4477
                if r.Policy2ID.Valid {
×
4478
                        policy2 = &sqlc.GraphChannelPolicy{
×
4479
                                ID:                      r.Policy2ID.Int64,
×
4480
                                Version:                 r.Policy2Version.Int16,
×
4481
                                ChannelID:               r.GraphChannel.ID,
×
4482
                                NodeID:                  r.Policy2NodeID.Int64,
×
4483
                                Timelock:                r.Policy2Timelock.Int32,
×
4484
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4485
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4486
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4487
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4488
                                LastUpdate:              r.Policy2LastUpdate,
×
4489
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4490
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4491
                                Disabled:                r.Policy2Disabled,
×
4492
                                MessageFlags:            r.Policy2MessageFlags,
×
4493
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4494
                                Signature:               r.Policy2Signature,
×
4495
                        }
×
4496
                }
×
4497

4498
                return policy1, policy2, nil
×
4499

4500
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4501
                if r.Policy1ID.Valid {
×
4502
                        policy1 = &sqlc.GraphChannelPolicy{
×
4503
                                ID:                      r.Policy1ID.Int64,
×
4504
                                Version:                 r.Policy1Version.Int16,
×
4505
                                ChannelID:               r.GraphChannel.ID,
×
4506
                                NodeID:                  r.Policy1NodeID.Int64,
×
4507
                                Timelock:                r.Policy1Timelock.Int32,
×
4508
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4509
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4510
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4511
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4512
                                LastUpdate:              r.Policy1LastUpdate,
×
4513
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4514
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4515
                                Disabled:                r.Policy1Disabled,
×
4516
                                MessageFlags:            r.Policy1MessageFlags,
×
4517
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4518
                                Signature:               r.Policy1Signature,
×
4519
                        }
×
4520
                }
×
4521
                if r.Policy2ID.Valid {
×
4522
                        policy2 = &sqlc.GraphChannelPolicy{
×
4523
                                ID:                      r.Policy2ID.Int64,
×
4524
                                Version:                 r.Policy2Version.Int16,
×
4525
                                ChannelID:               r.GraphChannel.ID,
×
4526
                                NodeID:                  r.Policy2NodeID.Int64,
×
4527
                                Timelock:                r.Policy2Timelock.Int32,
×
4528
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4529
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4530
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4531
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4532
                                LastUpdate:              r.Policy2LastUpdate,
×
4533
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4534
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4535
                                Disabled:                r.Policy2Disabled,
×
4536
                                MessageFlags:            r.Policy2MessageFlags,
×
4537
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4538
                                Signature:               r.Policy2Signature,
×
4539
                        }
×
4540
                }
×
4541

4542
                return policy1, policy2, nil
×
4543

4544
        case sqlc.ListChannelsByNodeIDRow:
×
4545
                if r.Policy1ID.Valid {
×
4546
                        policy1 = &sqlc.GraphChannelPolicy{
×
4547
                                ID:                      r.Policy1ID.Int64,
×
4548
                                Version:                 r.Policy1Version.Int16,
×
4549
                                ChannelID:               r.GraphChannel.ID,
×
4550
                                NodeID:                  r.Policy1NodeID.Int64,
×
4551
                                Timelock:                r.Policy1Timelock.Int32,
×
4552
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4553
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4554
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4555
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4556
                                LastUpdate:              r.Policy1LastUpdate,
×
4557
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4558
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4559
                                Disabled:                r.Policy1Disabled,
×
4560
                                MessageFlags:            r.Policy1MessageFlags,
×
4561
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4562
                                Signature:               r.Policy1Signature,
×
4563
                        }
×
4564
                }
×
4565
                if r.Policy2ID.Valid {
×
4566
                        policy2 = &sqlc.GraphChannelPolicy{
×
4567
                                ID:                      r.Policy2ID.Int64,
×
4568
                                Version:                 r.Policy2Version.Int16,
×
4569
                                ChannelID:               r.GraphChannel.ID,
×
4570
                                NodeID:                  r.Policy2NodeID.Int64,
×
4571
                                Timelock:                r.Policy2Timelock.Int32,
×
4572
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4573
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4574
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4575
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4576
                                LastUpdate:              r.Policy2LastUpdate,
×
4577
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4578
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4579
                                Disabled:                r.Policy2Disabled,
×
4580
                                MessageFlags:            r.Policy2MessageFlags,
×
4581
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4582
                                Signature:               r.Policy2Signature,
×
4583
                        }
×
4584
                }
×
4585

4586
                return policy1, policy2, nil
×
4587

4588
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4589
                if r.Policy1ID.Valid {
×
4590
                        policy1 = &sqlc.GraphChannelPolicy{
×
4591
                                ID:                      r.Policy1ID.Int64,
×
4592
                                Version:                 r.Policy1Version.Int16,
×
4593
                                ChannelID:               r.GraphChannel.ID,
×
4594
                                NodeID:                  r.Policy1NodeID.Int64,
×
4595
                                Timelock:                r.Policy1Timelock.Int32,
×
4596
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4597
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4598
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4599
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4600
                                LastUpdate:              r.Policy1LastUpdate,
×
4601
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4602
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4603
                                Disabled:                r.Policy1Disabled,
×
4604
                                MessageFlags:            r.Policy1MessageFlags,
×
4605
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4606
                                Signature:               r.Policy1Signature,
×
4607
                        }
×
4608
                }
×
4609
                if r.Policy2ID.Valid {
×
4610
                        policy2 = &sqlc.GraphChannelPolicy{
×
4611
                                ID:                      r.Policy2ID.Int64,
×
4612
                                Version:                 r.Policy2Version.Int16,
×
4613
                                ChannelID:               r.GraphChannel.ID,
×
4614
                                NodeID:                  r.Policy2NodeID.Int64,
×
4615
                                Timelock:                r.Policy2Timelock.Int32,
×
4616
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4617
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4618
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4619
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4620
                                LastUpdate:              r.Policy2LastUpdate,
×
4621
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4622
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4623
                                Disabled:                r.Policy2Disabled,
×
4624
                                MessageFlags:            r.Policy2MessageFlags,
×
4625
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4626
                                Signature:               r.Policy2Signature,
×
4627
                        }
×
4628
                }
×
4629

4630
                return policy1, policy2, nil
×
4631
        default:
×
4632
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4633
                        "extractChannelPolicies: %T", r)
×
4634
        }
4635
}
4636

4637
// channelIDToBytes converts a channel ID (SCID) to a byte array
4638
// representation.
4639
func channelIDToBytes(channelID uint64) []byte {
×
4640
        var chanIDB [8]byte
×
4641
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4642

×
4643
        return chanIDB[:]
×
4644
}
×
4645

4646
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4647
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4648
        if len(addresses) == 0 {
×
4649
                return nil, nil
×
4650
        }
×
4651

4652
        result := make([]net.Addr, 0, len(addresses))
×
4653
        for _, addr := range addresses {
×
4654
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
4655
                if err != nil {
×
4656
                        return nil, fmt.Errorf("unable to parse address %s "+
×
4657
                                "of type %d: %w", addr.address, addr.addrType,
×
4658
                                err)
×
4659
                }
×
4660
                if netAddr != nil {
×
4661
                        result = append(result, netAddr)
×
4662
                }
×
4663
        }
4664

4665
        // If we have no valid addresses, return nil instead of empty slice.
4666
        if len(result) == 0 {
×
4667
                return nil, nil
×
4668
        }
×
4669

4670
        return result, nil
×
4671
}
4672

4673
// parseAddress parses the given address string based on the address type
4674
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4675
// and opaque addresses.
4676
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4677
        switch addrType {
×
4678
        case addressTypeIPv4:
×
4679
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4680
                if err != nil {
×
4681
                        return nil, err
×
4682
                }
×
4683

4684
                tcp.IP = tcp.IP.To4()
×
4685

×
4686
                return tcp, nil
×
4687

4688
        case addressTypeIPv6:
×
4689
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
4690
                if err != nil {
×
4691
                        return nil, err
×
4692
                }
×
4693

4694
                return tcp, nil
×
4695

4696
        case addressTypeTorV3, addressTypeTorV2:
×
4697
                service, portStr, err := net.SplitHostPort(address)
×
4698
                if err != nil {
×
4699
                        return nil, fmt.Errorf("unable to split tor "+
×
4700
                                "address: %v", address)
×
4701
                }
×
4702

4703
                port, err := strconv.Atoi(portStr)
×
4704
                if err != nil {
×
4705
                        return nil, err
×
4706
                }
×
4707

4708
                return &tor.OnionAddr{
×
4709
                        OnionService: service,
×
4710
                        Port:         port,
×
4711
                }, nil
×
4712

4713
        case addressTypeOpaque:
×
4714
                opaque, err := hex.DecodeString(address)
×
4715
                if err != nil {
×
4716
                        return nil, fmt.Errorf("unable to decode opaque "+
×
4717
                                "address: %v", address)
×
4718
                }
×
4719

4720
                return &lnwire.OpaqueAddrs{
×
4721
                        Payload: opaque,
×
4722
                }, nil
×
4723

4724
        default:
×
4725
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
4726
        }
4727
}
4728

4729
// batchNodeData holds all the related data for a batch of nodes.
4730
type batchNodeData struct {
4731
        // features is a map from a DB node ID to the feature bits for that
4732
        // node.
4733
        features map[int64][]int
4734

4735
        // addresses is a map from a DB node ID to the node's addresses.
4736
        addresses map[int64][]nodeAddress
4737

4738
        // extraFields is a map from a DB node ID to the extra signed fields
4739
        // for that node.
4740
        extraFields map[int64]map[uint64][]byte
4741
}
4742

4743
// nodeAddress holds the address type, position and address string for a
4744
// node. This is used to batch the fetching of node addresses.
4745
type nodeAddress struct {
4746
        addrType dbAddressType
4747
        position int32
4748
        address  string
4749
}
4750

4751
// batchLoadNodeData loads all related data for a batch of node IDs using the
4752
// provided SQLQueries interface. It returns a batchNodeData instance containing
4753
// the node features, addresses and extra signed fields.
4754
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
4755
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
4756

×
4757
        // Batch load the node features.
×
4758
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
4759
        if err != nil {
×
4760
                return nil, fmt.Errorf("unable to batch load node "+
×
4761
                        "features: %w", err)
×
4762
        }
×
4763

4764
        // Batch load the node addresses.
4765
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
4766
        if err != nil {
×
4767
                return nil, fmt.Errorf("unable to batch load node "+
×
4768
                        "addresses: %w", err)
×
4769
        }
×
4770

4771
        // Batch load the node extra signed fields.
4772
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
4773
        if err != nil {
×
4774
                return nil, fmt.Errorf("unable to batch load node extra "+
×
4775
                        "signed fields: %w", err)
×
4776
        }
×
4777

4778
        return &batchNodeData{
×
4779
                features:    features,
×
4780
                addresses:   addrs,
×
4781
                extraFields: extraTypes,
×
4782
        }, nil
×
4783
}
4784

4785
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
4786
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
4787
func batchLoadNodeFeaturesHelper(ctx context.Context,
4788
        cfg *sqldb.QueryConfig, db SQLQueries,
4789
        nodeIDs []int64) (map[int64][]int, error) {
×
4790

×
4791
        features := make(map[int64][]int)
×
4792

×
NEW
4793
        return features, sqldb.ExecuteBatchQuery(
×
4794
                ctx, cfg, nodeIDs,
×
4795
                func(id int64) int64 {
×
4796
                        return id
×
4797
                },
×
4798
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
4799
                        error) {
×
4800

×
4801
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
4802
                },
×
4803
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
4804
                        features[feature.NodeID] = append(
×
4805
                                features[feature.NodeID],
×
4806
                                int(feature.FeatureBit),
×
4807
                        )
×
4808

×
4809
                        return nil
×
4810
                },
×
4811
        )
4812
}
4813

4814
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
4815
// wrapper around the GetNodeAddressesBatch query. It returns a map from
4816
// node ID to a slice of nodeAddress structs.
4817
func batchLoadNodeAddressesHelper(ctx context.Context,
4818
        cfg *sqldb.QueryConfig, db SQLQueries,
4819
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
4820

×
4821
        addrs := make(map[int64][]nodeAddress)
×
4822

×
NEW
4823
        return addrs, sqldb.ExecuteBatchQuery(
×
4824
                ctx, cfg, nodeIDs,
×
4825
                func(id int64) int64 {
×
4826
                        return id
×
4827
                },
×
4828
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
4829
                        error) {
×
4830

×
4831
                        return db.GetNodeAddressesBatch(ctx, ids)
×
4832
                },
×
4833
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
4834
                        addrs[addr.NodeID] = append(
×
4835
                                addrs[addr.NodeID], nodeAddress{
×
4836
                                        addrType: dbAddressType(addr.Type),
×
4837
                                        position: addr.Position,
×
4838
                                        address:  addr.Address,
×
4839
                                },
×
4840
                        )
×
4841

×
4842
                        return nil
×
4843
                },
×
4844
        )
4845
}
4846

4847
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
4848
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
4849
// query.
4850
func batchLoadNodeExtraTypesHelper(ctx context.Context,
4851
        cfg *sqldb.QueryConfig, db SQLQueries,
4852
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4853

×
4854
        extraFields := make(map[int64]map[uint64][]byte)
×
4855

×
4856
        callback := func(ctx context.Context,
×
4857
                field sqlc.GraphNodeExtraType) error {
×
4858

×
4859
                if extraFields[field.NodeID] == nil {
×
4860
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
4861
                }
×
4862
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
4863

×
4864
                return nil
×
4865
        }
4866

NEW
4867
        return extraFields, sqldb.ExecuteBatchQuery(
×
4868
                ctx, cfg, nodeIDs,
×
4869
                func(id int64) int64 {
×
4870
                        return id
×
4871
                },
×
4872
                func(ctx context.Context, ids []int64) (
4873
                        []sqlc.GraphNodeExtraType, error) {
×
4874

×
4875
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
4876
                },
×
4877
                callback,
4878
        )
4879
}
4880

4881
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
4882
// from the provided sqlc.GraphChannelPolicy records and the
4883
// provided batchChannelData.
4884
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4885
        channelID uint64, node1, node2 route.Vertex,
4886
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
4887
        *models.ChannelEdgePolicy, error) {
×
4888

×
4889
        pol1, err := buildChanPolicyWithBatchData(
×
4890
                dbPol1, channelID, node2, batchData,
×
4891
        )
×
4892
        if err != nil {
×
4893
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4894
        }
×
4895

4896
        pol2, err := buildChanPolicyWithBatchData(
×
4897
                dbPol2, channelID, node1, batchData,
×
4898
        )
×
4899
        if err != nil {
×
4900
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4901
        }
×
4902

4903
        return pol1, pol2, nil
×
4904
}
4905

4906
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
4907
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
4908
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
4909
        channelID uint64, toNode route.Vertex,
4910
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
4911

×
4912
        if dbPol == nil {
×
4913
                return nil, nil
×
4914
        }
×
4915

4916
        var dbPol1Extras map[uint64][]byte
×
4917
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
4918
                dbPol1Extras = extras
×
4919
        } else {
×
4920
                dbPol1Extras = make(map[uint64][]byte)
×
4921
        }
×
4922

4923
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
4924
}
4925

4926
// batchChannelData holds all the related data for a batch of channels.
4927
type batchChannelData struct {
4928
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
4929
        chanfeatures map[int64][]int
4930

4931
        // chanExtras is a map from DB channel ID to a map of TLV type to
4932
        // extra signed field bytes.
4933
        chanExtraTypes map[int64]map[uint64][]byte
4934

4935
        // policyExtras is a map from DB channel policy ID to a map of TLV type
4936
        // to extra signed field bytes.
4937
        policyExtras map[int64]map[uint64][]byte
4938
}
4939

4940
// batchLoadChannelData loads all related data for batches of channels and
4941
// policies.
4942
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
4943
        db SQLQueries, channelIDs []int64,
4944
        policyIDs []int64) (*batchChannelData, error) {
×
4945

×
4946
        batchData := &batchChannelData{
×
4947
                chanfeatures:   make(map[int64][]int),
×
4948
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
4949
                policyExtras:   make(map[int64]map[uint64][]byte),
×
4950
        }
×
4951

×
4952
        // Batch load channel features and extras
×
4953
        var err error
×
4954
        if len(channelIDs) > 0 {
×
4955
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
4956
                        ctx, cfg, db, channelIDs,
×
4957
                )
×
4958
                if err != nil {
×
4959
                        return nil, fmt.Errorf("unable to batch load "+
×
4960
                                "channel features: %w", err)
×
4961
                }
×
4962

4963
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
4964
                        ctx, cfg, db, channelIDs,
×
4965
                )
×
4966
                if err != nil {
×
4967
                        return nil, fmt.Errorf("unable to batch load "+
×
4968
                                "channel extras: %w", err)
×
4969
                }
×
4970
        }
4971

4972
        if len(policyIDs) > 0 {
×
4973
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
4974
                        ctx, cfg, db, policyIDs,
×
4975
                )
×
4976
                if err != nil {
×
4977
                        return nil, fmt.Errorf("unable to batch load "+
×
4978
                                "policy extras: %w", err)
×
4979
                }
×
4980
                batchData.policyExtras = policyExtras
×
4981
        }
4982

4983
        return batchData, nil
×
4984
}
4985

4986
// batchLoadChannelFeaturesHelper loads channel features for a batch of
4987
// channel IDs using ExecuteBatchQuery wrapper around the
4988
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
4989
// slice of feature bits.
4990
func batchLoadChannelFeaturesHelper(ctx context.Context,
4991
        cfg *sqldb.QueryConfig, db SQLQueries,
4992
        channelIDs []int64) (map[int64][]int, error) {
×
4993

×
4994
        features := make(map[int64][]int)
×
4995

×
NEW
4996
        return features, sqldb.ExecuteBatchQuery(
×
4997
                ctx, cfg, channelIDs,
×
4998
                func(id int64) int64 {
×
4999
                        return id
×
5000
                },
×
5001
                func(ctx context.Context,
5002
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
5003

×
5004
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
5005
                },
×
5006
                func(ctx context.Context,
5007
                        feature sqlc.GraphChannelFeature) error {
×
5008

×
5009
                        features[feature.ChannelID] = append(
×
5010
                                features[feature.ChannelID],
×
5011
                                int(feature.FeatureBit),
×
5012
                        )
×
5013

×
5014
                        return nil
×
5015
                },
×
5016
        )
5017
}
5018

5019
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
5020
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
5021
// query. It returns a map from DB channel ID to a map of TLV type to extra
5022
// signed field bytes.
5023
func batchLoadChannelExtrasHelper(ctx context.Context,
5024
        cfg *sqldb.QueryConfig, db SQLQueries,
5025
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5026

×
5027
        extras := make(map[int64]map[uint64][]byte)
×
5028

×
5029
        cb := func(ctx context.Context,
×
5030
                extra sqlc.GraphChannelExtraType) error {
×
5031

×
5032
                if extras[extra.ChannelID] == nil {
×
5033
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5034
                }
×
5035
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
5036

×
5037
                return nil
×
5038
        }
5039

NEW
5040
        return extras, sqldb.ExecuteBatchQuery(
×
5041
                ctx, cfg, channelIDs,
×
5042
                func(id int64) int64 {
×
5043
                        return id
×
5044
                },
×
5045
                func(ctx context.Context,
5046
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5047

×
5048
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5049
                }, cb,
×
5050
        )
5051
}
5052

5053
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5054
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5055
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5056
// a map of TLV type to extra signed field bytes.
5057
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5058
        cfg *sqldb.QueryConfig, db SQLQueries,
5059
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5060

×
5061
        extras := make(map[int64]map[uint64][]byte)
×
5062

×
NEW
5063
        return extras, sqldb.ExecuteBatchQuery(
×
5064
                ctx, cfg, policyIDs,
×
5065
                func(id int64) int64 {
×
5066
                        return id
×
5067
                },
×
5068
                func(ctx context.Context, ids []int64) (
5069
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5070

×
5071
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5072
                },
×
5073
                func(ctx context.Context,
5074
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5075

×
5076
                        if extras[row.PolicyID] == nil {
×
5077
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5078
                        }
×
5079
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5080

×
5081
                        return nil
×
5082
                },
5083
        )
5084
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc