• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16629470050

30 Jul 2025 05:27PM UTC coverage: 67.181% (+0.1%) from 67.054%
16629470050

push

github

web-flow
Merge pull request #10113 from ellemouton/graphPerf2

[1] graph/db: add some SQL performance improvements

0 of 123 new or added lines in 2 files covered. (0.0%)

45 existing lines in 17 files now uncovered.

135567 of 201794 relevant lines covered (67.18%)

21670.7 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
72
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
73
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
74

75
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
96
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
97
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
98
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
99
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
100
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
101
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
102
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
103
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
104
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
105
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
106
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
107
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
108
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
109
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
110
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
111
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
112
        DeleteChannels(ctx context.Context, ids []int64) error
113

114
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
115
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
116

117
        /*
118
                Channel Policy table queries.
119
        */
120
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
121
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
122
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
123

124
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
125
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
126
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
127

128
        /*
129
                Zombie index queries.
130
        */
131
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
132
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
133
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
134
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
135
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
136

137
        /*
138
                Prune log table queries.
139
        */
140
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
141
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
142
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
143
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
144

145
        /*
146
                Closed SCID table queries.
147
        */
148
        InsertClosedChannel(ctx context.Context, scid []byte) error
149
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
150
}
151

152
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
153
// database operations.
154
type BatchedSQLQueries interface {
155
        SQLQueries
156
        sqldb.BatchedTx[SQLQueries]
157
}
158

159
// SQLStore is an implementation of the V1Store interface that uses a SQL
160
// database as the backend.
161
type SQLStore struct {
162
        cfg *SQLStoreConfig
163
        db  BatchedSQLQueries
164

165
        // cacheMu guards all caches (rejectCache and chanCache). If
166
        // this mutex will be acquired at the same time as the DB mutex then
167
        // the cacheMu MUST be acquired first to prevent deadlock.
168
        cacheMu     sync.RWMutex
169
        rejectCache *rejectCache
170
        chanCache   *channelCache
171

172
        chanScheduler batch.Scheduler[SQLQueries]
173
        nodeScheduler batch.Scheduler[SQLQueries]
174

175
        srcNodes  map[ProtocolVersion]*srcNodeInfo
176
        srcNodeMu sync.Mutex
177
}
178

179
// A compile-time assertion to ensure that SQLStore implements the V1Store
180
// interface.
181
var _ V1Store = (*SQLStore)(nil)
182

183
// SQLStoreConfig holds the configuration for the SQLStore.
184
type SQLStoreConfig struct {
185
        // ChainHash is the genesis hash for the chain that all the gossip
186
        // messages in this store are aimed at.
187
        ChainHash chainhash.Hash
188

189
        // PaginationCfg is the configuration for paginated queries.
190
        PaginationCfg *sqldb.PagedQueryConfig
191
}
192

193
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
194
// storage backend.
195
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
196
        options ...StoreOptionModifier) (*SQLStore, error) {
×
197

×
198
        opts := DefaultOptions()
×
199
        for _, o := range options {
×
200
                o(opts)
×
201
        }
×
202

203
        if opts.NoMigration {
×
204
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
205
                        "supported for SQL stores")
×
206
        }
×
207

208
        s := &SQLStore{
×
209
                cfg:         cfg,
×
210
                db:          db,
×
211
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
212
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
213
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
214
        }
×
215

×
216
        s.chanScheduler = batch.NewTimeScheduler(
×
217
                db, &s.cacheMu, opts.BatchCommitInterval,
×
218
        )
×
219
        s.nodeScheduler = batch.NewTimeScheduler(
×
220
                db, nil, opts.BatchCommitInterval,
×
221
        )
×
222

×
223
        return s, nil
×
224
}
225

226
// AddLightningNode adds a vertex/node to the graph database. If the node is not
227
// in the database from before, this will add a new, unconnected one to the
228
// graph. If it is present from before, this will update that node's
229
// information.
230
//
231
// NOTE: part of the V1Store interface.
232
func (s *SQLStore) AddLightningNode(ctx context.Context,
233
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
234

×
235
        r := &batch.Request[SQLQueries]{
×
236
                Opts: batch.NewSchedulerOptions(opts...),
×
237
                Do: func(queries SQLQueries) error {
×
238
                        _, err := upsertNode(ctx, queries, node)
×
239
                        return err
×
240
                },
×
241
        }
242

243
        return s.nodeScheduler.Execute(ctx, r)
×
244
}
245

246
// FetchLightningNode attempts to look up a target node by its identity public
247
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
248
// returned.
249
//
250
// NOTE: part of the V1Store interface.
251
func (s *SQLStore) FetchLightningNode(ctx context.Context,
252
        pubKey route.Vertex) (*models.LightningNode, error) {
×
253

×
254
        var node *models.LightningNode
×
255
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
256
                var err error
×
257
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
258

×
259
                return err
×
260
        }, sqldb.NoOpReset)
×
261
        if err != nil {
×
262
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
263
        }
×
264

265
        return node, nil
×
266
}
267

268
// HasLightningNode determines if the graph has a vertex identified by the
269
// target node identity public key. If the node exists in the database, a
270
// timestamp of when the data for the node was lasted updated is returned along
271
// with a true boolean. Otherwise, an empty time.Time is returned with a false
272
// boolean.
273
//
274
// NOTE: part of the V1Store interface.
275
func (s *SQLStore) HasLightningNode(ctx context.Context,
276
        pubKey [33]byte) (time.Time, bool, error) {
×
277

×
278
        var (
×
279
                exists     bool
×
280
                lastUpdate time.Time
×
281
        )
×
282
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
283
                dbNode, err := db.GetNodeByPubKey(
×
284
                        ctx, sqlc.GetNodeByPubKeyParams{
×
285
                                Version: int16(ProtocolV1),
×
286
                                PubKey:  pubKey[:],
×
287
                        },
×
288
                )
×
289
                if errors.Is(err, sql.ErrNoRows) {
×
290
                        return nil
×
291
                } else if err != nil {
×
292
                        return fmt.Errorf("unable to fetch node: %w", err)
×
293
                }
×
294

295
                exists = true
×
296

×
297
                if dbNode.LastUpdate.Valid {
×
298
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
299
                }
×
300

301
                return nil
×
302
        }, sqldb.NoOpReset)
303
        if err != nil {
×
304
                return time.Time{}, false,
×
305
                        fmt.Errorf("unable to fetch node: %w", err)
×
306
        }
×
307

308
        return lastUpdate, exists, nil
×
309
}
310

311
// AddrsForNode returns all known addresses for the target node public key
312
// that the graph DB is aware of. The returned boolean indicates if the
313
// given node is unknown to the graph DB or not.
314
//
315
// NOTE: part of the V1Store interface.
316
func (s *SQLStore) AddrsForNode(ctx context.Context,
317
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
318

×
319
        var (
×
320
                addresses []net.Addr
×
321
                known     bool
×
322
        )
×
323
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
324
                // First, check if the node exists and get its DB ID if it
×
NEW
325
                // does.
×
NEW
326
                dbID, err := db.GetNodeIDByPubKey(
×
NEW
327
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
NEW
328
                                Version: int16(ProtocolV1),
×
NEW
329
                                PubKey:  nodePub.SerializeCompressed(),
×
NEW
330
                        },
×
331
                )
×
NEW
332
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
333
                        return nil
×
NEW
334
                }
×
335

NEW
336
                known = true
×
NEW
337

×
NEW
338
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
339
                if err != nil {
×
340
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
341
                                err)
×
342
                }
×
343

344
                return nil
×
345
        }, sqldb.NoOpReset)
346
        if err != nil {
×
347
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
348
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
349
        }
×
350

351
        return known, addresses, nil
×
352
}
353

354
// DeleteLightningNode starts a new database transaction to remove a vertex/node
355
// from the database according to the node's public key.
356
//
357
// NOTE: part of the V1Store interface.
358
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
359
        pubKey route.Vertex) error {
×
360

×
361
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
362
                res, err := db.DeleteNodeByPubKey(
×
363
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
364
                                Version: int16(ProtocolV1),
×
365
                                PubKey:  pubKey[:],
×
366
                        },
×
367
                )
×
368
                if err != nil {
×
369
                        return err
×
370
                }
×
371

372
                rows, err := res.RowsAffected()
×
373
                if err != nil {
×
374
                        return err
×
375
                }
×
376

377
                if rows == 0 {
×
378
                        return ErrGraphNodeNotFound
×
379
                } else if rows > 1 {
×
380
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
381
                }
×
382

383
                return err
×
384
        }, sqldb.NoOpReset)
385
        if err != nil {
×
386
                return fmt.Errorf("unable to delete node: %w", err)
×
387
        }
×
388

389
        return nil
×
390
}
391

392
// FetchNodeFeatures returns the features of the given node. If no features are
393
// known for the node, an empty feature vector is returned.
394
//
395
// NOTE: this is part of the graphdb.NodeTraverser interface.
396
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
397
        *lnwire.FeatureVector, error) {
×
398

×
399
        ctx := context.TODO()
×
400

×
401
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
402
}
×
403

404
// DisabledChannelIDs returns the channel ids of disabled channels.
405
// A channel is disabled when two of the associated ChanelEdgePolicies
406
// have their disabled bit on.
407
//
408
// NOTE: part of the V1Store interface.
409
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
410
        var (
×
411
                ctx     = context.TODO()
×
412
                chanIDs []uint64
×
413
        )
×
414
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
415
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
416
                if err != nil {
×
417
                        return fmt.Errorf("unable to fetch disabled "+
×
418
                                "channels: %w", err)
×
419
                }
×
420

421
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
422

×
423
                return nil
×
424
        }, sqldb.NoOpReset)
425
        if err != nil {
×
426
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
427
                        err)
×
428
        }
×
429

430
        return chanIDs, nil
×
431
}
432

433
// LookupAlias attempts to return the alias as advertised by the target node.
434
//
435
// NOTE: part of the V1Store interface.
436
func (s *SQLStore) LookupAlias(ctx context.Context,
437
        pub *btcec.PublicKey) (string, error) {
×
438

×
439
        var alias string
×
440
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
441
                dbNode, err := db.GetNodeByPubKey(
×
442
                        ctx, sqlc.GetNodeByPubKeyParams{
×
443
                                Version: int16(ProtocolV1),
×
444
                                PubKey:  pub.SerializeCompressed(),
×
445
                        },
×
446
                )
×
447
                if errors.Is(err, sql.ErrNoRows) {
×
448
                        return ErrNodeAliasNotFound
×
449
                } else if err != nil {
×
450
                        return fmt.Errorf("unable to fetch node: %w", err)
×
451
                }
×
452

453
                if !dbNode.Alias.Valid {
×
454
                        return ErrNodeAliasNotFound
×
455
                }
×
456

457
                alias = dbNode.Alias.String
×
458

×
459
                return nil
×
460
        }, sqldb.NoOpReset)
461
        if err != nil {
×
462
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
463
        }
×
464

465
        return alias, nil
×
466
}
467

468
// SourceNode returns the source node of the graph. The source node is treated
469
// as the center node within a star-graph. This method may be used to kick off
470
// a path finding algorithm in order to explore the reachability of another
471
// node based off the source node.
472
//
473
// NOTE: part of the V1Store interface.
474
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
475
        error) {
×
476

×
477
        var node *models.LightningNode
×
478
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
479
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
480
                if err != nil {
×
481
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
482
                                err)
×
483
                }
×
484

485
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
486

×
487
                return err
×
488
        }, sqldb.NoOpReset)
489
        if err != nil {
×
490
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
491
        }
×
492

493
        return node, nil
×
494
}
495

496
// SetSourceNode sets the source node within the graph database. The source
497
// node is to be used as the center of a star-graph within path finding
498
// algorithms.
499
//
500
// NOTE: part of the V1Store interface.
501
func (s *SQLStore) SetSourceNode(ctx context.Context,
502
        node *models.LightningNode) error {
×
503

×
504
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
505
                id, err := upsertNode(ctx, db, node)
×
506
                if err != nil {
×
507
                        return fmt.Errorf("unable to upsert source node: %w",
×
508
                                err)
×
509
                }
×
510

511
                // Make sure that if a source node for this version is already
512
                // set, then the ID is the same as the one we are about to set.
513
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
514
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
515
                        return fmt.Errorf("unable to fetch source node: %w",
×
516
                                err)
×
517
                } else if err == nil {
×
518
                        if dbSourceNodeID != id {
×
519
                                return fmt.Errorf("v1 source node already "+
×
520
                                        "set to a different node: %d vs %d",
×
521
                                        dbSourceNodeID, id)
×
522
                        }
×
523

524
                        return nil
×
525
                }
526

527
                return db.AddSourceNode(ctx, id)
×
528
        }, sqldb.NoOpReset)
529
}
530

531
// NodeUpdatesInHorizon returns all the known lightning node which have an
532
// update timestamp within the passed range. This method can be used by two
533
// nodes to quickly determine if they have the same set of up to date node
534
// announcements.
535
//
536
// NOTE: This is part of the V1Store interface.
537
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
538
        endTime time.Time) ([]models.LightningNode, error) {
×
539

×
540
        ctx := context.TODO()
×
541

×
542
        var nodes []models.LightningNode
×
543
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
544
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
545
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
546
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
547
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
548
                        },
×
549
                )
×
550
                if err != nil {
×
551
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
552
                }
×
553

554
                for _, dbNode := range dbNodes {
×
555
                        node, err := buildNode(ctx, db, &dbNode)
×
556
                        if err != nil {
×
557
                                return fmt.Errorf("unable to build node: %w",
×
558
                                        err)
×
559
                        }
×
560

561
                        nodes = append(nodes, *node)
×
562
                }
563

564
                return nil
×
565
        }, sqldb.NoOpReset)
566
        if err != nil {
×
567
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
568
        }
×
569

570
        return nodes, nil
×
571
}
572

573
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
574
// undirected edge from the two target nodes are created. The information stored
575
// denotes the static attributes of the channel, such as the channelID, the keys
576
// involved in creation of the channel, and the set of features that the channel
577
// supports. The chanPoint and chanID are used to uniquely identify the edge
578
// globally within the database.
579
//
580
// NOTE: part of the V1Store interface.
581
func (s *SQLStore) AddChannelEdge(ctx context.Context,
582
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
583

×
584
        var alreadyExists bool
×
585
        r := &batch.Request[SQLQueries]{
×
586
                Opts: batch.NewSchedulerOptions(opts...),
×
587
                Reset: func() {
×
588
                        alreadyExists = false
×
589
                },
×
590
                Do: func(tx SQLQueries) error {
×
591
                        _, err := insertChannel(ctx, tx, edge)
×
592

×
593
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
594
                        // succeed, but propagate the error via local state.
×
595
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
596
                                alreadyExists = true
×
597
                                return nil
×
598
                        }
×
599

600
                        return err
×
601
                },
602
                OnCommit: func(err error) error {
×
603
                        switch {
×
604
                        case err != nil:
×
605
                                return err
×
606
                        case alreadyExists:
×
607
                                return ErrEdgeAlreadyExist
×
608
                        default:
×
609
                                s.rejectCache.remove(edge.ChannelID)
×
610
                                s.chanCache.remove(edge.ChannelID)
×
611
                                return nil
×
612
                        }
613
                },
614
        }
615

616
        return s.chanScheduler.Execute(ctx, r)
×
617
}
618

619
// HighestChanID returns the "highest" known channel ID in the channel graph.
620
// This represents the "newest" channel from the PoV of the chain. This method
621
// can be used by peers to quickly determine if their graphs are in sync.
622
//
623
// NOTE: This is part of the V1Store interface.
624
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
625
        var highestChanID uint64
×
626
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
627
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
628
                if errors.Is(err, sql.ErrNoRows) {
×
629
                        return nil
×
630
                } else if err != nil {
×
631
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
632
                                err)
×
633
                }
×
634

635
                highestChanID = byteOrder.Uint64(chanID)
×
636

×
637
                return nil
×
638
        }, sqldb.NoOpReset)
639
        if err != nil {
×
640
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
641
        }
×
642

643
        return highestChanID, nil
×
644
}
645

646
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
647
// within the database for the referenced channel. The `flags` attribute within
648
// the ChannelEdgePolicy determines which of the directed edges are being
649
// updated. If the flag is 1, then the first node's information is being
650
// updated, otherwise it's the second node's information. The node ordering is
651
// determined by the lexicographical ordering of the identity public keys of the
652
// nodes on either side of the channel.
653
//
654
// NOTE: part of the V1Store interface.
655
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
656
        edge *models.ChannelEdgePolicy,
657
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
658

×
659
        var (
×
660
                isUpdate1    bool
×
661
                edgeNotFound bool
×
662
                from, to     route.Vertex
×
663
        )
×
664

×
665
        r := &batch.Request[SQLQueries]{
×
666
                Opts: batch.NewSchedulerOptions(opts...),
×
667
                Reset: func() {
×
668
                        isUpdate1 = false
×
669
                        edgeNotFound = false
×
670
                },
×
671
                Do: func(tx SQLQueries) error {
×
672
                        var err error
×
673
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
674
                                ctx, tx, edge,
×
675
                        )
×
676
                        if err != nil {
×
677
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
678
                        }
×
679

680
                        // Silence ErrEdgeNotFound so that the batch can
681
                        // succeed, but propagate the error via local state.
682
                        if errors.Is(err, ErrEdgeNotFound) {
×
683
                                edgeNotFound = true
×
684
                                return nil
×
685
                        }
×
686

687
                        return err
×
688
                },
689
                OnCommit: func(err error) error {
×
690
                        switch {
×
691
                        case err != nil:
×
692
                                return err
×
693
                        case edgeNotFound:
×
694
                                return ErrEdgeNotFound
×
695
                        default:
×
696
                                s.updateEdgeCache(edge, isUpdate1)
×
697
                                return nil
×
698
                        }
699
                },
700
        }
701

702
        err := s.chanScheduler.Execute(ctx, r)
×
703

×
704
        return from, to, err
×
705
}
706

707
// updateEdgeCache updates our reject and channel caches with the new
708
// edge policy information.
709
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
710
        isUpdate1 bool) {
×
711

×
712
        // If an entry for this channel is found in reject cache, we'll modify
×
713
        // the entry with the updated timestamp for the direction that was just
×
714
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
715
        // during the next query for this edge.
×
716
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
717
                if isUpdate1 {
×
718
                        entry.upd1Time = e.LastUpdate.Unix()
×
719
                } else {
×
720
                        entry.upd2Time = e.LastUpdate.Unix()
×
721
                }
×
722
                s.rejectCache.insert(e.ChannelID, entry)
×
723
        }
724

725
        // If an entry for this channel is found in channel cache, we'll modify
726
        // the entry with the updated policy for the direction that was just
727
        // written. If the edge doesn't exist, we'll defer loading the info and
728
        // policies and lazily read from disk during the next query.
729
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
730
                if isUpdate1 {
×
731
                        channel.Policy1 = e
×
732
                } else {
×
733
                        channel.Policy2 = e
×
734
                }
×
735
                s.chanCache.insert(e.ChannelID, channel)
×
736
        }
737
}
738

739
// ForEachSourceNodeChannel iterates through all channels of the source node,
740
// executing the passed callback on each. The call-back is provided with the
741
// channel's outpoint, whether we have a policy for the channel and the channel
742
// peer's node information.
743
//
744
// NOTE: part of the V1Store interface.
745
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
746
        cb func(chanPoint wire.OutPoint, havePolicy bool,
747
                otherNode *models.LightningNode) error, reset func()) error {
×
748

×
749
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
750
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
751
                if err != nil {
×
752
                        return fmt.Errorf("unable to fetch source node: %w",
×
753
                                err)
×
754
                }
×
755

756
                return forEachNodeChannel(
×
757
                        ctx, db, s.cfg.ChainHash, nodeID,
×
758
                        func(info *models.ChannelEdgeInfo,
×
759
                                outPolicy *models.ChannelEdgePolicy,
×
760
                                _ *models.ChannelEdgePolicy) error {
×
761

×
762
                                // Fetch the other node.
×
763
                                var (
×
764
                                        otherNodePub [33]byte
×
765
                                        node1        = info.NodeKey1Bytes
×
766
                                        node2        = info.NodeKey2Bytes
×
767
                                )
×
768
                                switch {
×
769
                                case bytes.Equal(node1[:], nodePub[:]):
×
770
                                        otherNodePub = node2
×
771
                                case bytes.Equal(node2[:], nodePub[:]):
×
772
                                        otherNodePub = node1
×
773
                                default:
×
774
                                        return fmt.Errorf("node not " +
×
775
                                                "participating in this channel")
×
776
                                }
777

778
                                _, otherNode, err := getNodeByPubKey(
×
779
                                        ctx, db, otherNodePub,
×
780
                                )
×
781
                                if err != nil {
×
782
                                        return fmt.Errorf("unable to fetch "+
×
783
                                                "other node(%x): %w",
×
784
                                                otherNodePub, err)
×
785
                                }
×
786

787
                                return cb(
×
788
                                        info.ChannelPoint, outPolicy != nil,
×
789
                                        otherNode,
×
790
                                )
×
791
                        },
792
                )
793
        }, reset)
794
}
795

796
// ForEachNode iterates through all the stored vertices/nodes in the graph,
797
// executing the passed callback with each node encountered. If the callback
798
// returns an error, then the transaction is aborted and the iteration stops
799
// early. Any operations performed on the NodeTx passed to the call-back are
800
// executed under the same read transaction and so, methods on the NodeTx object
801
// _MUST_ only be called from within the call-back.
802
//
803
// NOTE: part of the V1Store interface.
804
func (s *SQLStore) ForEachNode(ctx context.Context,
805
        cb func(tx NodeRTx) error, reset func()) error {
×
806

×
807
        var lastID int64 = 0
×
808
        handleNode := func(db SQLQueries, dbNode sqlc.GraphNode) error {
×
809
                node, err := buildNode(ctx, db, &dbNode)
×
810
                if err != nil {
×
811
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
812
                                dbNode.ID, err)
×
813
                }
×
814

815
                err = cb(
×
816
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
817
                )
×
818
                if err != nil {
×
819
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
820
                                dbNode.ID, err)
×
821
                }
×
822

823
                return nil
×
824
        }
825

826
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
827
                for {
×
828
                        nodes, err := db.ListNodesPaginated(
×
829
                                ctx, sqlc.ListNodesPaginatedParams{
×
830
                                        Version: int16(ProtocolV1),
×
831
                                        ID:      lastID,
×
832
                                        Limit:   pageSize,
×
833
                                },
×
834
                        )
×
835
                        if err != nil {
×
836
                                return fmt.Errorf("unable to fetch nodes: %w",
×
837
                                        err)
×
838
                        }
×
839

840
                        if len(nodes) == 0 {
×
841
                                break
×
842
                        }
843

844
                        for _, dbNode := range nodes {
×
845
                                err = handleNode(db, dbNode)
×
846
                                if err != nil {
×
847
                                        return err
×
848
                                }
×
849

850
                                lastID = dbNode.ID
×
851
                        }
852
                }
853

854
                return nil
×
855
        }, reset)
856
}
857

858
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
859
// SQLStore and a SQL transaction.
860
type sqlGraphNodeTx struct {
861
        db    SQLQueries
862
        id    int64
863
        node  *models.LightningNode
864
        chain chainhash.Hash
865
}
866

867
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
868
// interface.
869
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
870

871
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
872
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
873

×
874
        return &sqlGraphNodeTx{
×
875
                db:    db,
×
876
                chain: chain,
×
877
                id:    id,
×
878
                node:  node,
×
879
        }
×
880
}
×
881

882
// Node returns the raw information of the node.
883
//
884
// NOTE: This is a part of the NodeRTx interface.
885
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
886
        return s.node
×
887
}
×
888

889
// ForEachChannel can be used to iterate over the node's channels under the same
890
// transaction used to fetch the node.
891
//
892
// NOTE: This is a part of the NodeRTx interface.
893
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
894
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
895

×
896
        ctx := context.TODO()
×
897

×
898
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
899
}
×
900

901
// FetchNode fetches the node with the given pub key under the same transaction
902
// used to fetch the current node. The returned node is also a NodeRTx and any
903
// operations on that NodeRTx will also be done under the same transaction.
904
//
905
// NOTE: This is a part of the NodeRTx interface.
906
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
907
        ctx := context.TODO()
×
908

×
909
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
910
        if err != nil {
×
911
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
912
                        nodePub, err)
×
913
        }
×
914

915
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
916
}
917

918
// ForEachNodeDirectedChannel iterates through all channels of a given node,
919
// executing the passed callback on the directed edge representing the channel
920
// and its incoming policy. If the callback returns an error, then the iteration
921
// is halted with the error propagated back up to the caller.
922
//
923
// Unknown policies are passed into the callback as nil values.
924
//
925
// NOTE: this is part of the graphdb.NodeTraverser interface.
926
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
927
        cb func(channel *DirectedChannel) error, reset func()) error {
×
928

×
929
        var ctx = context.TODO()
×
930

×
931
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
932
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
933
        }, reset)
×
934
}
935

936
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
937
// graph, executing the passed callback with each node encountered. If the
938
// callback returns an error, then the transaction is aborted and the iteration
939
// stops early.
940
//
941
// NOTE: This is a part of the V1Store interface.
942
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
943
        cb func(route.Vertex, *lnwire.FeatureVector) error,
944
        reset func()) error {
×
945

×
946
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
947
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
948
                        nodePub route.Vertex) error {
×
949

×
950
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
951
                        if err != nil {
×
952
                                return fmt.Errorf("unable to fetch node "+
×
953
                                        "features: %w", err)
×
954
                        }
×
955

956
                        return cb(nodePub, features)
×
957
                })
958
        }, reset)
959
        if err != nil {
×
960
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
961
        }
×
962

963
        return nil
×
964
}
965

966
// ForEachNodeChannel iterates through all channels of the given node,
967
// executing the passed callback with an edge info structure and the policies
968
// of each end of the channel. The first edge policy is the outgoing edge *to*
969
// the connecting node, while the second is the incoming edge *from* the
970
// connecting node. If the callback returns an error, then the iteration is
971
// halted with the error propagated back up to the caller.
972
//
973
// Unknown policies are passed into the callback as nil values.
974
//
975
// NOTE: part of the V1Store interface.
976
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
977
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
978
                *models.ChannelEdgePolicy) error, reset func()) error {
×
979

×
980
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
981
                dbNode, err := db.GetNodeByPubKey(
×
982
                        ctx, sqlc.GetNodeByPubKeyParams{
×
983
                                Version: int16(ProtocolV1),
×
984
                                PubKey:  nodePub[:],
×
985
                        },
×
986
                )
×
987
                if errors.Is(err, sql.ErrNoRows) {
×
988
                        return nil
×
989
                } else if err != nil {
×
990
                        return fmt.Errorf("unable to fetch node: %w", err)
×
991
                }
×
992

993
                return forEachNodeChannel(
×
994
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
995
                )
×
996
        }, reset)
997
}
998

999
// ChanUpdatesInHorizon returns all the known channel edges which have at least
1000
// one edge that has an update timestamp within the specified horizon.
1001
//
1002
// NOTE: This is part of the V1Store interface.
1003
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
1004
        endTime time.Time) ([]ChannelEdge, error) {
×
1005

×
1006
        s.cacheMu.Lock()
×
1007
        defer s.cacheMu.Unlock()
×
1008

×
1009
        var (
×
1010
                ctx = context.TODO()
×
1011
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
1012
                // an additional map to keep track of the edges already seen to
×
1013
                // prevent re-adding it.
×
1014
                edgesSeen    = make(map[uint64]struct{})
×
1015
                edgesToCache = make(map[uint64]ChannelEdge)
×
1016
                edges        []ChannelEdge
×
1017
                hits         int
×
1018
        )
×
1019
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1020
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1021
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1022
                                Version:   int16(ProtocolV1),
×
1023
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1024
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1025
                        },
×
1026
                )
×
1027
                if err != nil {
×
1028
                        return err
×
1029
                }
×
1030

1031
                for _, row := range rows {
×
1032
                        // If we've already retrieved the info and policies for
×
1033
                        // this edge, then we can skip it as we don't need to do
×
1034
                        // so again.
×
1035
                        chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid)
×
1036
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1037
                                continue
×
1038
                        }
1039

1040
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1041
                                hits++
×
1042
                                edgesSeen[chanIDInt] = struct{}{}
×
1043
                                edges = append(edges, channel)
×
1044

×
1045
                                continue
×
1046
                        }
1047

1048
                        node1, node2, err := buildNodes(
×
1049
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
1050
                        )
×
1051
                        if err != nil {
×
1052
                                return err
×
1053
                        }
×
1054

1055
                        channel, err := getAndBuildEdgeInfo(
×
1056
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1057
                                row.GraphChannel, node1.PubKeyBytes,
×
1058
                                node2.PubKeyBytes,
×
1059
                        )
×
1060
                        if err != nil {
×
1061
                                return fmt.Errorf("unable to build channel "+
×
1062
                                        "info: %w", err)
×
1063
                        }
×
1064

1065
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1066
                        if err != nil {
×
1067
                                return fmt.Errorf("unable to extract channel "+
×
1068
                                        "policies: %w", err)
×
1069
                        }
×
1070

1071
                        p1, p2, err := getAndBuildChanPolicies(
×
1072
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1073
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1074
                        )
×
1075
                        if err != nil {
×
1076
                                return fmt.Errorf("unable to build channel "+
×
1077
                                        "policies: %w", err)
×
1078
                        }
×
1079

1080
                        edgesSeen[chanIDInt] = struct{}{}
×
1081
                        chanEdge := ChannelEdge{
×
1082
                                Info:    channel,
×
1083
                                Policy1: p1,
×
1084
                                Policy2: p2,
×
1085
                                Node1:   node1,
×
1086
                                Node2:   node2,
×
1087
                        }
×
1088
                        edges = append(edges, chanEdge)
×
1089
                        edgesToCache[chanIDInt] = chanEdge
×
1090
                }
1091

1092
                return nil
×
1093
        }, func() {
×
1094
                edgesSeen = make(map[uint64]struct{})
×
1095
                edgesToCache = make(map[uint64]ChannelEdge)
×
1096
                edges = nil
×
1097
        })
×
1098
        if err != nil {
×
1099
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1100
        }
×
1101

1102
        // Insert any edges loaded from disk into the cache.
1103
        for chanid, channel := range edgesToCache {
×
1104
                s.chanCache.insert(chanid, channel)
×
1105
        }
×
1106

1107
        if len(edges) > 0 {
×
1108
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1109
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1110
        } else {
×
1111
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1112
                        "horizon (%s, %s)", startTime, endTime)
×
1113
        }
×
1114

1115
        return edges, nil
×
1116
}
1117

1118
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1119
// data to the call-back.
1120
//
1121
// NOTE: The callback contents MUST not be modified.
1122
//
1123
// NOTE: part of the V1Store interface.
1124
func (s *SQLStore) ForEachNodeCached(ctx context.Context,
1125
        cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error,
1126
        reset func()) error {
×
1127

×
1128
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1129
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1130
                        nodePub route.Vertex) error {
×
1131

×
1132
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1133
                        if err != nil {
×
1134
                                return fmt.Errorf("unable to fetch "+
×
1135
                                        "node(id=%d) features: %w", nodeID, err)
×
1136
                        }
×
1137

1138
                        toNodeCallback := func() route.Vertex {
×
1139
                                return nodePub
×
1140
                        }
×
1141

1142
                        rows, err := db.ListChannelsByNodeID(
×
1143
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1144
                                        Version: int16(ProtocolV1),
×
1145
                                        NodeID1: nodeID,
×
1146
                                },
×
1147
                        )
×
1148
                        if err != nil {
×
1149
                                return fmt.Errorf("unable to fetch channels "+
×
1150
                                        "of node(id=%d): %w", nodeID, err)
×
1151
                        }
×
1152

1153
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1154
                        for _, row := range rows {
×
1155
                                node1, node2, err := buildNodeVertices(
×
1156
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1157
                                )
×
1158
                                if err != nil {
×
1159
                                        return err
×
1160
                                }
×
1161

1162
                                e, err := getAndBuildEdgeInfo(
×
1163
                                        ctx, db, s.cfg.ChainHash,
×
1164
                                        row.GraphChannel.ID, row.GraphChannel,
×
1165
                                        node1, node2,
×
1166
                                )
×
1167
                                if err != nil {
×
1168
                                        return fmt.Errorf("unable to build "+
×
1169
                                                "channel info: %w", err)
×
1170
                                }
×
1171

1172
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1173
                                        row,
×
1174
                                )
×
1175
                                if err != nil {
×
1176
                                        return fmt.Errorf("unable to "+
×
1177
                                                "extract channel "+
×
1178
                                                "policies: %w", err)
×
1179
                                }
×
1180

1181
                                p1, p2, err := getAndBuildChanPolicies(
×
1182
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1183
                                        node1, node2,
×
1184
                                )
×
1185
                                if err != nil {
×
1186
                                        return fmt.Errorf("unable to "+
×
1187
                                                "build channel policies: %w",
×
1188
                                                err)
×
1189
                                }
×
1190

1191
                                // Determine the outgoing and incoming policy
1192
                                // for this channel and node combo.
1193
                                outPolicy, inPolicy := p1, p2
×
1194
                                if p1 != nil && p1.ToNode == nodePub {
×
1195
                                        outPolicy, inPolicy = p2, p1
×
1196
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1197
                                        outPolicy, inPolicy = p2, p1
×
1198
                                }
×
1199

1200
                                var cachedInPolicy *models.CachedEdgePolicy
×
1201
                                if inPolicy != nil {
×
1202
                                        cachedInPolicy = models.NewCachedPolicy(
×
1203
                                                inPolicy,
×
1204
                                        )
×
1205
                                        cachedInPolicy.ToNodePubKey =
×
1206
                                                toNodeCallback
×
1207
                                        cachedInPolicy.ToNodeFeatures =
×
1208
                                                features
×
1209
                                }
×
1210

1211
                                var inboundFee lnwire.Fee
×
1212
                                if outPolicy != nil {
×
1213
                                        outPolicy.InboundFee.WhenSome(
×
1214
                                                func(fee lnwire.Fee) {
×
1215
                                                        inboundFee = fee
×
1216
                                                },
×
1217
                                        )
1218
                                }
1219

1220
                                directedChannel := &DirectedChannel{
×
1221
                                        ChannelID: e.ChannelID,
×
1222
                                        IsNode1: nodePub ==
×
1223
                                                e.NodeKey1Bytes,
×
1224
                                        OtherNode:    e.NodeKey2Bytes,
×
1225
                                        Capacity:     e.Capacity,
×
1226
                                        OutPolicySet: outPolicy != nil,
×
1227
                                        InPolicy:     cachedInPolicy,
×
1228
                                        InboundFee:   inboundFee,
×
1229
                                }
×
1230

×
1231
                                if nodePub == e.NodeKey2Bytes {
×
1232
                                        directedChannel.OtherNode =
×
1233
                                                e.NodeKey1Bytes
×
1234
                                }
×
1235

1236
                                channels[e.ChannelID] = directedChannel
×
1237
                        }
1238

1239
                        return cb(nodePub, channels)
×
1240
                })
1241
        }, reset)
1242
}
1243

1244
// ForEachChannelCacheable iterates through all the channel edges stored
1245
// within the graph and invokes the passed callback for each edge. The
1246
// callback takes two edges as since this is a directed graph, both the
1247
// in/out edges are visited. If the callback returns an error, then the
1248
// transaction is aborted and the iteration stops early.
1249
//
1250
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1251
// pointer for that particular channel edge routing policy will be
1252
// passed into the callback.
1253
//
1254
// NOTE: this method is like ForEachChannel but fetches only the data
1255
// required for the graph cache.
1256
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1257
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1258
        reset func()) error {
×
1259

×
1260
        ctx := context.TODO()
×
1261

×
NEW
1262
        handleChannel := func(
×
NEW
1263
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1264

×
1265
                node1, node2, err := buildNodeVertices(
×
1266
                        row.Node1Pubkey, row.Node2Pubkey,
×
1267
                )
×
1268
                if err != nil {
×
1269
                        return err
×
1270
                }
×
1271

1272
                edge := buildCacheableChannelInfo(
×
NEW
1273
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1274
                )
×
1275

×
1276
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1277
                if err != nil {
×
1278
                        return err
×
1279
                }
×
1280

1281
                var pol1, pol2 *models.CachedEdgePolicy
×
1282
                if dbPol1 != nil {
×
1283
                        policy1, err := buildChanPolicy(
×
1284
                                *dbPol1, edge.ChannelID, nil, node2,
×
1285
                        )
×
1286
                        if err != nil {
×
1287
                                return err
×
1288
                        }
×
1289

1290
                        pol1 = models.NewCachedPolicy(policy1)
×
1291
                }
1292
                if dbPol2 != nil {
×
1293
                        policy2, err := buildChanPolicy(
×
1294
                                *dbPol2, edge.ChannelID, nil, node1,
×
1295
                        )
×
1296
                        if err != nil {
×
1297
                                return err
×
1298
                        }
×
1299

1300
                        pol2 = models.NewCachedPolicy(policy2)
×
1301
                }
1302

1303
                if err := cb(edge, pol1, pol2); err != nil {
×
1304
                        return err
×
1305
                }
×
1306

1307
                return nil
×
1308
        }
1309

1310
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1311
                lastID := int64(-1)
×
1312
                for {
×
1313
                        //nolint:ll
×
NEW
1314
                        rows, err := db.ListChannelsWithPoliciesForCachePaginated(
×
NEW
1315
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1316
                                        Version: int16(ProtocolV1),
×
1317
                                        ID:      lastID,
×
1318
                                        Limit:   pageSize,
×
1319
                                },
×
1320
                        )
×
1321
                        if err != nil {
×
1322
                                return err
×
1323
                        }
×
1324

1325
                        if len(rows) == 0 {
×
1326
                                break
×
1327
                        }
1328

1329
                        for _, row := range rows {
×
NEW
1330
                                err := handleChannel(row)
×
1331
                                if err != nil {
×
1332
                                        return err
×
1333
                                }
×
1334

NEW
1335
                                lastID = row.ID
×
1336
                        }
1337
                }
1338

1339
                return nil
×
1340
        }, reset)
1341
}
1342

1343
// ForEachChannel iterates through all the channel edges stored within the
1344
// graph and invokes the passed callback for each edge. The callback takes two
1345
// edges as since this is a directed graph, both the in/out edges are visited.
1346
// If the callback returns an error, then the transaction is aborted and the
1347
// iteration stops early.
1348
//
1349
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1350
// for that particular channel edge routing policy will be passed into the
1351
// callback.
1352
//
1353
// NOTE: part of the V1Store interface.
1354
func (s *SQLStore) ForEachChannel(ctx context.Context,
1355
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1356
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1357

×
1358
        handleChannel := func(db SQLQueries,
×
1359
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1360

×
1361
                node1, node2, err := buildNodeVertices(
×
1362
                        row.Node1Pubkey, row.Node2Pubkey,
×
1363
                )
×
1364
                if err != nil {
×
1365
                        return fmt.Errorf("unable to build node vertices: %w",
×
1366
                                err)
×
1367
                }
×
1368

1369
                edge, err := getAndBuildEdgeInfo(
×
1370
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1371
                        row.GraphChannel, node1, node2,
×
1372
                )
×
1373
                if err != nil {
×
1374
                        return fmt.Errorf("unable to build channel info: %w",
×
1375
                                err)
×
1376
                }
×
1377

1378
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1379
                if err != nil {
×
1380
                        return fmt.Errorf("unable to extract channel "+
×
1381
                                "policies: %w", err)
×
1382
                }
×
1383

1384
                p1, p2, err := getAndBuildChanPolicies(
×
1385
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1386
                )
×
1387
                if err != nil {
×
1388
                        return fmt.Errorf("unable to build channel "+
×
1389
                                "policies: %w", err)
×
1390
                }
×
1391

1392
                err = cb(edge, p1, p2)
×
1393
                if err != nil {
×
1394
                        return fmt.Errorf("callback failed for channel "+
×
1395
                                "id=%d: %w", edge.ChannelID, err)
×
1396
                }
×
1397

1398
                return nil
×
1399
        }
1400

1401
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1402
                lastID := int64(-1)
×
1403
                for {
×
1404
                        //nolint:ll
×
1405
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1406
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1407
                                        Version: int16(ProtocolV1),
×
1408
                                        ID:      lastID,
×
1409
                                        Limit:   pageSize,
×
1410
                                },
×
1411
                        )
×
1412
                        if err != nil {
×
1413
                                return err
×
1414
                        }
×
1415

1416
                        if len(rows) == 0 {
×
1417
                                break
×
1418
                        }
1419

1420
                        for _, row := range rows {
×
1421
                                err := handleChannel(db, row)
×
1422
                                if err != nil {
×
1423
                                        return err
×
1424
                                }
×
1425

1426
                                lastID = row.GraphChannel.ID
×
1427
                        }
1428
                }
1429

1430
                return nil
×
1431
        }, reset)
1432
}
1433

1434
// FilterChannelRange returns the channel ID's of all known channels which were
1435
// mined in a block height within the passed range. The channel IDs are grouped
1436
// by their common block height. This method can be used to quickly share with a
1437
// peer the set of channels we know of within a particular range to catch them
1438
// up after a period of time offline. If withTimestamps is true then the
1439
// timestamp info of the latest received channel update messages of the channel
1440
// will be included in the response.
1441
//
1442
// NOTE: This is part of the V1Store interface.
1443
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1444
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1445

×
1446
        var (
×
1447
                ctx       = context.TODO()
×
1448
                startSCID = &lnwire.ShortChannelID{
×
1449
                        BlockHeight: startHeight,
×
1450
                }
×
1451
                endSCID = lnwire.ShortChannelID{
×
1452
                        BlockHeight: endHeight,
×
1453
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1454
                        TxPosition:  math.MaxUint16,
×
1455
                }
×
1456
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1457
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1458
        )
×
1459

×
1460
        // 1) get all channels where channelID is between start and end chan ID.
×
1461
        // 2) skip if not public (ie, no channel_proof)
×
1462
        // 3) collect that channel.
×
1463
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1464
        //    and add those timestamps to the collected channel.
×
1465
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1466
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1467
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1468
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1469
                                StartScid: chanIDStart,
×
1470
                                EndScid:   chanIDEnd,
×
1471
                        },
×
1472
                )
×
1473
                if err != nil {
×
1474
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1475
                                err)
×
1476
                }
×
1477

1478
                for _, dbChan := range dbChans {
×
1479
                        cid := lnwire.NewShortChanIDFromInt(
×
1480
                                byteOrder.Uint64(dbChan.Scid),
×
1481
                        )
×
1482
                        chanInfo := NewChannelUpdateInfo(
×
1483
                                cid, time.Time{}, time.Time{},
×
1484
                        )
×
1485

×
1486
                        if !withTimestamps {
×
1487
                                channelsPerBlock[cid.BlockHeight] = append(
×
1488
                                        channelsPerBlock[cid.BlockHeight],
×
1489
                                        chanInfo,
×
1490
                                )
×
1491

×
1492
                                continue
×
1493
                        }
1494

1495
                        //nolint:ll
1496
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1497
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1498
                                        Version:   int16(ProtocolV1),
×
1499
                                        ChannelID: dbChan.ID,
×
1500
                                        NodeID:    dbChan.NodeID1,
×
1501
                                },
×
1502
                        )
×
1503
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1504
                                return fmt.Errorf("unable to fetch node1 "+
×
1505
                                        "policy: %w", err)
×
1506
                        } else if err == nil {
×
1507
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1508
                                        node1Policy.LastUpdate.Int64, 0,
×
1509
                                )
×
1510
                        }
×
1511

1512
                        //nolint:ll
1513
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1514
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1515
                                        Version:   int16(ProtocolV1),
×
1516
                                        ChannelID: dbChan.ID,
×
1517
                                        NodeID:    dbChan.NodeID2,
×
1518
                                },
×
1519
                        )
×
1520
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1521
                                return fmt.Errorf("unable to fetch node2 "+
×
1522
                                        "policy: %w", err)
×
1523
                        } else if err == nil {
×
1524
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1525
                                        node2Policy.LastUpdate.Int64, 0,
×
1526
                                )
×
1527
                        }
×
1528

1529
                        channelsPerBlock[cid.BlockHeight] = append(
×
1530
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1531
                        )
×
1532
                }
1533

1534
                return nil
×
1535
        }, func() {
×
1536
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1537
        })
×
1538
        if err != nil {
×
1539
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1540
        }
×
1541

1542
        if len(channelsPerBlock) == 0 {
×
1543
                return nil, nil
×
1544
        }
×
1545

1546
        // Return the channel ranges in ascending block height order.
1547
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1548
        slices.Sort(blocks)
×
1549

×
1550
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1551
                return BlockChannelRange{
×
1552
                        Height:   block,
×
1553
                        Channels: channelsPerBlock[block],
×
1554
                }
×
1555
        }), nil
×
1556
}
1557

1558
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1559
// zombie. This method is used on an ad-hoc basis, when channels need to be
1560
// marked as zombies outside the normal pruning cycle.
1561
//
1562
// NOTE: part of the V1Store interface.
1563
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1564
        pubKey1, pubKey2 [33]byte) error {
×
1565

×
1566
        ctx := context.TODO()
×
1567

×
1568
        s.cacheMu.Lock()
×
1569
        defer s.cacheMu.Unlock()
×
1570

×
1571
        chanIDB := channelIDToBytes(chanID)
×
1572

×
1573
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1574
                return db.UpsertZombieChannel(
×
1575
                        ctx, sqlc.UpsertZombieChannelParams{
×
1576
                                Version:  int16(ProtocolV1),
×
1577
                                Scid:     chanIDB,
×
1578
                                NodeKey1: pubKey1[:],
×
1579
                                NodeKey2: pubKey2[:],
×
1580
                        },
×
1581
                )
×
1582
        }, sqldb.NoOpReset)
×
1583
        if err != nil {
×
1584
                return fmt.Errorf("unable to upsert zombie channel "+
×
1585
                        "(channel_id=%d): %w", chanID, err)
×
1586
        }
×
1587

1588
        s.rejectCache.remove(chanID)
×
1589
        s.chanCache.remove(chanID)
×
1590

×
1591
        return nil
×
1592
}
1593

1594
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1595
//
1596
// NOTE: part of the V1Store interface.
1597
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1598
        s.cacheMu.Lock()
×
1599
        defer s.cacheMu.Unlock()
×
1600

×
1601
        var (
×
1602
                ctx     = context.TODO()
×
1603
                chanIDB = channelIDToBytes(chanID)
×
1604
        )
×
1605

×
1606
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1607
                res, err := db.DeleteZombieChannel(
×
1608
                        ctx, sqlc.DeleteZombieChannelParams{
×
1609
                                Scid:    chanIDB,
×
1610
                                Version: int16(ProtocolV1),
×
1611
                        },
×
1612
                )
×
1613
                if err != nil {
×
1614
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1615
                                err)
×
1616
                }
×
1617

1618
                rows, err := res.RowsAffected()
×
1619
                if err != nil {
×
1620
                        return err
×
1621
                }
×
1622

1623
                if rows == 0 {
×
1624
                        return ErrZombieEdgeNotFound
×
1625
                } else if rows > 1 {
×
1626
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1627
                                "expected 1", rows)
×
1628
                }
×
1629

1630
                return nil
×
1631
        }, sqldb.NoOpReset)
1632
        if err != nil {
×
1633
                return fmt.Errorf("unable to mark edge live "+
×
1634
                        "(channel_id=%d): %w", chanID, err)
×
1635
        }
×
1636

1637
        s.rejectCache.remove(chanID)
×
1638
        s.chanCache.remove(chanID)
×
1639

×
1640
        return err
×
1641
}
1642

1643
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1644
// zombie, then the two node public keys corresponding to this edge are also
1645
// returned.
1646
//
1647
// NOTE: part of the V1Store interface.
1648
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1649
        error) {
×
1650

×
1651
        var (
×
1652
                ctx              = context.TODO()
×
1653
                isZombie         bool
×
1654
                pubKey1, pubKey2 route.Vertex
×
1655
                chanIDB          = channelIDToBytes(chanID)
×
1656
        )
×
1657

×
1658
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1659
                zombie, err := db.GetZombieChannel(
×
1660
                        ctx, sqlc.GetZombieChannelParams{
×
1661
                                Scid:    chanIDB,
×
1662
                                Version: int16(ProtocolV1),
×
1663
                        },
×
1664
                )
×
1665
                if errors.Is(err, sql.ErrNoRows) {
×
1666
                        return nil
×
1667
                }
×
1668
                if err != nil {
×
1669
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1670
                                err)
×
1671
                }
×
1672

1673
                copy(pubKey1[:], zombie.NodeKey1)
×
1674
                copy(pubKey2[:], zombie.NodeKey2)
×
1675
                isZombie = true
×
1676

×
1677
                return nil
×
1678
        }, sqldb.NoOpReset)
1679
        if err != nil {
×
1680
                return false, route.Vertex{}, route.Vertex{},
×
1681
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1682
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1683
        }
×
1684

1685
        return isZombie, pubKey1, pubKey2, nil
×
1686
}
1687

1688
// NumZombies returns the current number of zombie channels in the graph.
1689
//
1690
// NOTE: part of the V1Store interface.
1691
func (s *SQLStore) NumZombies() (uint64, error) {
×
1692
        var (
×
1693
                ctx        = context.TODO()
×
1694
                numZombies uint64
×
1695
        )
×
1696
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1697
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1698
                if err != nil {
×
1699
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1700
                                err)
×
1701
                }
×
1702

1703
                numZombies = uint64(count)
×
1704

×
1705
                return nil
×
1706
        }, sqldb.NoOpReset)
1707
        if err != nil {
×
1708
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1709
        }
×
1710

1711
        return numZombies, nil
×
1712
}
1713

1714
// DeleteChannelEdges removes edges with the given channel IDs from the
1715
// database and marks them as zombies. This ensures that we're unable to re-add
1716
// it to our database once again. If an edge does not exist within the
1717
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1718
// true, then when we mark these edges as zombies, we'll set up the keys such
1719
// that we require the node that failed to send the fresh update to be the one
1720
// that resurrects the channel from its zombie state. The markZombie bool
1721
// denotes whether to mark the channel as a zombie.
1722
//
1723
// NOTE: part of the V1Store interface.
1724
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1725
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1726

×
1727
        s.cacheMu.Lock()
×
1728
        defer s.cacheMu.Unlock()
×
1729

×
1730
        // Keep track of which channels we end up finding so that we can
×
1731
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1732
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1733
        for _, chanID := range chanIDs {
×
1734
                chanLookup[chanID] = struct{}{}
×
1735
        }
×
1736

1737
        var (
×
1738
                ctx     = context.TODO()
×
1739
                deleted []*models.ChannelEdgeInfo
×
1740
        )
×
1741
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1742
                chanIDsToDelete := make([]int64, 0, len(chanIDs))
×
1743
                chanCallBack := func(ctx context.Context,
×
1744
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1745

×
1746
                        // Deleting the entry from the map indicates that we
×
1747
                        // have found the channel.
×
1748
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1749
                        delete(chanLookup, scid)
×
1750

×
1751
                        node1, node2, err := buildNodeVertices(
×
1752
                                row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1753
                        )
×
1754
                        if err != nil {
×
1755
                                return err
×
1756
                        }
×
1757

1758
                        info, err := getAndBuildEdgeInfo(
×
1759
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1760
                                row.GraphChannel, node1, node2,
×
1761
                        )
×
1762
                        if err != nil {
×
1763
                                return err
×
1764
                        }
×
1765

1766
                        deleted = append(deleted, info)
×
1767
                        chanIDsToDelete = append(
×
1768
                                chanIDsToDelete, row.GraphChannel.ID,
×
1769
                        )
×
1770

×
1771
                        if !markZombie {
×
1772
                                return nil
×
1773
                        }
×
1774

1775
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1776
                                info.NodeKey2Bytes
×
1777
                        if strictZombiePruning {
×
1778
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1779
                                if row.Policy1LastUpdate.Valid {
×
1780
                                        e1Time := time.Unix(
×
1781
                                                row.Policy1LastUpdate.Int64, 0,
×
1782
                                        )
×
1783
                                        e1UpdateTime = &e1Time
×
1784
                                }
×
1785
                                if row.Policy2LastUpdate.Valid {
×
1786
                                        e2Time := time.Unix(
×
1787
                                                row.Policy2LastUpdate.Int64, 0,
×
1788
                                        )
×
1789
                                        e2UpdateTime = &e2Time
×
1790
                                }
×
1791

1792
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1793
                                        info, e1UpdateTime, e2UpdateTime,
×
1794
                                )
×
1795
                        }
1796

1797
                        err = db.UpsertZombieChannel(
×
1798
                                ctx, sqlc.UpsertZombieChannelParams{
×
1799
                                        Version:  int16(ProtocolV1),
×
1800
                                        Scid:     channelIDToBytes(scid),
×
1801
                                        NodeKey1: nodeKey1[:],
×
1802
                                        NodeKey2: nodeKey2[:],
×
1803
                                },
×
1804
                        )
×
1805
                        if err != nil {
×
1806
                                return fmt.Errorf("unable to mark channel as "+
×
1807
                                        "zombie: %w", err)
×
1808
                        }
×
1809

1810
                        return nil
×
1811
                }
1812

1813
                err := s.forEachChanWithPoliciesInSCIDList(
×
1814
                        ctx, db, chanCallBack, chanIDs,
×
1815
                )
×
1816
                if err != nil {
×
1817
                        return err
×
1818
                }
×
1819

1820
                if len(chanLookup) > 0 {
×
1821
                        return ErrEdgeNotFound
×
1822
                }
×
1823

1824
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1825
        }, func() {
×
1826
                deleted = nil
×
1827

×
1828
                // Re-fill the lookup map.
×
1829
                for _, chanID := range chanIDs {
×
1830
                        chanLookup[chanID] = struct{}{}
×
1831
                }
×
1832
        })
1833
        if err != nil {
×
1834
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1835
                        err)
×
1836
        }
×
1837

1838
        for _, chanID := range chanIDs {
×
1839
                s.rejectCache.remove(chanID)
×
1840
                s.chanCache.remove(chanID)
×
1841
        }
×
1842

1843
        return deleted, nil
×
1844
}
1845

1846
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1847
// channel identified by the channel ID. If the channel can't be found, then
1848
// ErrEdgeNotFound is returned. A struct which houses the general information
1849
// for the channel itself is returned as well as two structs that contain the
1850
// routing policies for the channel in either direction.
1851
//
1852
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1853
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1854
// the ChannelEdgeInfo will only include the public keys of each node.
1855
//
1856
// NOTE: part of the V1Store interface.
1857
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1858
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1859
        *models.ChannelEdgePolicy, error) {
×
1860

×
1861
        var (
×
1862
                ctx              = context.TODO()
×
1863
                edge             *models.ChannelEdgeInfo
×
1864
                policy1, policy2 *models.ChannelEdgePolicy
×
1865
                chanIDB          = channelIDToBytes(chanID)
×
1866
        )
×
1867
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1868
                row, err := db.GetChannelBySCIDWithPolicies(
×
1869
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1870
                                Scid:    chanIDB,
×
1871
                                Version: int16(ProtocolV1),
×
1872
                        },
×
1873
                )
×
1874
                if errors.Is(err, sql.ErrNoRows) {
×
1875
                        // First check if this edge is perhaps in the zombie
×
1876
                        // index.
×
1877
                        zombie, err := db.GetZombieChannel(
×
1878
                                ctx, sqlc.GetZombieChannelParams{
×
1879
                                        Scid:    chanIDB,
×
1880
                                        Version: int16(ProtocolV1),
×
1881
                                },
×
1882
                        )
×
1883
                        if errors.Is(err, sql.ErrNoRows) {
×
1884
                                return ErrEdgeNotFound
×
1885
                        } else if err != nil {
×
1886
                                return fmt.Errorf("unable to check if "+
×
1887
                                        "channel is zombie: %w", err)
×
1888
                        }
×
1889

1890
                        // At this point, we know the channel is a zombie, so
1891
                        // we'll return an error indicating this, and we will
1892
                        // populate the edge info with the public keys of each
1893
                        // party as this is the only information we have about
1894
                        // it.
1895
                        edge = &models.ChannelEdgeInfo{}
×
1896
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1897
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1898

×
1899
                        return ErrZombieEdge
×
1900
                } else if err != nil {
×
1901
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1902
                }
×
1903

1904
                node1, node2, err := buildNodeVertices(
×
1905
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1906
                )
×
1907
                if err != nil {
×
1908
                        return err
×
1909
                }
×
1910

1911
                edge, err = getAndBuildEdgeInfo(
×
1912
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1913
                        row.GraphChannel, node1, node2,
×
1914
                )
×
1915
                if err != nil {
×
1916
                        return fmt.Errorf("unable to build channel info: %w",
×
1917
                                err)
×
1918
                }
×
1919

1920
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1921
                if err != nil {
×
1922
                        return fmt.Errorf("unable to extract channel "+
×
1923
                                "policies: %w", err)
×
1924
                }
×
1925

1926
                policy1, policy2, err = getAndBuildChanPolicies(
×
1927
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1928
                )
×
1929
                if err != nil {
×
1930
                        return fmt.Errorf("unable to build channel "+
×
1931
                                "policies: %w", err)
×
1932
                }
×
1933

1934
                return nil
×
1935
        }, sqldb.NoOpReset)
1936
        if err != nil {
×
1937
                // If we are returning the ErrZombieEdge, then we also need to
×
1938
                // return the edge info as the method comment indicates that
×
1939
                // this will be populated when the edge is a zombie.
×
1940
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1941
                        err)
×
1942
        }
×
1943

1944
        return edge, policy1, policy2, nil
×
1945
}
1946

1947
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1948
// the channel identified by the funding outpoint. If the channel can't be
1949
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1950
// information for the channel itself is returned as well as two structs that
1951
// contain the routing policies for the channel in either direction.
1952
//
1953
// NOTE: part of the V1Store interface.
1954
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1955
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1956
        *models.ChannelEdgePolicy, error) {
×
1957

×
1958
        var (
×
1959
                ctx              = context.TODO()
×
1960
                edge             *models.ChannelEdgeInfo
×
1961
                policy1, policy2 *models.ChannelEdgePolicy
×
1962
        )
×
1963
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1964
                row, err := db.GetChannelByOutpointWithPolicies(
×
1965
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1966
                                Outpoint: op.String(),
×
1967
                                Version:  int16(ProtocolV1),
×
1968
                        },
×
1969
                )
×
1970
                if errors.Is(err, sql.ErrNoRows) {
×
1971
                        return ErrEdgeNotFound
×
1972
                } else if err != nil {
×
1973
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1974
                }
×
1975

1976
                node1, node2, err := buildNodeVertices(
×
1977
                        row.Node1Pubkey, row.Node2Pubkey,
×
1978
                )
×
1979
                if err != nil {
×
1980
                        return err
×
1981
                }
×
1982

1983
                edge, err = getAndBuildEdgeInfo(
×
1984
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1985
                        row.GraphChannel, node1, node2,
×
1986
                )
×
1987
                if err != nil {
×
1988
                        return fmt.Errorf("unable to build channel info: %w",
×
1989
                                err)
×
1990
                }
×
1991

1992
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1993
                if err != nil {
×
1994
                        return fmt.Errorf("unable to extract channel "+
×
1995
                                "policies: %w", err)
×
1996
                }
×
1997

1998
                policy1, policy2, err = getAndBuildChanPolicies(
×
1999
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
2000
                )
×
2001
                if err != nil {
×
2002
                        return fmt.Errorf("unable to build channel "+
×
2003
                                "policies: %w", err)
×
2004
                }
×
2005

2006
                return nil
×
2007
        }, sqldb.NoOpReset)
2008
        if err != nil {
×
2009
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2010
                        err)
×
2011
        }
×
2012

2013
        return edge, policy1, policy2, nil
×
2014
}
2015

2016
// HasChannelEdge returns true if the database knows of a channel edge with the
2017
// passed channel ID, and false otherwise. If an edge with that ID is found
2018
// within the graph, then two time stamps representing the last time the edge
2019
// was updated for both directed edges are returned along with the boolean. If
2020
// it is not found, then the zombie index is checked and its result is returned
2021
// as the second boolean.
2022
//
2023
// NOTE: part of the V1Store interface.
2024
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
2025
        bool, error) {
×
2026

×
2027
        ctx := context.TODO()
×
2028

×
2029
        var (
×
2030
                exists          bool
×
2031
                isZombie        bool
×
2032
                node1LastUpdate time.Time
×
2033
                node2LastUpdate time.Time
×
2034
        )
×
2035

×
2036
        // We'll query the cache with the shared lock held to allow multiple
×
2037
        // readers to access values in the cache concurrently if they exist.
×
2038
        s.cacheMu.RLock()
×
2039
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2040
                s.cacheMu.RUnlock()
×
2041
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2042
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2043
                exists, isZombie = entry.flags.unpack()
×
2044

×
2045
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2046
        }
×
2047
        s.cacheMu.RUnlock()
×
2048

×
2049
        s.cacheMu.Lock()
×
2050
        defer s.cacheMu.Unlock()
×
2051

×
2052
        // The item was not found with the shared lock, so we'll acquire the
×
2053
        // exclusive lock and check the cache again in case another method added
×
2054
        // the entry to the cache while no lock was held.
×
2055
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2056
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2057
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2058
                exists, isZombie = entry.flags.unpack()
×
2059

×
2060
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2061
        }
×
2062

2063
        chanIDB := channelIDToBytes(chanID)
×
2064
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2065
                channel, err := db.GetChannelBySCID(
×
2066
                        ctx, sqlc.GetChannelBySCIDParams{
×
2067
                                Scid:    chanIDB,
×
2068
                                Version: int16(ProtocolV1),
×
2069
                        },
×
2070
                )
×
2071
                if errors.Is(err, sql.ErrNoRows) {
×
2072
                        // Check if it is a zombie channel.
×
2073
                        isZombie, err = db.IsZombieChannel(
×
2074
                                ctx, sqlc.IsZombieChannelParams{
×
2075
                                        Scid:    chanIDB,
×
2076
                                        Version: int16(ProtocolV1),
×
2077
                                },
×
2078
                        )
×
2079
                        if err != nil {
×
2080
                                return fmt.Errorf("could not check if channel "+
×
2081
                                        "is zombie: %w", err)
×
2082
                        }
×
2083

2084
                        return nil
×
2085
                } else if err != nil {
×
2086
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2087
                }
×
2088

2089
                exists = true
×
2090

×
2091
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2092
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2093
                                Version:   int16(ProtocolV1),
×
2094
                                ChannelID: channel.ID,
×
2095
                                NodeID:    channel.NodeID1,
×
2096
                        },
×
2097
                )
×
2098
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2099
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2100
                                err)
×
2101
                } else if err == nil {
×
2102
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2103
                }
×
2104

2105
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2106
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2107
                                Version:   int16(ProtocolV1),
×
2108
                                ChannelID: channel.ID,
×
2109
                                NodeID:    channel.NodeID2,
×
2110
                        },
×
2111
                )
×
2112
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2113
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2114
                                err)
×
2115
                } else if err == nil {
×
2116
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2117
                }
×
2118

2119
                return nil
×
2120
        }, sqldb.NoOpReset)
2121
        if err != nil {
×
2122
                return time.Time{}, time.Time{}, false, false,
×
2123
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2124
        }
×
2125

2126
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2127
                upd1Time: node1LastUpdate.Unix(),
×
2128
                upd2Time: node2LastUpdate.Unix(),
×
2129
                flags:    packRejectFlags(exists, isZombie),
×
2130
        })
×
2131

×
2132
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2133
}
2134

2135
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2136
// passed channel point (outpoint). If the passed channel doesn't exist within
2137
// the database, then ErrEdgeNotFound is returned.
2138
//
2139
// NOTE: part of the V1Store interface.
2140
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2141
        var (
×
2142
                ctx       = context.TODO()
×
2143
                channelID uint64
×
2144
        )
×
2145
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2146
                chanID, err := db.GetSCIDByOutpoint(
×
2147
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2148
                                Outpoint: chanPoint.String(),
×
2149
                                Version:  int16(ProtocolV1),
×
2150
                        },
×
2151
                )
×
2152
                if errors.Is(err, sql.ErrNoRows) {
×
2153
                        return ErrEdgeNotFound
×
2154
                } else if err != nil {
×
2155
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2156
                                err)
×
2157
                }
×
2158

2159
                channelID = byteOrder.Uint64(chanID)
×
2160

×
2161
                return nil
×
2162
        }, sqldb.NoOpReset)
2163
        if err != nil {
×
2164
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2165
        }
×
2166

2167
        return channelID, nil
×
2168
}
2169

2170
// IsPublicNode is a helper method that determines whether the node with the
2171
// given public key is seen as a public node in the graph from the graph's
2172
// source node's point of view.
2173
//
2174
// NOTE: part of the V1Store interface.
2175
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2176
        ctx := context.TODO()
×
2177

×
2178
        var isPublic bool
×
2179
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2180
                var err error
×
2181
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2182

×
2183
                return err
×
2184
        }, sqldb.NoOpReset)
×
2185
        if err != nil {
×
2186
                return false, fmt.Errorf("unable to check if node is "+
×
2187
                        "public: %w", err)
×
2188
        }
×
2189

2190
        return isPublic, nil
×
2191
}
2192

2193
// FetchChanInfos returns the set of channel edges that correspond to the passed
2194
// channel ID's. If an edge is the query is unknown to the database, it will
2195
// skipped and the result will contain only those edges that exist at the time
2196
// of the query. This can be used to respond to peer queries that are seeking to
2197
// fill in gaps in their view of the channel graph.
2198
//
2199
// NOTE: part of the V1Store interface.
2200
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2201
        var (
×
2202
                ctx   = context.TODO()
×
2203
                edges = make(map[uint64]ChannelEdge)
×
2204
        )
×
2205
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2206
                chanCallBack := func(ctx context.Context,
×
2207
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2208

×
2209
                        node1, node2, err := buildNodes(
×
2210
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
2211
                        )
×
2212
                        if err != nil {
×
2213
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2214
                                        err)
×
2215
                        }
×
2216

2217
                        edge, err := getAndBuildEdgeInfo(
×
2218
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
2219
                                row.GraphChannel, node1.PubKeyBytes,
×
2220
                                node2.PubKeyBytes,
×
2221
                        )
×
2222
                        if err != nil {
×
2223
                                return fmt.Errorf("unable to build "+
×
2224
                                        "channel info: %w", err)
×
2225
                        }
×
2226

2227
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2228
                        if err != nil {
×
2229
                                return fmt.Errorf("unable to extract channel "+
×
2230
                                        "policies: %w", err)
×
2231
                        }
×
2232

2233
                        p1, p2, err := getAndBuildChanPolicies(
×
2234
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2235
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2236
                        )
×
2237
                        if err != nil {
×
2238
                                return fmt.Errorf("unable to build channel "+
×
2239
                                        "policies: %w", err)
×
2240
                        }
×
2241

2242
                        edges[edge.ChannelID] = ChannelEdge{
×
2243
                                Info:    edge,
×
2244
                                Policy1: p1,
×
2245
                                Policy2: p2,
×
2246
                                Node1:   node1,
×
2247
                                Node2:   node2,
×
2248
                        }
×
2249

×
2250
                        return nil
×
2251
                }
2252

2253
                return s.forEachChanWithPoliciesInSCIDList(
×
2254
                        ctx, db, chanCallBack, chanIDs,
×
2255
                )
×
2256
        }, func() {
×
2257
                clear(edges)
×
2258
        })
×
2259
        if err != nil {
×
2260
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2261
        }
×
2262

2263
        res := make([]ChannelEdge, 0, len(edges))
×
2264
        for _, chanID := range chanIDs {
×
2265
                edge, ok := edges[chanID]
×
2266
                if !ok {
×
2267
                        continue
×
2268
                }
2269

2270
                res = append(res, edge)
×
2271
        }
2272

2273
        return res, nil
×
2274
}
2275

2276
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2277
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2278
// channels in a paginated manner.
2279
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2280
        db SQLQueries, cb func(ctx context.Context,
2281
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2282
        chanIDs []uint64) error {
×
2283

×
2284
        queryWrapper := func(ctx context.Context,
×
2285
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2286
                error) {
×
2287

×
2288
                return db.GetChannelsBySCIDWithPolicies(
×
2289
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2290
                                Version: int16(ProtocolV1),
×
2291
                                Scids:   scids,
×
2292
                        },
×
2293
                )
×
2294
        }
×
2295

2296
        return sqldb.ExecutePagedQuery(
×
2297
                ctx, s.cfg.PaginationCfg, chanIDs, channelIDToBytes,
×
2298
                queryWrapper, cb,
×
2299
        )
×
2300
}
2301

2302
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2303
// ID's that we don't know and are not known zombies of the passed set. In other
2304
// words, we perform a set difference of our set of chan ID's and the ones
2305
// passed in. This method can be used by callers to determine the set of
2306
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2307
// known zombies is also returned.
2308
//
2309
// NOTE: part of the V1Store interface.
2310
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2311
        []ChannelUpdateInfo, error) {
×
2312

×
2313
        var (
×
2314
                ctx          = context.TODO()
×
2315
                newChanIDs   []uint64
×
2316
                knownZombies []ChannelUpdateInfo
×
2317
                infoLookup   = make(
×
2318
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2319
                )
×
2320
        )
×
2321

×
2322
        // We first build a lookup map of the channel ID's to the
×
2323
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2324
        // already know about.
×
2325
        for _, chanInfo := range chansInfo {
×
2326
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2327
        }
×
2328

2329
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2330
                // The call-back function deletes known channels from
×
2331
                // infoLookup, so that we can later check which channels are
×
2332
                // zombies by only looking at the remaining channels in the set.
×
2333
                cb := func(ctx context.Context,
×
2334
                        channel sqlc.GraphChannel) error {
×
2335

×
2336
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2337

×
2338
                        return nil
×
2339
                }
×
2340

2341
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2342
                if err != nil {
×
2343
                        return fmt.Errorf("unable to iterate through "+
×
2344
                                "channels: %w", err)
×
2345
                }
×
2346

2347
                // We want to ensure that we deal with the channels in the
2348
                // same order that they were passed in, so we iterate over the
2349
                // original chansInfo slice and then check if that channel is
2350
                // still in the infoLookup map.
2351
                for _, chanInfo := range chansInfo {
×
2352
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2353
                        if _, ok := infoLookup[channelID]; !ok {
×
2354
                                continue
×
2355
                        }
2356

2357
                        isZombie, err := db.IsZombieChannel(
×
2358
                                ctx, sqlc.IsZombieChannelParams{
×
2359
                                        Scid:    channelIDToBytes(channelID),
×
2360
                                        Version: int16(ProtocolV1),
×
2361
                                },
×
2362
                        )
×
2363
                        if err != nil {
×
2364
                                return fmt.Errorf("unable to fetch zombie "+
×
2365
                                        "channel: %w", err)
×
2366
                        }
×
2367

2368
                        if isZombie {
×
2369
                                knownZombies = append(knownZombies, chanInfo)
×
2370

×
2371
                                continue
×
2372
                        }
2373

2374
                        newChanIDs = append(newChanIDs, channelID)
×
2375
                }
2376

2377
                return nil
×
2378
        }, func() {
×
2379
                newChanIDs = nil
×
2380
                knownZombies = nil
×
2381
                // Rebuild the infoLookup map in case of a rollback.
×
2382
                for _, chanInfo := range chansInfo {
×
2383
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2384
                        infoLookup[scid] = chanInfo
×
2385
                }
×
2386
        })
2387
        if err != nil {
×
2388
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2389
        }
×
2390

2391
        return newChanIDs, knownZombies, nil
×
2392
}
2393

2394
// forEachChanInSCIDList is a helper method that executes a paged query
2395
// against the database to fetch all channels that match the passed
2396
// ChannelUpdateInfo slice. The callback function is called for each channel
2397
// that is found.
2398
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2399
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2400
        chansInfo []ChannelUpdateInfo) error {
×
2401

×
2402
        queryWrapper := func(ctx context.Context,
×
2403
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2404

×
2405
                return db.GetChannelsBySCIDs(
×
2406
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2407
                                Version: int16(ProtocolV1),
×
2408
                                Scids:   scids,
×
2409
                        },
×
2410
                )
×
2411
        }
×
2412

2413
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2414
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2415

×
2416
                return channelIDToBytes(channelID)
×
2417
        }
×
2418

2419
        return sqldb.ExecutePagedQuery(
×
2420
                ctx, s.cfg.PaginationCfg, chansInfo, chanIDConverter,
×
2421
                queryWrapper, cb,
×
2422
        )
×
2423
}
2424

2425
// PruneGraphNodes is a garbage collection method which attempts to prune out
2426
// any nodes from the channel graph that are currently unconnected. This ensure
2427
// that we only maintain a graph of reachable nodes. In the event that a pruned
2428
// node gains more channels, it will be re-added back to the graph.
2429
//
2430
// NOTE: this prunes nodes across protocol versions. It will never prune the
2431
// source nodes.
2432
//
2433
// NOTE: part of the V1Store interface.
2434
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2435
        var ctx = context.TODO()
×
2436

×
2437
        var prunedNodes []route.Vertex
×
2438
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2439
                var err error
×
2440
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2441

×
2442
                return err
×
2443
        }, func() {
×
2444
                prunedNodes = nil
×
2445
        })
×
2446
        if err != nil {
×
2447
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2448
        }
×
2449

2450
        return prunedNodes, nil
×
2451
}
2452

2453
// PruneGraph prunes newly closed channels from the channel graph in response
2454
// to a new block being solved on the network. Any transactions which spend the
2455
// funding output of any known channels within he graph will be deleted.
2456
// Additionally, the "prune tip", or the last block which has been used to
2457
// prune the graph is stored so callers can ensure the graph is fully in sync
2458
// with the current UTXO state. A slice of channels that have been closed by
2459
// the target block along with any pruned nodes are returned if the function
2460
// succeeds without error.
2461
//
2462
// NOTE: part of the V1Store interface.
2463
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2464
        blockHash *chainhash.Hash, blockHeight uint32) (
2465
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2466

×
2467
        ctx := context.TODO()
×
2468

×
2469
        s.cacheMu.Lock()
×
2470
        defer s.cacheMu.Unlock()
×
2471

×
2472
        var (
×
2473
                closedChans []*models.ChannelEdgeInfo
×
2474
                prunedNodes []route.Vertex
×
2475
        )
×
2476
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2477
                var chansToDelete []int64
×
2478

×
2479
                // Define the callback function for processing each channel.
×
2480
                channelCallback := func(ctx context.Context,
×
2481
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2482

×
2483
                        node1, node2, err := buildNodeVertices(
×
2484
                                row.Node1Pubkey, row.Node2Pubkey,
×
2485
                        )
×
2486
                        if err != nil {
×
2487
                                return err
×
2488
                        }
×
2489

2490
                        info, err := getAndBuildEdgeInfo(
×
2491
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
2492
                                row.GraphChannel, node1, node2,
×
2493
                        )
×
2494
                        if err != nil {
×
2495
                                return err
×
2496
                        }
×
2497

2498
                        closedChans = append(closedChans, info)
×
2499
                        chansToDelete = append(
×
2500
                                chansToDelete, row.GraphChannel.ID,
×
2501
                        )
×
2502

×
2503
                        return nil
×
2504
                }
2505

2506
                err := s.forEachChanInOutpoints(
×
2507
                        ctx, db, spentOutputs, channelCallback,
×
2508
                )
×
2509
                if err != nil {
×
2510
                        return fmt.Errorf("unable to fetch channels by "+
×
2511
                                "outpoints: %w", err)
×
2512
                }
×
2513

2514
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2515
                if err != nil {
×
2516
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2517
                }
×
2518

2519
                err = db.UpsertPruneLogEntry(
×
2520
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2521
                                BlockHash:   blockHash[:],
×
2522
                                BlockHeight: int64(blockHeight),
×
2523
                        },
×
2524
                )
×
2525
                if err != nil {
×
2526
                        return fmt.Errorf("unable to insert prune log "+
×
2527
                                "entry: %w", err)
×
2528
                }
×
2529

2530
                // Now that we've pruned some channels, we'll also prune any
2531
                // nodes that no longer have any channels.
2532
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2533
                if err != nil {
×
2534
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2535
                                err)
×
2536
                }
×
2537

2538
                return nil
×
2539
        }, func() {
×
2540
                prunedNodes = nil
×
2541
                closedChans = nil
×
2542
        })
×
2543
        if err != nil {
×
2544
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2545
        }
×
2546

2547
        for _, channel := range closedChans {
×
2548
                s.rejectCache.remove(channel.ChannelID)
×
2549
                s.chanCache.remove(channel.ChannelID)
×
2550
        }
×
2551

2552
        return closedChans, prunedNodes, nil
×
2553
}
2554

2555
// forEachChanInOutpoints is a helper function that executes a paginated
2556
// query to fetch channels by their outpoints and applies the given call-back
2557
// to each.
2558
//
2559
// NOTE: this fetches channels for all protocol versions.
2560
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2561
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2562
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2563

×
2564
        // Create a wrapper that uses the transaction's db instance to execute
×
2565
        // the query.
×
2566
        queryWrapper := func(ctx context.Context,
×
2567
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2568
                error) {
×
2569

×
2570
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2571
        }
×
2572

2573
        // Define the conversion function from Outpoint to string.
2574
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2575
                return outpoint.String()
×
2576
        }
×
2577

2578
        return sqldb.ExecutePagedQuery(
×
2579
                ctx, s.cfg.PaginationCfg, outpoints, outpointToString,
×
2580
                queryWrapper, cb,
×
2581
        )
×
2582
}
2583

2584
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2585
        dbIDs []int64) error {
×
2586

×
2587
        // Create a wrapper that uses the transaction's db instance to execute
×
2588
        // the query.
×
2589
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2590
                return nil, db.DeleteChannels(ctx, ids)
×
2591
        }
×
2592

2593
        idConverter := func(id int64) int64 {
×
2594
                return id
×
2595
        }
×
2596

2597
        return sqldb.ExecutePagedQuery(
×
2598
                ctx, s.cfg.PaginationCfg, dbIDs, idConverter,
×
2599
                queryWrapper, func(ctx context.Context, _ any) error {
×
2600
                        return nil
×
2601
                },
×
2602
        )
2603
}
2604

2605
// ChannelView returns the verifiable edge information for each active channel
2606
// within the known channel graph. The set of UTXOs (along with their scripts)
2607
// returned are the ones that need to be watched on chain to detect channel
2608
// closes on the resident blockchain.
2609
//
2610
// NOTE: part of the V1Store interface.
2611
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2612
        var (
×
2613
                ctx        = context.TODO()
×
2614
                edgePoints []EdgePoint
×
2615
        )
×
2616

×
2617
        handleChannel := func(db SQLQueries,
×
2618
                channel sqlc.ListChannelsPaginatedRow) error {
×
2619

×
2620
                pkScript, err := genMultiSigP2WSH(
×
2621
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
2622
                )
×
2623
                if err != nil {
×
2624
                        return err
×
2625
                }
×
2626

2627
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2628
                if err != nil {
×
2629
                        return err
×
2630
                }
×
2631

2632
                edgePoints = append(edgePoints, EdgePoint{
×
2633
                        FundingPkScript: pkScript,
×
2634
                        OutPoint:        *op,
×
2635
                })
×
2636

×
2637
                return nil
×
2638
        }
2639

2640
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2641
                lastID := int64(-1)
×
2642
                for {
×
2643
                        rows, err := db.ListChannelsPaginated(
×
2644
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2645
                                        Version: int16(ProtocolV1),
×
2646
                                        ID:      lastID,
×
2647
                                        Limit:   pageSize,
×
2648
                                },
×
2649
                        )
×
2650
                        if err != nil {
×
2651
                                return err
×
2652
                        }
×
2653

2654
                        if len(rows) == 0 {
×
2655
                                break
×
2656
                        }
2657

2658
                        for _, row := range rows {
×
2659
                                err := handleChannel(db, row)
×
2660
                                if err != nil {
×
2661
                                        return err
×
2662
                                }
×
2663

2664
                                lastID = row.ID
×
2665
                        }
2666
                }
2667

2668
                return nil
×
2669
        }, func() {
×
2670
                edgePoints = nil
×
2671
        })
×
2672
        if err != nil {
×
2673
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2674
        }
×
2675

2676
        return edgePoints, nil
×
2677
}
2678

2679
// PruneTip returns the block height and hash of the latest block that has been
2680
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2681
// to tell if the graph is currently in sync with the current best known UTXO
2682
// state.
2683
//
2684
// NOTE: part of the V1Store interface.
2685
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2686
        var (
×
2687
                ctx       = context.TODO()
×
2688
                tipHash   chainhash.Hash
×
2689
                tipHeight uint32
×
2690
        )
×
2691
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2692
                pruneTip, err := db.GetPruneTip(ctx)
×
2693
                if errors.Is(err, sql.ErrNoRows) {
×
2694
                        return ErrGraphNeverPruned
×
2695
                } else if err != nil {
×
2696
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2697
                }
×
2698

2699
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2700
                tipHeight = uint32(pruneTip.BlockHeight)
×
2701

×
2702
                return nil
×
2703
        }, sqldb.NoOpReset)
2704
        if err != nil {
×
2705
                return nil, 0, err
×
2706
        }
×
2707

2708
        return &tipHash, tipHeight, nil
×
2709
}
2710

2711
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2712
//
2713
// NOTE: this prunes nodes across protocol versions. It will never prune the
2714
// source nodes.
2715
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2716
        db SQLQueries) ([]route.Vertex, error) {
×
2717

×
2718
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2719
        if err != nil {
×
2720
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2721
                        "nodes: %w", err)
×
2722
        }
×
2723

2724
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2725
        for i, nodeKey := range nodeKeys {
×
2726
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2727
                if err != nil {
×
2728
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2729
                                "from bytes: %w", err)
×
2730
                }
×
2731

2732
                prunedNodes[i] = pub
×
2733
        }
2734

2735
        return prunedNodes, nil
×
2736
}
2737

2738
// DisconnectBlockAtHeight is used to indicate that the block specified
2739
// by the passed height has been disconnected from the main chain. This
2740
// will "rewind" the graph back to the height below, deleting channels
2741
// that are no longer confirmed from the graph. The prune log will be
2742
// set to the last prune height valid for the remaining chain.
2743
// Channels that were removed from the graph resulting from the
2744
// disconnected block are returned.
2745
//
2746
// NOTE: part of the V1Store interface.
2747
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2748
        []*models.ChannelEdgeInfo, error) {
×
2749

×
2750
        ctx := context.TODO()
×
2751

×
2752
        var (
×
2753
                // Every channel having a ShortChannelID starting at 'height'
×
2754
                // will no longer be confirmed.
×
2755
                startShortChanID = lnwire.ShortChannelID{
×
2756
                        BlockHeight: height,
×
2757
                }
×
2758

×
2759
                // Delete everything after this height from the db up until the
×
2760
                // SCID alias range.
×
2761
                endShortChanID = aliasmgr.StartingAlias
×
2762

×
2763
                removedChans []*models.ChannelEdgeInfo
×
2764

×
2765
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2766
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2767
        )
×
2768

×
2769
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2770
                rows, err := db.GetChannelsBySCIDRange(
×
2771
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2772
                                StartScid: chanIDStart,
×
2773
                                EndScid:   chanIDEnd,
×
2774
                        },
×
2775
                )
×
2776
                if err != nil {
×
2777
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2778
                }
×
2779

2780
                chanIDsToDelete := make([]int64, len(rows))
×
2781
                for i, row := range rows {
×
2782
                        node1, node2, err := buildNodeVertices(
×
2783
                                row.Node1PubKey, row.Node2PubKey,
×
2784
                        )
×
2785
                        if err != nil {
×
2786
                                return err
×
2787
                        }
×
2788

2789
                        channel, err := getAndBuildEdgeInfo(
×
2790
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
2791
                                row.GraphChannel, node1, node2,
×
2792
                        )
×
2793
                        if err != nil {
×
2794
                                return err
×
2795
                        }
×
2796

2797
                        chanIDsToDelete[i] = row.GraphChannel.ID
×
2798
                        removedChans = append(removedChans, channel)
×
2799
                }
2800

2801
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2802
                if err != nil {
×
2803
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2804
                }
×
2805

2806
                return db.DeletePruneLogEntriesInRange(
×
2807
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2808
                                StartHeight: int64(height),
×
2809
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2810
                        },
×
2811
                )
×
2812
        }, func() {
×
2813
                removedChans = nil
×
2814
        })
×
2815
        if err != nil {
×
2816
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2817
                        "height: %w", err)
×
2818
        }
×
2819

2820
        for _, channel := range removedChans {
×
2821
                s.rejectCache.remove(channel.ChannelID)
×
2822
                s.chanCache.remove(channel.ChannelID)
×
2823
        }
×
2824

2825
        return removedChans, nil
×
2826
}
2827

2828
// AddEdgeProof sets the proof of an existing edge in the graph database.
2829
//
2830
// NOTE: part of the V1Store interface.
2831
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2832
        proof *models.ChannelAuthProof) error {
×
2833

×
2834
        var (
×
2835
                ctx       = context.TODO()
×
2836
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2837
        )
×
2838

×
2839
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2840
                res, err := db.AddV1ChannelProof(
×
2841
                        ctx, sqlc.AddV1ChannelProofParams{
×
2842
                                Scid:              scidBytes,
×
2843
                                Node1Signature:    proof.NodeSig1Bytes,
×
2844
                                Node2Signature:    proof.NodeSig2Bytes,
×
2845
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2846
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2847
                        },
×
2848
                )
×
2849
                if err != nil {
×
2850
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2851
                }
×
2852

2853
                n, err := res.RowsAffected()
×
2854
                if err != nil {
×
2855
                        return err
×
2856
                }
×
2857

2858
                if n == 0 {
×
2859
                        return fmt.Errorf("no rows affected when adding edge "+
×
2860
                                "proof for SCID %v", scid)
×
2861
                } else if n > 1 {
×
2862
                        return fmt.Errorf("multiple rows affected when adding "+
×
2863
                                "edge proof for SCID %v: %d rows affected",
×
2864
                                scid, n)
×
2865
                }
×
2866

2867
                return nil
×
2868
        }, sqldb.NoOpReset)
2869
        if err != nil {
×
2870
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2871
        }
×
2872

2873
        return nil
×
2874
}
2875

2876
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2877
// that we can ignore channel announcements that we know to be closed without
2878
// having to validate them and fetch a block.
2879
//
2880
// NOTE: part of the V1Store interface.
2881
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2882
        var (
×
2883
                ctx     = context.TODO()
×
2884
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2885
        )
×
2886

×
2887
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2888
                return db.InsertClosedChannel(ctx, chanIDB)
×
2889
        }, sqldb.NoOpReset)
×
2890
}
2891

2892
// IsClosedScid checks whether a channel identified by the passed in scid is
2893
// closed. This helps avoid having to perform expensive validation checks.
2894
//
2895
// NOTE: part of the V1Store interface.
2896
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2897
        var (
×
2898
                ctx      = context.TODO()
×
2899
                isClosed bool
×
2900
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2901
        )
×
2902
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2903
                var err error
×
2904
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2905
                if err != nil {
×
2906
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2907
                                err)
×
2908
                }
×
2909

2910
                return nil
×
2911
        }, sqldb.NoOpReset)
2912
        if err != nil {
×
2913
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2914
                        err)
×
2915
        }
×
2916

2917
        return isClosed, nil
×
2918
}
2919

2920
// GraphSession will provide the call-back with access to a NodeTraverser
2921
// instance which can be used to perform queries against the channel graph.
2922
//
2923
// NOTE: part of the V1Store interface.
2924
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
2925
        reset func()) error {
×
2926

×
2927
        var ctx = context.TODO()
×
2928

×
2929
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2930
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2931
        }, reset)
×
2932
}
2933

2934
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2935
// read only transaction for a consistent view of the graph.
2936
type sqlNodeTraverser struct {
2937
        db    SQLQueries
2938
        chain chainhash.Hash
2939
}
2940

2941
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2942
// NodeTraverser interface.
2943
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2944

2945
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2946
func newSQLNodeTraverser(db SQLQueries,
2947
        chain chainhash.Hash) *sqlNodeTraverser {
×
2948

×
2949
        return &sqlNodeTraverser{
×
2950
                db:    db,
×
2951
                chain: chain,
×
2952
        }
×
2953
}
×
2954

2955
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2956
// node.
2957
//
2958
// NOTE: Part of the NodeTraverser interface.
2959
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2960
        cb func(channel *DirectedChannel) error, _ func()) error {
×
2961

×
2962
        ctx := context.TODO()
×
2963

×
2964
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2965
}
×
2966

2967
// FetchNodeFeatures returns the features of the given node. If the node is
2968
// unknown, assume no additional features are supported.
2969
//
2970
// NOTE: Part of the NodeTraverser interface.
2971
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2972
        *lnwire.FeatureVector, error) {
×
2973

×
2974
        ctx := context.TODO()
×
2975

×
2976
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2977
}
×
2978

2979
// forEachNodeDirectedChannel iterates through all channels of a given
2980
// node, executing the passed callback on the directed edge representing the
2981
// channel and its incoming policy. If the node is not found, no error is
2982
// returned.
2983
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2984
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2985

×
2986
        toNodeCallback := func() route.Vertex {
×
2987
                return nodePub
×
2988
        }
×
2989

2990
        dbID, err := db.GetNodeIDByPubKey(
×
2991
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2992
                        Version: int16(ProtocolV1),
×
2993
                        PubKey:  nodePub[:],
×
2994
                },
×
2995
        )
×
2996
        if errors.Is(err, sql.ErrNoRows) {
×
2997
                return nil
×
2998
        } else if err != nil {
×
2999
                return fmt.Errorf("unable to fetch node: %w", err)
×
3000
        }
×
3001

3002
        rows, err := db.ListChannelsByNodeID(
×
3003
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3004
                        Version: int16(ProtocolV1),
×
3005
                        NodeID1: dbID,
×
3006
                },
×
3007
        )
×
3008
        if err != nil {
×
3009
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3010
        }
×
3011

3012
        // Exit early if there are no channels for this node so we don't
3013
        // do the unnecessary feature fetching.
3014
        if len(rows) == 0 {
×
3015
                return nil
×
3016
        }
×
3017

3018
        features, err := getNodeFeatures(ctx, db, dbID)
×
3019
        if err != nil {
×
3020
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3021
        }
×
3022

3023
        for _, row := range rows {
×
3024
                node1, node2, err := buildNodeVertices(
×
3025
                        row.Node1Pubkey, row.Node2Pubkey,
×
3026
                )
×
3027
                if err != nil {
×
3028
                        return fmt.Errorf("unable to build node vertices: %w",
×
3029
                                err)
×
3030
                }
×
3031

3032
                edge := buildCacheableChannelInfo(
×
NEW
3033
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
NEW
3034
                        node1, node2,
×
3035
                )
×
3036

×
3037
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3038
                if err != nil {
×
3039
                        return err
×
3040
                }
×
3041

3042
                var p1, p2 *models.CachedEdgePolicy
×
3043
                if dbPol1 != nil {
×
3044
                        policy1, err := buildChanPolicy(
×
3045
                                *dbPol1, edge.ChannelID, nil, node2,
×
3046
                        )
×
3047
                        if err != nil {
×
3048
                                return err
×
3049
                        }
×
3050

3051
                        p1 = models.NewCachedPolicy(policy1)
×
3052
                }
3053
                if dbPol2 != nil {
×
3054
                        policy2, err := buildChanPolicy(
×
3055
                                *dbPol2, edge.ChannelID, nil, node1,
×
3056
                        )
×
3057
                        if err != nil {
×
3058
                                return err
×
3059
                        }
×
3060

3061
                        p2 = models.NewCachedPolicy(policy2)
×
3062
                }
3063

3064
                // Determine the outgoing and incoming policy for this
3065
                // channel and node combo.
3066
                outPolicy, inPolicy := p1, p2
×
3067
                if p1 != nil && node2 == nodePub {
×
3068
                        outPolicy, inPolicy = p2, p1
×
3069
                } else if p2 != nil && node1 != nodePub {
×
3070
                        outPolicy, inPolicy = p2, p1
×
3071
                }
×
3072

3073
                var cachedInPolicy *models.CachedEdgePolicy
×
3074
                if inPolicy != nil {
×
3075
                        cachedInPolicy = inPolicy
×
3076
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3077
                        cachedInPolicy.ToNodeFeatures = features
×
3078
                }
×
3079

3080
                directedChannel := &DirectedChannel{
×
3081
                        ChannelID:    edge.ChannelID,
×
3082
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3083
                        OtherNode:    edge.NodeKey2Bytes,
×
3084
                        Capacity:     edge.Capacity,
×
3085
                        OutPolicySet: outPolicy != nil,
×
3086
                        InPolicy:     cachedInPolicy,
×
3087
                }
×
3088
                if outPolicy != nil {
×
3089
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3090
                                directedChannel.InboundFee = fee
×
3091
                        })
×
3092
                }
3093

3094
                if nodePub == edge.NodeKey2Bytes {
×
3095
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3096
                }
×
3097

3098
                if err := cb(directedChannel); err != nil {
×
3099
                        return err
×
3100
                }
×
3101
        }
3102

3103
        return nil
×
3104
}
3105

3106
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3107
// and executes the provided callback for each node.
3108
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
3109
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
3110

×
3111
        lastID := int64(-1)
×
3112

×
3113
        for {
×
3114
                nodes, err := db.ListNodeIDsAndPubKeys(
×
3115
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3116
                                Version: int16(ProtocolV1),
×
3117
                                ID:      lastID,
×
3118
                                Limit:   pageSize,
×
3119
                        },
×
3120
                )
×
3121
                if err != nil {
×
3122
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
3123
                }
×
3124

3125
                if len(nodes) == 0 {
×
3126
                        break
×
3127
                }
3128

3129
                for _, node := range nodes {
×
3130
                        var pub route.Vertex
×
3131
                        copy(pub[:], node.PubKey)
×
3132

×
3133
                        if err := cb(node.ID, pub); err != nil {
×
3134
                                return fmt.Errorf("forEachNodeCacheable "+
×
3135
                                        "callback failed for node(id=%d): %w",
×
3136
                                        node.ID, err)
×
3137
                        }
×
3138

3139
                        lastID = node.ID
×
3140
                }
3141
        }
3142

3143
        return nil
×
3144
}
3145

3146
// forEachNodeChannel iterates through all channels of a node, executing
3147
// the passed callback on each. The call-back is provided with the channel's
3148
// edge information, the outgoing policy and the incoming policy for the
3149
// channel and node combo.
3150
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3151
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
3152
                *models.ChannelEdgePolicy,
3153
                *models.ChannelEdgePolicy) error) error {
×
3154

×
3155
        // Get all the V1 channels for this node.Add commentMore actions
×
3156
        rows, err := db.ListChannelsByNodeID(
×
3157
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3158
                        Version: int16(ProtocolV1),
×
3159
                        NodeID1: id,
×
3160
                },
×
3161
        )
×
3162
        if err != nil {
×
3163
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3164
        }
×
3165

3166
        // Call the call-back for each channel and its known policies.
3167
        for _, row := range rows {
×
3168
                node1, node2, err := buildNodeVertices(
×
3169
                        row.Node1Pubkey, row.Node2Pubkey,
×
3170
                )
×
3171
                if err != nil {
×
3172
                        return fmt.Errorf("unable to build node vertices: %w",
×
3173
                                err)
×
3174
                }
×
3175

3176
                edge, err := getAndBuildEdgeInfo(
×
3177
                        ctx, db, chain, row.GraphChannel.ID, row.GraphChannel,
×
3178
                        node1, node2,
×
3179
                )
×
3180
                if err != nil {
×
3181
                        return fmt.Errorf("unable to build channel info: %w",
×
3182
                                err)
×
3183
                }
×
3184

3185
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3186
                if err != nil {
×
3187
                        return fmt.Errorf("unable to extract channel "+
×
3188
                                "policies: %w", err)
×
3189
                }
×
3190

3191
                p1, p2, err := getAndBuildChanPolicies(
×
3192
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3193
                )
×
3194
                if err != nil {
×
3195
                        return fmt.Errorf("unable to build channel "+
×
3196
                                "policies: %w", err)
×
3197
                }
×
3198

3199
                // Determine the outgoing and incoming policy for this
3200
                // channel and node combo.
3201
                p1ToNode := row.GraphChannel.NodeID2
×
3202
                p2ToNode := row.GraphChannel.NodeID1
×
3203
                outPolicy, inPolicy := p1, p2
×
3204
                if (p1 != nil && p1ToNode == id) ||
×
3205
                        (p2 != nil && p2ToNode != id) {
×
3206

×
3207
                        outPolicy, inPolicy = p2, p1
×
3208
                }
×
3209

3210
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3211
                        return err
×
3212
                }
×
3213
        }
3214

3215
        return nil
×
3216
}
3217

3218
// updateChanEdgePolicy upserts the channel policy info we have stored for
3219
// a channel we already know of.
3220
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3221
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3222
        error) {
×
3223

×
3224
        var (
×
3225
                node1Pub, node2Pub route.Vertex
×
3226
                isNode1            bool
×
3227
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3228
        )
×
3229

×
3230
        // Check that this edge policy refers to a channel that we already
×
3231
        // know of. We do this explicitly so that we can return the appropriate
×
3232
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3233
        // abort the transaction which would abort the entire batch.
×
3234
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3235
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3236
                        Scid:    chanIDB,
×
3237
                        Version: int16(ProtocolV1),
×
3238
                },
×
3239
        )
×
3240
        if errors.Is(err, sql.ErrNoRows) {
×
3241
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3242
        } else if err != nil {
×
3243
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3244
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3245
        }
×
3246

3247
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3248
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3249

×
3250
        // Figure out which node this edge is from.
×
3251
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3252
        nodeID := dbChan.NodeID1
×
3253
        if !isNode1 {
×
3254
                nodeID = dbChan.NodeID2
×
3255
        }
×
3256

3257
        var (
×
3258
                inboundBase sql.NullInt64
×
3259
                inboundRate sql.NullInt64
×
3260
        )
×
3261
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3262
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3263
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3264
        })
×
3265

3266
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3267
                Version:     int16(ProtocolV1),
×
3268
                ChannelID:   dbChan.ID,
×
3269
                NodeID:      nodeID,
×
3270
                Timelock:    int32(edge.TimeLockDelta),
×
3271
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3272
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3273
                MinHtlcMsat: int64(edge.MinHTLC),
×
3274
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3275
                Disabled: sql.NullBool{
×
3276
                        Valid: true,
×
3277
                        Bool:  edge.IsDisabled(),
×
3278
                },
×
3279
                MaxHtlcMsat: sql.NullInt64{
×
3280
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3281
                        Int64: int64(edge.MaxHTLC),
×
3282
                },
×
3283
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3284
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3285
                InboundBaseFeeMsat:      inboundBase,
×
3286
                InboundFeeRateMilliMsat: inboundRate,
×
3287
                Signature:               edge.SigBytes,
×
3288
        })
×
3289
        if err != nil {
×
3290
                return node1Pub, node2Pub, isNode1,
×
3291
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3292
        }
×
3293

3294
        // Convert the flat extra opaque data into a map of TLV types to
3295
        // values.
3296
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3297
        if err != nil {
×
3298
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3299
                        "marshal extra opaque data: %w", err)
×
3300
        }
×
3301

3302
        // Update the channel policy's extra signed fields.
3303
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3304
        if err != nil {
×
3305
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3306
                        "policy extra TLVs: %w", err)
×
3307
        }
×
3308

3309
        return node1Pub, node2Pub, isNode1, nil
×
3310
}
3311

3312
// getNodeByPubKey attempts to look up a target node by its public key.
3313
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3314
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3315

×
3316
        dbNode, err := db.GetNodeByPubKey(
×
3317
                ctx, sqlc.GetNodeByPubKeyParams{
×
3318
                        Version: int16(ProtocolV1),
×
3319
                        PubKey:  pubKey[:],
×
3320
                },
×
3321
        )
×
3322
        if errors.Is(err, sql.ErrNoRows) {
×
3323
                return 0, nil, ErrGraphNodeNotFound
×
3324
        } else if err != nil {
×
3325
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3326
        }
×
3327

3328
        node, err := buildNode(ctx, db, &dbNode)
×
3329
        if err != nil {
×
3330
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3331
        }
×
3332

3333
        return dbNode.ID, node, nil
×
3334
}
3335

3336
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3337
// provided parameters.
3338
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3339
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3340

×
3341
        return &models.CachedEdgeInfo{
×
NEW
3342
                ChannelID:     byteOrder.Uint64(scid),
×
3343
                NodeKey1Bytes: node1Pub,
×
3344
                NodeKey2Bytes: node2Pub,
×
NEW
3345
                Capacity:      btcutil.Amount(capacity),
×
3346
        }
×
3347
}
×
3348

3349
// buildNode constructs a LightningNode instance from the given database node
3350
// record. The node's features, addresses and extra signed fields are also
3351
// fetched from the database and set on the node.
3352
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.GraphNode) (
3353
        *models.LightningNode, error) {
×
3354

×
3355
        if dbNode.Version != int16(ProtocolV1) {
×
3356
                return nil, fmt.Errorf("unsupported node version: %d",
×
3357
                        dbNode.Version)
×
3358
        }
×
3359

3360
        var pub [33]byte
×
3361
        copy(pub[:], dbNode.PubKey)
×
3362

×
3363
        node := &models.LightningNode{
×
3364
                PubKeyBytes: pub,
×
3365
                Features:    lnwire.EmptyFeatureVector(),
×
3366
                LastUpdate:  time.Unix(0, 0),
×
3367
        }
×
3368

×
3369
        if len(dbNode.Signature) == 0 {
×
3370
                return node, nil
×
3371
        }
×
3372

3373
        node.HaveNodeAnnouncement = true
×
3374
        node.AuthSigBytes = dbNode.Signature
×
3375
        node.Alias = dbNode.Alias.String
×
3376
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3377

×
3378
        var err error
×
3379
        if dbNode.Color.Valid {
×
3380
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3381
                if err != nil {
×
3382
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3383
                                err)
×
3384
                }
×
3385
        }
3386

3387
        // Fetch the node's features.
3388
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
3389
        if err != nil {
×
3390
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3391
                        "features: %w", dbNode.ID, err)
×
3392
        }
×
3393

3394
        // Fetch the node's addresses.
NEW
3395
        node.Addresses, err = getNodeAddresses(ctx, db, dbNode.ID)
×
3396
        if err != nil {
×
3397
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3398
                        "addresses: %w", dbNode.ID, err)
×
3399
        }
×
3400

3401
        // Fetch the node's extra signed fields.
3402
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
3403
        if err != nil {
×
3404
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3405
                        "extra signed fields: %w", dbNode.ID, err)
×
3406
        }
×
3407

3408
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
3409
        if err != nil {
×
3410
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3411
                        "fields: %w", err)
×
3412
        }
×
3413

3414
        if len(recs) != 0 {
×
3415
                node.ExtraOpaqueData = recs
×
3416
        }
×
3417

3418
        return node, nil
×
3419
}
3420

3421
// getNodeFeatures fetches the feature bits and constructs the feature vector
3422
// for a node with the given DB ID.
3423
func getNodeFeatures(ctx context.Context, db SQLQueries,
3424
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3425

×
3426
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3427
        if err != nil {
×
3428
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3429
                        nodeID, err)
×
3430
        }
×
3431

3432
        features := lnwire.EmptyFeatureVector()
×
3433
        for _, feature := range rows {
×
3434
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3435
        }
×
3436

3437
        return features, nil
×
3438
}
3439

3440
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3441
// given DB ID.
3442
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3443
        nodeID int64) (map[uint64][]byte, error) {
×
3444

×
3445
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3446
        if err != nil {
×
3447
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
3448
                        "signed fields: %w", nodeID, err)
×
3449
        }
×
3450

3451
        extraFields := make(map[uint64][]byte)
×
3452
        for _, field := range fields {
×
3453
                extraFields[uint64(field.Type)] = field.Value
×
3454
        }
×
3455

3456
        return extraFields, nil
×
3457
}
3458

3459
// upsertNode upserts the node record into the database. If the node already
3460
// exists, then the node's information is updated. If the node doesn't exist,
3461
// then a new node is created. The node's features, addresses and extra TLV
3462
// types are also updated. The node's DB ID is returned.
3463
func upsertNode(ctx context.Context, db SQLQueries,
3464
        node *models.LightningNode) (int64, error) {
×
3465

×
3466
        params := sqlc.UpsertNodeParams{
×
3467
                Version: int16(ProtocolV1),
×
3468
                PubKey:  node.PubKeyBytes[:],
×
3469
        }
×
3470

×
3471
        if node.HaveNodeAnnouncement {
×
3472
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3473
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3474
                params.Alias = sqldb.SQLStr(node.Alias)
×
3475
                params.Signature = node.AuthSigBytes
×
3476
        }
×
3477

3478
        nodeID, err := db.UpsertNode(ctx, params)
×
3479
        if err != nil {
×
3480
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3481
                        err)
×
3482
        }
×
3483

3484
        // We can exit here if we don't have the announcement yet.
3485
        if !node.HaveNodeAnnouncement {
×
3486
                return nodeID, nil
×
3487
        }
×
3488

3489
        // Update the node's features.
3490
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3491
        if err != nil {
×
3492
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3493
        }
×
3494

3495
        // Update the node's addresses.
3496
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3497
        if err != nil {
×
3498
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3499
        }
×
3500

3501
        // Convert the flat extra opaque data into a map of TLV types to
3502
        // values.
3503
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3504
        if err != nil {
×
3505
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3506
                        err)
×
3507
        }
×
3508

3509
        // Update the node's extra signed fields.
3510
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3511
        if err != nil {
×
3512
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3513
        }
×
3514

3515
        return nodeID, nil
×
3516
}
3517

3518
// upsertNodeFeatures updates the node's features node_features table. This
3519
// includes deleting any feature bits no longer present and inserting any new
3520
// feature bits. If the feature bit does not yet exist in the features table,
3521
// then an entry is created in that table first.
3522
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3523
        features *lnwire.FeatureVector) error {
×
3524

×
3525
        // Get any existing features for the node.
×
3526
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3527
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3528
                return err
×
3529
        }
×
3530

3531
        // Copy the nodes latest set of feature bits.
3532
        newFeatures := make(map[int32]struct{})
×
3533
        if features != nil {
×
3534
                for feature := range features.Features() {
×
3535
                        newFeatures[int32(feature)] = struct{}{}
×
3536
                }
×
3537
        }
3538

3539
        // For any current feature that already exists in the DB, remove it from
3540
        // the in-memory map. For any existing feature that does not exist in
3541
        // the in-memory map, delete it from the database.
3542
        for _, feature := range existingFeatures {
×
3543
                // The feature is still present, so there are no updates to be
×
3544
                // made.
×
3545
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3546
                        delete(newFeatures, feature.FeatureBit)
×
3547
                        continue
×
3548
                }
3549

3550
                // The feature is no longer present, so we remove it from the
3551
                // database.
3552
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3553
                        NodeID:     nodeID,
×
3554
                        FeatureBit: feature.FeatureBit,
×
3555
                })
×
3556
                if err != nil {
×
3557
                        return fmt.Errorf("unable to delete node(%d) "+
×
3558
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3559
                                err)
×
3560
                }
×
3561
        }
3562

3563
        // Any remaining entries in newFeatures are new features that need to be
3564
        // added to the database for the first time.
3565
        for feature := range newFeatures {
×
3566
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3567
                        NodeID:     nodeID,
×
3568
                        FeatureBit: feature,
×
3569
                })
×
3570
                if err != nil {
×
3571
                        return fmt.Errorf("unable to insert node(%d) "+
×
3572
                                "feature(%v): %w", nodeID, feature, err)
×
3573
                }
×
3574
        }
3575

3576
        return nil
×
3577
}
3578

3579
// fetchNodeFeatures fetches the features for a node with the given public key.
3580
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3581
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3582

×
3583
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3584
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3585
                        PubKey:  nodePub[:],
×
3586
                        Version: int16(ProtocolV1),
×
3587
                },
×
3588
        )
×
3589
        if err != nil {
×
3590
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3591
                        nodePub, err)
×
3592
        }
×
3593

3594
        features := lnwire.EmptyFeatureVector()
×
3595
        for _, bit := range rows {
×
3596
                features.Set(lnwire.FeatureBit(bit))
×
3597
        }
×
3598

3599
        return features, nil
×
3600
}
3601

3602
// dbAddressType is an enum type that represents the different address types
3603
// that we store in the node_addresses table. The address type determines how
3604
// the address is to be serialised/deserialize.
3605
type dbAddressType uint8
3606

3607
const (
3608
        addressTypeIPv4   dbAddressType = 1
3609
        addressTypeIPv6   dbAddressType = 2
3610
        addressTypeTorV2  dbAddressType = 3
3611
        addressTypeTorV3  dbAddressType = 4
3612
        addressTypeOpaque dbAddressType = math.MaxInt8
3613
)
3614

3615
// upsertNodeAddresses updates the node's addresses in the database. This
3616
// includes deleting any existing addresses and inserting the new set of
3617
// addresses. The deletion is necessary since the ordering of the addresses may
3618
// change, and we need to ensure that the database reflects the latest set of
3619
// addresses so that at the time of reconstructing the node announcement, the
3620
// order is preserved and the signature over the message remains valid.
3621
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3622
        addresses []net.Addr) error {
×
3623

×
3624
        // Delete any existing addresses for the node. This is required since
×
3625
        // even if the new set of addresses is the same, the ordering may have
×
3626
        // changed for a given address type.
×
3627
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3628
        if err != nil {
×
3629
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3630
                        nodeID, err)
×
3631
        }
×
3632

3633
        // Copy the nodes latest set of addresses.
3634
        newAddresses := map[dbAddressType][]string{
×
3635
                addressTypeIPv4:   {},
×
3636
                addressTypeIPv6:   {},
×
3637
                addressTypeTorV2:  {},
×
3638
                addressTypeTorV3:  {},
×
3639
                addressTypeOpaque: {},
×
3640
        }
×
3641
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3642
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3643
        }
×
3644

3645
        for _, address := range addresses {
×
3646
                switch addr := address.(type) {
×
3647
                case *net.TCPAddr:
×
3648
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3649
                                addAddr(addressTypeIPv4, addr)
×
3650
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3651
                                addAddr(addressTypeIPv6, addr)
×
3652
                        } else {
×
3653
                                return fmt.Errorf("unhandled IP address: %v",
×
3654
                                        addr)
×
3655
                        }
×
3656

3657
                case *tor.OnionAddr:
×
3658
                        switch len(addr.OnionService) {
×
3659
                        case tor.V2Len:
×
3660
                                addAddr(addressTypeTorV2, addr)
×
3661
                        case tor.V3Len:
×
3662
                                addAddr(addressTypeTorV3, addr)
×
3663
                        default:
×
3664
                                return fmt.Errorf("invalid length for a tor " +
×
3665
                                        "address")
×
3666
                        }
3667

3668
                case *lnwire.OpaqueAddrs:
×
3669
                        addAddr(addressTypeOpaque, addr)
×
3670

3671
                default:
×
3672
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3673
                }
3674
        }
3675

3676
        // Any remaining entries in newAddresses are new addresses that need to
3677
        // be added to the database for the first time.
3678
        for addrType, addrList := range newAddresses {
×
3679
                for position, addr := range addrList {
×
3680
                        err := db.InsertNodeAddress(
×
3681
                                ctx, sqlc.InsertNodeAddressParams{
×
3682
                                        NodeID:   nodeID,
×
3683
                                        Type:     int16(addrType),
×
3684
                                        Address:  addr,
×
3685
                                        Position: int32(position),
×
3686
                                },
×
3687
                        )
×
3688
                        if err != nil {
×
3689
                                return fmt.Errorf("unable to insert "+
×
3690
                                        "node(%d) address(%v): %w", nodeID,
×
3691
                                        addr, err)
×
3692
                        }
×
3693
                }
3694
        }
3695

3696
        return nil
×
3697
}
3698

3699
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3700
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
NEW
3701
        error) {
×
3702

×
NEW
3703
        // GetNodeAddresses ensures that the addresses for a given type are
×
NEW
3704
        // returned in the same order as they were inserted.
×
NEW
3705
        rows, err := db.GetNodeAddresses(ctx, id)
×
3706
        if err != nil {
×
NEW
3707
                return nil, err
×
3708
        }
×
3709

3710
        addresses := make([]net.Addr, 0, len(rows))
×
NEW
3711
        for _, row := range rows {
×
NEW
3712
                address := row.Address
×
3713

×
NEW
3714
                switch dbAddressType(row.Type) {
×
3715
                case addressTypeIPv4:
×
3716
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3717
                        if err != nil {
×
NEW
3718
                                return nil, err
×
3719
                        }
×
3720
                        tcp.IP = tcp.IP.To4()
×
3721

×
3722
                        addresses = append(addresses, tcp)
×
3723

3724
                case addressTypeIPv6:
×
3725
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3726
                        if err != nil {
×
NEW
3727
                                return nil, err
×
3728
                        }
×
3729
                        addresses = append(addresses, tcp)
×
3730

3731
                case addressTypeTorV3, addressTypeTorV2:
×
3732
                        service, portStr, err := net.SplitHostPort(address)
×
3733
                        if err != nil {
×
NEW
3734
                                return nil, fmt.Errorf("unable to "+
×
NEW
3735
                                        "split tor v3 address: %v", address)
×
UNCOV
3736
                        }
×
3737

3738
                        port, err := strconv.Atoi(portStr)
×
3739
                        if err != nil {
×
NEW
3740
                                return nil, err
×
3741
                        }
×
3742

3743
                        addresses = append(addresses, &tor.OnionAddr{
×
3744
                                OnionService: service,
×
3745
                                Port:         port,
×
3746
                        })
×
3747

3748
                case addressTypeOpaque:
×
3749
                        opaque, err := hex.DecodeString(address)
×
3750
                        if err != nil {
×
NEW
3751
                                return nil, fmt.Errorf("unable to "+
×
NEW
3752
                                        "decode opaque address: %v", address)
×
UNCOV
3753
                        }
×
3754

3755
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3756
                                Payload: opaque,
×
3757
                        })
×
3758

3759
                default:
×
NEW
3760
                        return nil, fmt.Errorf("unknown address type: %v",
×
NEW
3761
                                row.Type)
×
3762
                }
3763
        }
3764

3765
        // If we have no addresses, then we'll return nil instead of an
3766
        // empty slice.
3767
        if len(addresses) == 0 {
×
3768
                addresses = nil
×
3769
        }
×
3770

NEW
3771
        return addresses, nil
×
3772
}
3773

3774
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3775
// database. This includes updating any existing types, inserting any new types,
3776
// and deleting any types that are no longer present.
3777
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3778
        nodeID int64, extraFields map[uint64][]byte) error {
×
3779

×
3780
        // Get any existing extra signed fields for the node.
×
3781
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3782
        if err != nil {
×
3783
                return err
×
3784
        }
×
3785

3786
        // Make a lookup map of the existing field types so that we can use it
3787
        // to keep track of any fields we should delete.
3788
        m := make(map[uint64]bool)
×
3789
        for _, field := range existingFields {
×
3790
                m[uint64(field.Type)] = true
×
3791
        }
×
3792

3793
        // For all the new fields, we'll upsert them and remove them from the
3794
        // map of existing fields.
3795
        for tlvType, value := range extraFields {
×
3796
                err = db.UpsertNodeExtraType(
×
3797
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3798
                                NodeID: nodeID,
×
3799
                                Type:   int64(tlvType),
×
3800
                                Value:  value,
×
3801
                        },
×
3802
                )
×
3803
                if err != nil {
×
3804
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3805
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3806
                }
×
3807

3808
                // Remove the field from the map of existing fields if it was
3809
                // present.
3810
                delete(m, tlvType)
×
3811
        }
3812

3813
        // For all the fields that are left in the map of existing fields, we'll
3814
        // delete them as they are no longer present in the new set of fields.
3815
        for tlvType := range m {
×
3816
                err = db.DeleteExtraNodeType(
×
3817
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3818
                                NodeID: nodeID,
×
3819
                                Type:   int64(tlvType),
×
3820
                        },
×
3821
                )
×
3822
                if err != nil {
×
3823
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3824
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3825
                }
×
3826
        }
3827

3828
        return nil
×
3829
}
3830

3831
// srcNodeInfo holds the information about the source node of the graph.
3832
type srcNodeInfo struct {
3833
        // id is the DB level ID of the source node entry in the "nodes" table.
3834
        id int64
3835

3836
        // pub is the public key of the source node.
3837
        pub route.Vertex
3838
}
3839

3840
// sourceNode returns the DB node ID and pub key of the source node for the
3841
// specified protocol version.
3842
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3843
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3844

×
3845
        s.srcNodeMu.Lock()
×
3846
        defer s.srcNodeMu.Unlock()
×
3847

×
3848
        // If we already have the source node ID and pub key cached, then
×
3849
        // return them.
×
3850
        if info, ok := s.srcNodes[version]; ok {
×
3851
                return info.id, info.pub, nil
×
3852
        }
×
3853

3854
        var pubKey route.Vertex
×
3855

×
3856
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3857
        if err != nil {
×
3858
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3859
                        err)
×
3860
        }
×
3861

3862
        if len(nodes) == 0 {
×
3863
                return 0, pubKey, ErrSourceNodeNotSet
×
3864
        } else if len(nodes) > 1 {
×
3865
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3866
                        "protocol %s found", version)
×
3867
        }
×
3868

3869
        copy(pubKey[:], nodes[0].PubKey)
×
3870

×
3871
        s.srcNodes[version] = &srcNodeInfo{
×
3872
                id:  nodes[0].NodeID,
×
3873
                pub: pubKey,
×
3874
        }
×
3875

×
3876
        return nodes[0].NodeID, pubKey, nil
×
3877
}
3878

3879
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3880
// This then produces a map from TLV type to value. If the input is not a
3881
// valid TLV stream, then an error is returned.
3882
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3883
        r := bytes.NewReader(data)
×
3884

×
3885
        tlvStream, err := tlv.NewStream()
×
3886
        if err != nil {
×
3887
                return nil, err
×
3888
        }
×
3889

3890
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3891
        // pass it into the P2P decoding variant.
3892
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3893
        if err != nil {
×
3894
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3895
        }
×
3896
        if len(parsedTypes) == 0 {
×
3897
                return nil, nil
×
3898
        }
×
3899

3900
        records := make(map[uint64][]byte)
×
3901
        for k, v := range parsedTypes {
×
3902
                records[uint64(k)] = v
×
3903
        }
×
3904

3905
        return records, nil
×
3906
}
3907

3908
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3909
// channel.
3910
type dbChanInfo struct {
3911
        channelID int64
3912
        node1ID   int64
3913
        node2ID   int64
3914
}
3915

3916
// insertChannel inserts a new channel record into the database.
3917
func insertChannel(ctx context.Context, db SQLQueries,
3918
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3919

×
3920
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3921

×
3922
        // Make sure that the channel doesn't already exist. We do this
×
3923
        // explicitly instead of relying on catching a unique constraint error
×
3924
        // because relying on SQL to throw that error would abort the entire
×
3925
        // batch of transactions.
×
3926
        _, err := db.GetChannelBySCID(
×
3927
                ctx, sqlc.GetChannelBySCIDParams{
×
3928
                        Scid:    chanIDB,
×
3929
                        Version: int16(ProtocolV1),
×
3930
                },
×
3931
        )
×
3932
        if err == nil {
×
3933
                return nil, ErrEdgeAlreadyExist
×
3934
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3935
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3936
        }
×
3937

3938
        // Make sure that at least a "shell" entry for each node is present in
3939
        // the nodes table.
3940
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3941
        if err != nil {
×
3942
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3943
        }
×
3944

3945
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3946
        if err != nil {
×
3947
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3948
        }
×
3949

3950
        var capacity sql.NullInt64
×
3951
        if edge.Capacity != 0 {
×
3952
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3953
        }
×
3954

3955
        createParams := sqlc.CreateChannelParams{
×
3956
                Version:     int16(ProtocolV1),
×
3957
                Scid:        chanIDB,
×
3958
                NodeID1:     node1DBID,
×
3959
                NodeID2:     node2DBID,
×
3960
                Outpoint:    edge.ChannelPoint.String(),
×
3961
                Capacity:    capacity,
×
3962
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3963
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3964
        }
×
3965

×
3966
        if edge.AuthProof != nil {
×
3967
                proof := edge.AuthProof
×
3968

×
3969
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3970
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3971
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3972
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3973
        }
×
3974

3975
        // Insert the new channel record.
3976
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3977
        if err != nil {
×
3978
                return nil, err
×
3979
        }
×
3980

3981
        // Insert any channel features.
3982
        for feature := range edge.Features.Features() {
×
3983
                err = db.InsertChannelFeature(
×
3984
                        ctx, sqlc.InsertChannelFeatureParams{
×
3985
                                ChannelID:  dbChanID,
×
3986
                                FeatureBit: int32(feature),
×
3987
                        },
×
3988
                )
×
3989
                if err != nil {
×
3990
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
3991
                                "feature(%v): %w", dbChanID, feature, err)
×
3992
                }
×
3993
        }
3994

3995
        // Finally, insert any extra TLV fields in the channel announcement.
3996
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3997
        if err != nil {
×
3998
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
3999
                        "data: %w", err)
×
4000
        }
×
4001

4002
        for tlvType, value := range extra {
×
4003
                err := db.CreateChannelExtraType(
×
4004
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
4005
                                ChannelID: dbChanID,
×
4006
                                Type:      int64(tlvType),
×
4007
                                Value:     value,
×
4008
                        },
×
4009
                )
×
4010
                if err != nil {
×
4011
                        return nil, fmt.Errorf("unable to upsert "+
×
4012
                                "channel(%d) extra signed field(%v): %w",
×
4013
                                edge.ChannelID, tlvType, err)
×
4014
                }
×
4015
        }
4016

4017
        return &dbChanInfo{
×
4018
                channelID: dbChanID,
×
4019
                node1ID:   node1DBID,
×
4020
                node2ID:   node2DBID,
×
4021
        }, nil
×
4022
}
4023

4024
// maybeCreateShellNode checks if a shell node entry exists for the
4025
// given public key. If it does not exist, then a new shell node entry is
4026
// created. The ID of the node is returned. A shell node only has a protocol
4027
// version and public key persisted.
4028
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
4029
        pubKey route.Vertex) (int64, error) {
×
4030

×
4031
        dbNode, err := db.GetNodeByPubKey(
×
4032
                ctx, sqlc.GetNodeByPubKeyParams{
×
4033
                        PubKey:  pubKey[:],
×
4034
                        Version: int16(ProtocolV1),
×
4035
                },
×
4036
        )
×
4037
        // The node exists. Return the ID.
×
4038
        if err == nil {
×
4039
                return dbNode.ID, nil
×
4040
        } else if !errors.Is(err, sql.ErrNoRows) {
×
4041
                return 0, err
×
4042
        }
×
4043

4044
        // Otherwise, the node does not exist, so we create a shell entry for
4045
        // it.
4046
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
4047
                Version: int16(ProtocolV1),
×
4048
                PubKey:  pubKey[:],
×
4049
        })
×
4050
        if err != nil {
×
4051
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4052
        }
×
4053

4054
        return id, nil
×
4055
}
4056

4057
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4058
// the database. This includes deleting any existing types and then inserting
4059
// the new types.
4060
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4061
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4062

×
4063
        // Delete all existing extra signed fields for the channel policy.
×
4064
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4065
        if err != nil {
×
4066
                return fmt.Errorf("unable to delete "+
×
4067
                        "existing policy extra signed fields for policy %d: %w",
×
4068
                        chanPolicyID, err)
×
4069
        }
×
4070

4071
        // Insert all new extra signed fields for the channel policy.
4072
        for tlvType, value := range extraFields {
×
4073
                err = db.InsertChanPolicyExtraType(
×
4074
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
4075
                                ChannelPolicyID: chanPolicyID,
×
4076
                                Type:            int64(tlvType),
×
4077
                                Value:           value,
×
4078
                        },
×
4079
                )
×
4080
                if err != nil {
×
4081
                        return fmt.Errorf("unable to insert "+
×
4082
                                "channel_policy(%d) extra signed field(%v): %w",
×
4083
                                chanPolicyID, tlvType, err)
×
4084
                }
×
4085
        }
4086

4087
        return nil
×
4088
}
4089

4090
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4091
// provided dbChanRow and also fetches any other required information
4092
// to construct the edge info.
4093
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
4094
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.GraphChannel, node1,
4095
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4096

×
4097
        if dbChan.Version != int16(ProtocolV1) {
×
4098
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4099
                        dbChan.Version)
×
4100
        }
×
4101

4102
        fv, extras, err := getChanFeaturesAndExtras(
×
4103
                ctx, db, dbChanID,
×
4104
        )
×
4105
        if err != nil {
×
4106
                return nil, err
×
4107
        }
×
4108

4109
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4110
        if err != nil {
×
4111
                return nil, err
×
4112
        }
×
4113

4114
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4115
        if err != nil {
×
4116
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4117
                        "fields: %w", err)
×
4118
        }
×
4119
        if recs == nil {
×
4120
                recs = make([]byte, 0)
×
4121
        }
×
4122

4123
        var btcKey1, btcKey2 route.Vertex
×
4124
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4125
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4126

×
4127
        channel := &models.ChannelEdgeInfo{
×
4128
                ChainHash:        chain,
×
4129
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4130
                NodeKey1Bytes:    node1,
×
4131
                NodeKey2Bytes:    node2,
×
4132
                BitcoinKey1Bytes: btcKey1,
×
4133
                BitcoinKey2Bytes: btcKey2,
×
4134
                ChannelPoint:     *op,
×
4135
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4136
                Features:         fv,
×
4137
                ExtraOpaqueData:  recs,
×
4138
        }
×
4139

×
4140
        // We always set all the signatures at the same time, so we can
×
4141
        // safely check if one signature is present to determine if we have the
×
4142
        // rest of the signatures for the auth proof.
×
4143
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4144
                channel.AuthProof = &models.ChannelAuthProof{
×
4145
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4146
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4147
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4148
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4149
                }
×
4150
        }
×
4151

4152
        return channel, nil
×
4153
}
4154

4155
// buildNodeVertices is a helper that converts raw node public keys
4156
// into route.Vertex instances.
4157
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4158
        route.Vertex, error) {
×
4159

×
4160
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4161
        if err != nil {
×
4162
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4163
                        "create vertex from node1 pubkey: %w", err)
×
4164
        }
×
4165

4166
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4167
        if err != nil {
×
4168
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4169
                        "create vertex from node2 pubkey: %w", err)
×
4170
        }
×
4171

4172
        return node1Vertex, node2Vertex, nil
×
4173
}
4174

4175
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
4176
// for a channel with the given ID.
4177
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
4178
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
4179

×
4180
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
4181
        if err != nil {
×
4182
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
4183
                        "features and extras: %w", err)
×
4184
        }
×
4185

4186
        var (
×
4187
                fv     = lnwire.EmptyFeatureVector()
×
4188
                extras = make(map[uint64][]byte)
×
4189
        )
×
4190
        for _, row := range rows {
×
4191
                if row.IsFeature {
×
4192
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
4193

×
4194
                        continue
×
4195
                }
4196

4197
                tlvType, ok := row.ExtraKey.(int64)
×
4198
                if !ok {
×
4199
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4200
                                "TLV type: %T", row.ExtraKey)
×
4201
                }
×
4202

4203
                valueBytes, ok := row.Value.([]byte)
×
4204
                if !ok {
×
4205
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4206
                                "Value: %T", row.Value)
×
4207
                }
×
4208

4209
                extras[uint64(tlvType)] = valueBytes
×
4210
        }
4211

4212
        return fv, extras, nil
×
4213
}
4214

4215
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4216
// retrieves all the extra info required to build the complete
4217
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4218
// the provided sqlc.GraphChannelPolicy records are nil.
4219
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4220
        dbPol1, dbPol2 *sqlc.GraphChannelPolicy, channelID uint64, node1,
4221
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4222
        *models.ChannelEdgePolicy, error) {
×
4223

×
4224
        if dbPol1 == nil && dbPol2 == nil {
×
4225
                return nil, nil, nil
×
4226
        }
×
4227

4228
        var (
×
4229
                policy1ID int64
×
4230
                policy2ID int64
×
4231
        )
×
4232
        if dbPol1 != nil {
×
4233
                policy1ID = dbPol1.ID
×
4234
        }
×
4235
        if dbPol2 != nil {
×
4236
                policy2ID = dbPol2.ID
×
4237
        }
×
4238
        rows, err := db.GetChannelPolicyExtraTypes(
×
4239
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
4240
                        ID:   policy1ID,
×
4241
                        ID_2: policy2ID,
×
4242
                },
×
4243
        )
×
4244
        if err != nil {
×
4245
                return nil, nil, err
×
4246
        }
×
4247

4248
        var (
×
4249
                dbPol1Extras = make(map[uint64][]byte)
×
4250
                dbPol2Extras = make(map[uint64][]byte)
×
4251
        )
×
4252
        for _, row := range rows {
×
4253
                switch row.PolicyID {
×
4254
                case policy1ID:
×
4255
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
4256
                case policy2ID:
×
4257
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
4258
                default:
×
4259
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
4260
                                "in row: %v", row.PolicyID, row)
×
4261
                }
4262
        }
4263

4264
        var pol1, pol2 *models.ChannelEdgePolicy
×
4265
        if dbPol1 != nil {
×
4266
                pol1, err = buildChanPolicy(
×
4267
                        *dbPol1, channelID, dbPol1Extras, node2,
×
4268
                )
×
4269
                if err != nil {
×
4270
                        return nil, nil, err
×
4271
                }
×
4272
        }
4273
        if dbPol2 != nil {
×
4274
                pol2, err = buildChanPolicy(
×
4275
                        *dbPol2, channelID, dbPol2Extras, node1,
×
4276
                )
×
4277
                if err != nil {
×
4278
                        return nil, nil, err
×
4279
                }
×
4280
        }
4281

4282
        return pol1, pol2, nil
×
4283
}
4284

4285
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4286
// provided sqlc.GraphChannelPolicy and other required information.
4287
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4288
        extras map[uint64][]byte,
4289
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4290

×
4291
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4292
        if err != nil {
×
4293
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4294
                        "fields: %w", err)
×
4295
        }
×
4296

4297
        var inboundFee fn.Option[lnwire.Fee]
×
4298
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4299
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4300

×
4301
                inboundFee = fn.Some(lnwire.Fee{
×
4302
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4303
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4304
                })
×
4305
        }
×
4306

4307
        return &models.ChannelEdgePolicy{
×
4308
                SigBytes:  dbPolicy.Signature,
×
4309
                ChannelID: channelID,
×
4310
                LastUpdate: time.Unix(
×
4311
                        dbPolicy.LastUpdate.Int64, 0,
×
4312
                ),
×
4313
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4314
                        dbPolicy.MessageFlags,
×
4315
                ),
×
4316
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4317
                        dbPolicy.ChannelFlags,
×
4318
                ),
×
4319
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4320
                MinHTLC: lnwire.MilliSatoshi(
×
4321
                        dbPolicy.MinHtlcMsat,
×
4322
                ),
×
4323
                MaxHTLC: lnwire.MilliSatoshi(
×
4324
                        dbPolicy.MaxHtlcMsat.Int64,
×
4325
                ),
×
4326
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4327
                        dbPolicy.BaseFeeMsat,
×
4328
                ),
×
4329
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4330
                ToNode:                    toNode,
×
4331
                InboundFee:                inboundFee,
×
4332
                ExtraOpaqueData:           recs,
×
4333
        }, nil
×
4334
}
4335

4336
// buildNodes builds the models.LightningNode instances for the
4337
// given row which is expected to be a sqlc type that contains node information.
4338
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4339
        dbNode2 sqlc.GraphNode) (*models.LightningNode, *models.LightningNode,
4340
        error) {
×
4341

×
4342
        node1, err := buildNode(ctx, db, &dbNode1)
×
4343
        if err != nil {
×
4344
                return nil, nil, err
×
4345
        }
×
4346

4347
        node2, err := buildNode(ctx, db, &dbNode2)
×
4348
        if err != nil {
×
4349
                return nil, nil, err
×
4350
        }
×
4351

4352
        return node1, node2, nil
×
4353
}
4354

4355
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4356
// row which is expected to be a sqlc type that contains channel policy
4357
// information. It returns two policies, which may be nil if the policy
4358
// information is not present in the row.
4359
//
4360
//nolint:ll,dupl,funlen
4361
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4362
        *sqlc.GraphChannelPolicy, error) {
×
4363

×
4364
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4365
        switch r := row.(type) {
×
NEW
4366
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
NEW
4367
                if r.Policy1Timelock.Valid {
×
NEW
4368
                        policy1 = &sqlc.GraphChannelPolicy{
×
NEW
4369
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
4370
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
4371
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
4372
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
4373
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
4374
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
4375
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
4376
                                Disabled:                r.Policy1Disabled,
×
NEW
4377
                                MessageFlags:            r.Policy1MessageFlags,
×
NEW
4378
                                ChannelFlags:            r.Policy1ChannelFlags,
×
NEW
4379
                        }
×
NEW
4380
                }
×
NEW
4381
                if r.Policy2Timelock.Valid {
×
NEW
4382
                        policy2 = &sqlc.GraphChannelPolicy{
×
NEW
4383
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
4384
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
4385
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
4386
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
4387
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
4388
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
4389
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
4390
                                Disabled:                r.Policy2Disabled,
×
NEW
4391
                                MessageFlags:            r.Policy2MessageFlags,
×
NEW
4392
                                ChannelFlags:            r.Policy2ChannelFlags,
×
NEW
4393
                        }
×
NEW
4394
                }
×
4395

NEW
4396
                return policy1, policy2, nil
×
4397

4398
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4399
                if r.Policy1ID.Valid {
×
4400
                        policy1 = &sqlc.GraphChannelPolicy{
×
4401
                                ID:                      r.Policy1ID.Int64,
×
4402
                                Version:                 r.Policy1Version.Int16,
×
4403
                                ChannelID:               r.GraphChannel.ID,
×
4404
                                NodeID:                  r.Policy1NodeID.Int64,
×
4405
                                Timelock:                r.Policy1Timelock.Int32,
×
4406
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4407
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4408
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4409
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4410
                                LastUpdate:              r.Policy1LastUpdate,
×
4411
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4412
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4413
                                Disabled:                r.Policy1Disabled,
×
4414
                                MessageFlags:            r.Policy1MessageFlags,
×
4415
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4416
                                Signature:               r.Policy1Signature,
×
4417
                        }
×
4418
                }
×
4419
                if r.Policy2ID.Valid {
×
4420
                        policy2 = &sqlc.GraphChannelPolicy{
×
4421
                                ID:                      r.Policy2ID.Int64,
×
4422
                                Version:                 r.Policy2Version.Int16,
×
4423
                                ChannelID:               r.GraphChannel.ID,
×
4424
                                NodeID:                  r.Policy2NodeID.Int64,
×
4425
                                Timelock:                r.Policy2Timelock.Int32,
×
4426
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4427
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4428
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4429
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4430
                                LastUpdate:              r.Policy2LastUpdate,
×
4431
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4432
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4433
                                Disabled:                r.Policy2Disabled,
×
4434
                                MessageFlags:            r.Policy2MessageFlags,
×
4435
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4436
                                Signature:               r.Policy2Signature,
×
4437
                        }
×
4438
                }
×
4439

4440
                return policy1, policy2, nil
×
4441

4442
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4443
                if r.Policy1ID.Valid {
×
4444
                        policy1 = &sqlc.GraphChannelPolicy{
×
4445
                                ID:                      r.Policy1ID.Int64,
×
4446
                                Version:                 r.Policy1Version.Int16,
×
4447
                                ChannelID:               r.GraphChannel.ID,
×
4448
                                NodeID:                  r.Policy1NodeID.Int64,
×
4449
                                Timelock:                r.Policy1Timelock.Int32,
×
4450
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4451
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4452
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4453
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4454
                                LastUpdate:              r.Policy1LastUpdate,
×
4455
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4456
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4457
                                Disabled:                r.Policy1Disabled,
×
4458
                                MessageFlags:            r.Policy1MessageFlags,
×
4459
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4460
                                Signature:               r.Policy1Signature,
×
4461
                        }
×
4462
                }
×
4463
                if r.Policy2ID.Valid {
×
4464
                        policy2 = &sqlc.GraphChannelPolicy{
×
4465
                                ID:                      r.Policy2ID.Int64,
×
4466
                                Version:                 r.Policy2Version.Int16,
×
4467
                                ChannelID:               r.GraphChannel.ID,
×
4468
                                NodeID:                  r.Policy2NodeID.Int64,
×
4469
                                Timelock:                r.Policy2Timelock.Int32,
×
4470
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4471
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4472
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4473
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4474
                                LastUpdate:              r.Policy2LastUpdate,
×
4475
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4476
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4477
                                Disabled:                r.Policy2Disabled,
×
4478
                                MessageFlags:            r.Policy2MessageFlags,
×
4479
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4480
                                Signature:               r.Policy2Signature,
×
4481
                        }
×
4482
                }
×
4483

4484
                return policy1, policy2, nil
×
4485

4486
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4487
                if r.Policy1ID.Valid {
×
4488
                        policy1 = &sqlc.GraphChannelPolicy{
×
4489
                                ID:                      r.Policy1ID.Int64,
×
4490
                                Version:                 r.Policy1Version.Int16,
×
4491
                                ChannelID:               r.GraphChannel.ID,
×
4492
                                NodeID:                  r.Policy1NodeID.Int64,
×
4493
                                Timelock:                r.Policy1Timelock.Int32,
×
4494
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4495
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4496
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4497
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4498
                                LastUpdate:              r.Policy1LastUpdate,
×
4499
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4500
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4501
                                Disabled:                r.Policy1Disabled,
×
4502
                                MessageFlags:            r.Policy1MessageFlags,
×
4503
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4504
                                Signature:               r.Policy1Signature,
×
4505
                        }
×
4506
                }
×
4507
                if r.Policy2ID.Valid {
×
4508
                        policy2 = &sqlc.GraphChannelPolicy{
×
4509
                                ID:                      r.Policy2ID.Int64,
×
4510
                                Version:                 r.Policy2Version.Int16,
×
4511
                                ChannelID:               r.GraphChannel.ID,
×
4512
                                NodeID:                  r.Policy2NodeID.Int64,
×
4513
                                Timelock:                r.Policy2Timelock.Int32,
×
4514
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4515
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4516
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4517
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4518
                                LastUpdate:              r.Policy2LastUpdate,
×
4519
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4520
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4521
                                Disabled:                r.Policy2Disabled,
×
4522
                                MessageFlags:            r.Policy2MessageFlags,
×
4523
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4524
                                Signature:               r.Policy2Signature,
×
4525
                        }
×
4526
                }
×
4527

4528
                return policy1, policy2, nil
×
4529

4530
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4531
                if r.Policy1ID.Valid {
×
4532
                        policy1 = &sqlc.GraphChannelPolicy{
×
4533
                                ID:                      r.Policy1ID.Int64,
×
4534
                                Version:                 r.Policy1Version.Int16,
×
4535
                                ChannelID:               r.GraphChannel.ID,
×
4536
                                NodeID:                  r.Policy1NodeID.Int64,
×
4537
                                Timelock:                r.Policy1Timelock.Int32,
×
4538
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4539
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4540
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4541
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4542
                                LastUpdate:              r.Policy1LastUpdate,
×
4543
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4544
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4545
                                Disabled:                r.Policy1Disabled,
×
4546
                                MessageFlags:            r.Policy1MessageFlags,
×
4547
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4548
                                Signature:               r.Policy1Signature,
×
4549
                        }
×
4550
                }
×
4551
                if r.Policy2ID.Valid {
×
4552
                        policy2 = &sqlc.GraphChannelPolicy{
×
4553
                                ID:                      r.Policy2ID.Int64,
×
4554
                                Version:                 r.Policy2Version.Int16,
×
4555
                                ChannelID:               r.GraphChannel.ID,
×
4556
                                NodeID:                  r.Policy2NodeID.Int64,
×
4557
                                Timelock:                r.Policy2Timelock.Int32,
×
4558
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4559
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4560
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4561
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4562
                                LastUpdate:              r.Policy2LastUpdate,
×
4563
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4564
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4565
                                Disabled:                r.Policy2Disabled,
×
4566
                                MessageFlags:            r.Policy2MessageFlags,
×
4567
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4568
                                Signature:               r.Policy2Signature,
×
4569
                        }
×
4570
                }
×
4571

4572
                return policy1, policy2, nil
×
4573

4574
        case sqlc.ListChannelsByNodeIDRow:
×
4575
                if r.Policy1ID.Valid {
×
4576
                        policy1 = &sqlc.GraphChannelPolicy{
×
4577
                                ID:                      r.Policy1ID.Int64,
×
4578
                                Version:                 r.Policy1Version.Int16,
×
4579
                                ChannelID:               r.GraphChannel.ID,
×
4580
                                NodeID:                  r.Policy1NodeID.Int64,
×
4581
                                Timelock:                r.Policy1Timelock.Int32,
×
4582
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4583
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4584
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4585
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4586
                                LastUpdate:              r.Policy1LastUpdate,
×
4587
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4588
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4589
                                Disabled:                r.Policy1Disabled,
×
4590
                                MessageFlags:            r.Policy1MessageFlags,
×
4591
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4592
                                Signature:               r.Policy1Signature,
×
4593
                        }
×
4594
                }
×
4595
                if r.Policy2ID.Valid {
×
4596
                        policy2 = &sqlc.GraphChannelPolicy{
×
4597
                                ID:                      r.Policy2ID.Int64,
×
4598
                                Version:                 r.Policy2Version.Int16,
×
4599
                                ChannelID:               r.GraphChannel.ID,
×
4600
                                NodeID:                  r.Policy2NodeID.Int64,
×
4601
                                Timelock:                r.Policy2Timelock.Int32,
×
4602
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4603
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4604
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4605
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4606
                                LastUpdate:              r.Policy2LastUpdate,
×
4607
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4608
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4609
                                Disabled:                r.Policy2Disabled,
×
4610
                                MessageFlags:            r.Policy2MessageFlags,
×
4611
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4612
                                Signature:               r.Policy2Signature,
×
4613
                        }
×
4614
                }
×
4615

4616
                return policy1, policy2, nil
×
4617

4618
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4619
                if r.Policy1ID.Valid {
×
4620
                        policy1 = &sqlc.GraphChannelPolicy{
×
4621
                                ID:                      r.Policy1ID.Int64,
×
4622
                                Version:                 r.Policy1Version.Int16,
×
4623
                                ChannelID:               r.GraphChannel.ID,
×
4624
                                NodeID:                  r.Policy1NodeID.Int64,
×
4625
                                Timelock:                r.Policy1Timelock.Int32,
×
4626
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4627
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4628
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4629
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4630
                                LastUpdate:              r.Policy1LastUpdate,
×
4631
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4632
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4633
                                Disabled:                r.Policy1Disabled,
×
4634
                                MessageFlags:            r.Policy1MessageFlags,
×
4635
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4636
                                Signature:               r.Policy1Signature,
×
4637
                        }
×
4638
                }
×
4639
                if r.Policy2ID.Valid {
×
4640
                        policy2 = &sqlc.GraphChannelPolicy{
×
4641
                                ID:                      r.Policy2ID.Int64,
×
4642
                                Version:                 r.Policy2Version.Int16,
×
4643
                                ChannelID:               r.GraphChannel.ID,
×
4644
                                NodeID:                  r.Policy2NodeID.Int64,
×
4645
                                Timelock:                r.Policy2Timelock.Int32,
×
4646
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4647
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4648
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4649
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4650
                                LastUpdate:              r.Policy2LastUpdate,
×
4651
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4652
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4653
                                Disabled:                r.Policy2Disabled,
×
4654
                                MessageFlags:            r.Policy2MessageFlags,
×
4655
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4656
                                Signature:               r.Policy2Signature,
×
4657
                        }
×
4658
                }
×
4659

4660
                return policy1, policy2, nil
×
4661
        default:
×
4662
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4663
                        "extractChannelPolicies: %T", r)
×
4664
        }
4665
}
4666

4667
// channelIDToBytes converts a channel ID (SCID) to a byte array
4668
// representation.
4669
func channelIDToBytes(channelID uint64) []byte {
×
4670
        var chanIDB [8]byte
×
4671
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4672

×
4673
        return chanIDB[:]
×
4674
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc