• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15997023090

01 Jul 2025 10:40AM UTC coverage: 55.469% (-2.3%) from 57.809%
15997023090

push

github

web-flow
Merge pull request #10010 from ellemouton/sqlGraphUpdates

graph/db: various misc updates

0 of 31 new or added lines in 2 files covered. (0.0%)

23617 existing lines in 280 files now uncovered.

108414 of 195451 relevant lines covered (55.47%)

22394.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
72
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
73
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
74

75
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
96
        GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error)
97
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
98
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
99
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
100
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
101
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
102
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
103
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
104
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
105
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
106
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
107
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
108
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
109
        DeleteChannel(ctx context.Context, id int64) error
110

111
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
112
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
113

114
        /*
115
                Channel Policy table queries.
116
        */
117
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
118
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.ChannelPolicy, error)
119
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
120

121
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
122
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
123
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
124

125
        /*
126
                Zombie index queries.
127
        */
128
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
129
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error)
130
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
131
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
132
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
133

134
        /*
135
                Prune log table queries.
136
        */
137
        GetPruneTip(ctx context.Context) (sqlc.PruneLog, error)
138
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
139
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
140

141
        /*
142
                Closed SCID table queries.
143
        */
144
        InsertClosedChannel(ctx context.Context, scid []byte) error
145
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
146
}
147

148
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
149
// database operations.
150
type BatchedSQLQueries interface {
151
        SQLQueries
152
        sqldb.BatchedTx[SQLQueries]
153
}
154

155
// SQLStore is an implementation of the V1Store interface that uses a SQL
156
// database as the backend.
157
type SQLStore struct {
158
        cfg *SQLStoreConfig
159
        db  BatchedSQLQueries
160

161
        // cacheMu guards all caches (rejectCache and chanCache). If
162
        // this mutex will be acquired at the same time as the DB mutex then
163
        // the cacheMu MUST be acquired first to prevent deadlock.
164
        cacheMu     sync.RWMutex
165
        rejectCache *rejectCache
166
        chanCache   *channelCache
167

168
        chanScheduler batch.Scheduler[SQLQueries]
169
        nodeScheduler batch.Scheduler[SQLQueries]
170

171
        srcNodes  map[ProtocolVersion]*srcNodeInfo
172
        srcNodeMu sync.Mutex
173
}
174

175
// A compile-time assertion to ensure that SQLStore implements the V1Store
176
// interface.
177
var _ V1Store = (*SQLStore)(nil)
178

179
// SQLStoreConfig holds the configuration for the SQLStore.
180
type SQLStoreConfig struct {
181
        // ChainHash is the genesis hash for the chain that all the gossip
182
        // messages in this store are aimed at.
183
        ChainHash chainhash.Hash
184
}
185

186
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
187
// storage backend.
188
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
189
        options ...StoreOptionModifier) (*SQLStore, error) {
×
190

×
191
        opts := DefaultOptions()
×
192
        for _, o := range options {
×
193
                o(opts)
×
194
        }
×
195

196
        if opts.NoMigration {
×
197
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
198
                        "supported for SQL stores")
×
199
        }
×
200

201
        s := &SQLStore{
×
202
                cfg:         cfg,
×
203
                db:          db,
×
204
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
205
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
206
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
207
        }
×
208

×
209
        s.chanScheduler = batch.NewTimeScheduler(
×
210
                db, &s.cacheMu, opts.BatchCommitInterval,
×
211
        )
×
212
        s.nodeScheduler = batch.NewTimeScheduler(
×
213
                db, nil, opts.BatchCommitInterval,
×
214
        )
×
215

×
216
        return s, nil
×
217
}
218

219
// AddLightningNode adds a vertex/node to the graph database. If the node is not
220
// in the database from before, this will add a new, unconnected one to the
221
// graph. If it is present from before, this will update that node's
222
// information.
223
//
224
// NOTE: part of the V1Store interface.
225
func (s *SQLStore) AddLightningNode(ctx context.Context,
226
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
227

×
228
        r := &batch.Request[SQLQueries]{
×
229
                Opts: batch.NewSchedulerOptions(opts...),
×
230
                Do: func(queries SQLQueries) error {
×
231
                        _, err := upsertNode(ctx, queries, node)
×
232
                        return err
×
233
                },
×
234
        }
235

236
        return s.nodeScheduler.Execute(ctx, r)
×
237
}
238

239
// FetchLightningNode attempts to look up a target node by its identity public
240
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
241
// returned.
242
//
243
// NOTE: part of the V1Store interface.
244
func (s *SQLStore) FetchLightningNode(ctx context.Context,
245
        pubKey route.Vertex) (*models.LightningNode, error) {
×
246

×
247
        var node *models.LightningNode
×
248
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
249
                var err error
×
250
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
251

×
252
                return err
×
253
        }, sqldb.NoOpReset)
×
254
        if err != nil {
×
255
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
256
        }
×
257

258
        return node, nil
×
259
}
260

261
// HasLightningNode determines if the graph has a vertex identified by the
262
// target node identity public key. If the node exists in the database, a
263
// timestamp of when the data for the node was lasted updated is returned along
264
// with a true boolean. Otherwise, an empty time.Time is returned with a false
265
// boolean.
266
//
267
// NOTE: part of the V1Store interface.
268
func (s *SQLStore) HasLightningNode(ctx context.Context,
269
        pubKey [33]byte) (time.Time, bool, error) {
×
270

×
271
        var (
×
272
                exists     bool
×
273
                lastUpdate time.Time
×
274
        )
×
275
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
276
                dbNode, err := db.GetNodeByPubKey(
×
277
                        ctx, sqlc.GetNodeByPubKeyParams{
×
278
                                Version: int16(ProtocolV1),
×
279
                                PubKey:  pubKey[:],
×
280
                        },
×
281
                )
×
282
                if errors.Is(err, sql.ErrNoRows) {
×
283
                        return nil
×
284
                } else if err != nil {
×
285
                        return fmt.Errorf("unable to fetch node: %w", err)
×
286
                }
×
287

288
                exists = true
×
289

×
290
                if dbNode.LastUpdate.Valid {
×
291
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
292
                }
×
293

294
                return nil
×
295
        }, sqldb.NoOpReset)
296
        if err != nil {
×
297
                return time.Time{}, false,
×
298
                        fmt.Errorf("unable to fetch node: %w", err)
×
299
        }
×
300

301
        return lastUpdate, exists, nil
×
302
}
303

304
// AddrsForNode returns all known addresses for the target node public key
305
// that the graph DB is aware of. The returned boolean indicates if the
306
// given node is unknown to the graph DB or not.
307
//
308
// NOTE: part of the V1Store interface.
309
func (s *SQLStore) AddrsForNode(ctx context.Context,
310
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
311

×
312
        var (
×
313
                addresses []net.Addr
×
314
                known     bool
×
315
        )
×
316
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
317
                var err error
×
318
                known, addresses, err = getNodeAddresses(
×
319
                        ctx, db, nodePub.SerializeCompressed(),
×
320
                )
×
321
                if err != nil {
×
322
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
323
                                err)
×
324
                }
×
325

326
                return nil
×
327
        }, sqldb.NoOpReset)
328
        if err != nil {
×
329
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
330
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
331
        }
×
332

333
        return known, addresses, nil
×
334
}
335

336
// DeleteLightningNode starts a new database transaction to remove a vertex/node
337
// from the database according to the node's public key.
338
//
339
// NOTE: part of the V1Store interface.
340
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
341
        pubKey route.Vertex) error {
×
342

×
343
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
344
                res, err := db.DeleteNodeByPubKey(
×
345
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
346
                                Version: int16(ProtocolV1),
×
347
                                PubKey:  pubKey[:],
×
348
                        },
×
349
                )
×
350
                if err != nil {
×
351
                        return err
×
352
                }
×
353

354
                rows, err := res.RowsAffected()
×
355
                if err != nil {
×
356
                        return err
×
357
                }
×
358

359
                if rows == 0 {
×
360
                        return ErrGraphNodeNotFound
×
361
                } else if rows > 1 {
×
362
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
363
                }
×
364

365
                return err
×
366
        }, sqldb.NoOpReset)
367
        if err != nil {
×
368
                return fmt.Errorf("unable to delete node: %w", err)
×
369
        }
×
370

371
        return nil
×
372
}
373

374
// FetchNodeFeatures returns the features of the given node. If no features are
375
// known for the node, an empty feature vector is returned.
376
//
377
// NOTE: this is part of the graphdb.NodeTraverser interface.
378
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
379
        *lnwire.FeatureVector, error) {
×
380

×
381
        ctx := context.TODO()
×
382

×
383
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
384
}
×
385

386
// DisabledChannelIDs returns the channel ids of disabled channels.
387
// A channel is disabled when two of the associated ChanelEdgePolicies
388
// have their disabled bit on.
389
//
390
// NOTE: part of the V1Store interface.
391
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
392
        var (
×
393
                ctx     = context.TODO()
×
394
                chanIDs []uint64
×
395
        )
×
396
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
397
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
398
                if err != nil {
×
399
                        return fmt.Errorf("unable to fetch disabled "+
×
400
                                "channels: %w", err)
×
401
                }
×
402

403
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
404

×
405
                return nil
×
406
        }, sqldb.NoOpReset)
407
        if err != nil {
×
408
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
409
                        err)
×
410
        }
×
411

412
        return chanIDs, nil
×
413
}
414

415
// LookupAlias attempts to return the alias as advertised by the target node.
416
//
417
// NOTE: part of the V1Store interface.
418
func (s *SQLStore) LookupAlias(ctx context.Context,
419
        pub *btcec.PublicKey) (string, error) {
×
420

×
421
        var alias string
×
422
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
423
                dbNode, err := db.GetNodeByPubKey(
×
424
                        ctx, sqlc.GetNodeByPubKeyParams{
×
425
                                Version: int16(ProtocolV1),
×
426
                                PubKey:  pub.SerializeCompressed(),
×
427
                        },
×
428
                )
×
429
                if errors.Is(err, sql.ErrNoRows) {
×
430
                        return ErrNodeAliasNotFound
×
431
                } else if err != nil {
×
432
                        return fmt.Errorf("unable to fetch node: %w", err)
×
433
                }
×
434

435
                if !dbNode.Alias.Valid {
×
436
                        return ErrNodeAliasNotFound
×
437
                }
×
438

439
                alias = dbNode.Alias.String
×
440

×
441
                return nil
×
442
        }, sqldb.NoOpReset)
443
        if err != nil {
×
444
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
445
        }
×
446

447
        return alias, nil
×
448
}
449

450
// SourceNode returns the source node of the graph. The source node is treated
451
// as the center node within a star-graph. This method may be used to kick off
452
// a path finding algorithm in order to explore the reachability of another
453
// node based off the source node.
454
//
455
// NOTE: part of the V1Store interface.
456
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
457
        error) {
×
458

×
459
        var node *models.LightningNode
×
460
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
461
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
462
                if err != nil {
×
463
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
464
                                err)
×
465
                }
×
466

467
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
468

×
469
                return err
×
470
        }, sqldb.NoOpReset)
471
        if err != nil {
×
472
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
473
        }
×
474

475
        return node, nil
×
476
}
477

478
// SetSourceNode sets the source node within the graph database. The source
479
// node is to be used as the center of a star-graph within path finding
480
// algorithms.
481
//
482
// NOTE: part of the V1Store interface.
483
func (s *SQLStore) SetSourceNode(ctx context.Context,
484
        node *models.LightningNode) error {
×
485

×
486
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
487
                id, err := upsertNode(ctx, db, node)
×
488
                if err != nil {
×
489
                        return fmt.Errorf("unable to upsert source node: %w",
×
490
                                err)
×
491
                }
×
492

493
                // Make sure that if a source node for this version is already
494
                // set, then the ID is the same as the one we are about to set.
495
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
496
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
497
                        return fmt.Errorf("unable to fetch source node: %w",
×
498
                                err)
×
499
                } else if err == nil {
×
500
                        if dbSourceNodeID != id {
×
501
                                return fmt.Errorf("v1 source node already "+
×
502
                                        "set to a different node: %d vs %d",
×
503
                                        dbSourceNodeID, id)
×
504
                        }
×
505

506
                        return nil
×
507
                }
508

509
                return db.AddSourceNode(ctx, id)
×
510
        }, sqldb.NoOpReset)
511
}
512

513
// NodeUpdatesInHorizon returns all the known lightning node which have an
514
// update timestamp within the passed range. This method can be used by two
515
// nodes to quickly determine if they have the same set of up to date node
516
// announcements.
517
//
518
// NOTE: This is part of the V1Store interface.
519
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
520
        endTime time.Time) ([]models.LightningNode, error) {
×
521

×
522
        ctx := context.TODO()
×
523

×
524
        var nodes []models.LightningNode
×
525
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
526
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
527
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
528
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
529
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
530
                        },
×
531
                )
×
532
                if err != nil {
×
533
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
534
                }
×
535

536
                for _, dbNode := range dbNodes {
×
537
                        node, err := buildNode(ctx, db, &dbNode)
×
538
                        if err != nil {
×
539
                                return fmt.Errorf("unable to build node: %w",
×
540
                                        err)
×
541
                        }
×
542

543
                        nodes = append(nodes, *node)
×
544
                }
545

546
                return nil
×
547
        }, sqldb.NoOpReset)
548
        if err != nil {
×
549
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
550
        }
×
551

552
        return nodes, nil
×
553
}
554

555
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
556
// undirected edge from the two target nodes are created. The information stored
557
// denotes the static attributes of the channel, such as the channelID, the keys
558
// involved in creation of the channel, and the set of features that the channel
559
// supports. The chanPoint and chanID are used to uniquely identify the edge
560
// globally within the database.
561
//
562
// NOTE: part of the V1Store interface.
563
func (s *SQLStore) AddChannelEdge(ctx context.Context,
564
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
565

×
566
        var alreadyExists bool
×
567
        r := &batch.Request[SQLQueries]{
×
568
                Opts: batch.NewSchedulerOptions(opts...),
×
569
                Reset: func() {
×
570
                        alreadyExists = false
×
571
                },
×
572
                Do: func(tx SQLQueries) error {
×
573
                        err := insertChannel(ctx, tx, edge)
×
574

×
575
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
576
                        // succeed, but propagate the error via local state.
×
577
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
578
                                alreadyExists = true
×
579
                                return nil
×
580
                        }
×
581

582
                        return err
×
583
                },
584
                OnCommit: func(err error) error {
×
585
                        switch {
×
586
                        case err != nil:
×
587
                                return err
×
588
                        case alreadyExists:
×
589
                                return ErrEdgeAlreadyExist
×
590
                        default:
×
591
                                s.rejectCache.remove(edge.ChannelID)
×
592
                                s.chanCache.remove(edge.ChannelID)
×
593
                                return nil
×
594
                        }
595
                },
596
        }
597

598
        return s.chanScheduler.Execute(ctx, r)
×
599
}
600

601
// HighestChanID returns the "highest" known channel ID in the channel graph.
602
// This represents the "newest" channel from the PoV of the chain. This method
603
// can be used by peers to quickly determine if their graphs are in sync.
604
//
605
// NOTE: This is part of the V1Store interface.
606
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
607
        var highestChanID uint64
×
608
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
609
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
610
                if errors.Is(err, sql.ErrNoRows) {
×
611
                        return nil
×
612
                } else if err != nil {
×
613
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
614
                                err)
×
615
                }
×
616

617
                highestChanID = byteOrder.Uint64(chanID)
×
618

×
619
                return nil
×
620
        }, sqldb.NoOpReset)
621
        if err != nil {
×
622
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
623
        }
×
624

625
        return highestChanID, nil
×
626
}
627

628
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
629
// within the database for the referenced channel. The `flags` attribute within
630
// the ChannelEdgePolicy determines which of the directed edges are being
631
// updated. If the flag is 1, then the first node's information is being
632
// updated, otherwise it's the second node's information. The node ordering is
633
// determined by the lexicographical ordering of the identity public keys of the
634
// nodes on either side of the channel.
635
//
636
// NOTE: part of the V1Store interface.
637
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
638
        edge *models.ChannelEdgePolicy,
639
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
640

×
641
        var (
×
642
                isUpdate1    bool
×
643
                edgeNotFound bool
×
644
                from, to     route.Vertex
×
645
        )
×
646

×
647
        r := &batch.Request[SQLQueries]{
×
648
                Opts: batch.NewSchedulerOptions(opts...),
×
649
                Reset: func() {
×
650
                        isUpdate1 = false
×
651
                        edgeNotFound = false
×
652
                },
×
653
                Do: func(tx SQLQueries) error {
×
654
                        var err error
×
655
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
656
                                ctx, tx, edge,
×
657
                        )
×
658
                        if err != nil {
×
659
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
660
                        }
×
661

662
                        // Silence ErrEdgeNotFound so that the batch can
663
                        // succeed, but propagate the error via local state.
664
                        if errors.Is(err, ErrEdgeNotFound) {
×
665
                                edgeNotFound = true
×
666
                                return nil
×
667
                        }
×
668

669
                        return err
×
670
                },
671
                OnCommit: func(err error) error {
×
672
                        switch {
×
673
                        case err != nil:
×
674
                                return err
×
675
                        case edgeNotFound:
×
676
                                return ErrEdgeNotFound
×
677
                        default:
×
678
                                s.updateEdgeCache(edge, isUpdate1)
×
679
                                return nil
×
680
                        }
681
                },
682
        }
683

684
        err := s.chanScheduler.Execute(ctx, r)
×
685

×
686
        return from, to, err
×
687
}
688

689
// updateEdgeCache updates our reject and channel caches with the new
690
// edge policy information.
691
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
692
        isUpdate1 bool) {
×
693

×
694
        // If an entry for this channel is found in reject cache, we'll modify
×
695
        // the entry with the updated timestamp for the direction that was just
×
696
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
697
        // during the next query for this edge.
×
698
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
699
                if isUpdate1 {
×
700
                        entry.upd1Time = e.LastUpdate.Unix()
×
701
                } else {
×
702
                        entry.upd2Time = e.LastUpdate.Unix()
×
703
                }
×
704
                s.rejectCache.insert(e.ChannelID, entry)
×
705
        }
706

707
        // If an entry for this channel is found in channel cache, we'll modify
708
        // the entry with the updated policy for the direction that was just
709
        // written. If the edge doesn't exist, we'll defer loading the info and
710
        // policies and lazily read from disk during the next query.
711
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
712
                if isUpdate1 {
×
713
                        channel.Policy1 = e
×
714
                } else {
×
715
                        channel.Policy2 = e
×
716
                }
×
717
                s.chanCache.insert(e.ChannelID, channel)
×
718
        }
719
}
720

721
// ForEachSourceNodeChannel iterates through all channels of the source node,
722
// executing the passed callback on each. The call-back is provided with the
723
// channel's outpoint, whether we have a policy for the channel and the channel
724
// peer's node information.
725
//
726
// NOTE: part of the V1Store interface.
727
func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint,
728
        havePolicy bool, otherNode *models.LightningNode) error) error {
×
729

×
730
        var ctx = context.TODO()
×
731

×
732
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
733
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
734
                if err != nil {
×
735
                        return fmt.Errorf("unable to fetch source node: %w",
×
736
                                err)
×
737
                }
×
738

739
                return forEachNodeChannel(
×
740
                        ctx, db, s.cfg.ChainHash, nodeID,
×
741
                        func(info *models.ChannelEdgeInfo,
×
742
                                outPolicy *models.ChannelEdgePolicy,
×
743
                                _ *models.ChannelEdgePolicy) error {
×
744

×
745
                                // Fetch the other node.
×
746
                                var (
×
747
                                        otherNodePub [33]byte
×
748
                                        node1        = info.NodeKey1Bytes
×
749
                                        node2        = info.NodeKey2Bytes
×
750
                                )
×
751
                                switch {
×
752
                                case bytes.Equal(node1[:], nodePub[:]):
×
753
                                        otherNodePub = node2
×
754
                                case bytes.Equal(node2[:], nodePub[:]):
×
755
                                        otherNodePub = node1
×
756
                                default:
×
757
                                        return fmt.Errorf("node not " +
×
758
                                                "participating in this channel")
×
759
                                }
760

761
                                _, otherNode, err := getNodeByPubKey(
×
762
                                        ctx, db, otherNodePub,
×
763
                                )
×
764
                                if err != nil {
×
765
                                        return fmt.Errorf("unable to fetch "+
×
766
                                                "other node(%x): %w",
×
767
                                                otherNodePub, err)
×
768
                                }
×
769

770
                                return cb(
×
771
                                        info.ChannelPoint, outPolicy != nil,
×
772
                                        otherNode,
×
773
                                )
×
774
                        },
775
                )
776
        }, sqldb.NoOpReset)
777
}
778

779
// ForEachNode iterates through all the stored vertices/nodes in the graph,
780
// executing the passed callback with each node encountered. If the callback
781
// returns an error, then the transaction is aborted and the iteration stops
782
// early. Any operations performed on the NodeTx passed to the call-back are
783
// executed under the same read transaction and so, methods on the NodeTx object
784
// _MUST_ only be called from within the call-back.
785
//
786
// NOTE: part of the V1Store interface.
787
func (s *SQLStore) ForEachNode(cb func(tx NodeRTx) error) error {
×
788
        var (
×
789
                ctx          = context.TODO()
×
790
                lastID int64 = 0
×
791
        )
×
792

×
793
        handleNode := func(db SQLQueries, dbNode sqlc.Node) error {
×
794
                node, err := buildNode(ctx, db, &dbNode)
×
795
                if err != nil {
×
796
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
797
                                dbNode.ID, err)
×
798
                }
×
799

800
                err = cb(
×
801
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
802
                )
×
803
                if err != nil {
×
804
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
805
                                dbNode.ID, err)
×
806
                }
×
807

808
                return nil
×
809
        }
810

811
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
812
                for {
×
813
                        nodes, err := db.ListNodesPaginated(
×
814
                                ctx, sqlc.ListNodesPaginatedParams{
×
815
                                        Version: int16(ProtocolV1),
×
816
                                        ID:      lastID,
×
817
                                        Limit:   pageSize,
×
818
                                },
×
819
                        )
×
820
                        if err != nil {
×
821
                                return fmt.Errorf("unable to fetch nodes: %w",
×
822
                                        err)
×
823
                        }
×
824

825
                        if len(nodes) == 0 {
×
826
                                break
×
827
                        }
828

829
                        for _, dbNode := range nodes {
×
830
                                err = handleNode(db, dbNode)
×
831
                                if err != nil {
×
832
                                        return err
×
833
                                }
×
834

835
                                lastID = dbNode.ID
×
836
                        }
837
                }
838

839
                return nil
×
840
        }, sqldb.NoOpReset)
841
}
842

843
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
844
// SQLStore and a SQL transaction.
845
type sqlGraphNodeTx struct {
846
        db    SQLQueries
847
        id    int64
848
        node  *models.LightningNode
849
        chain chainhash.Hash
850
}
851

852
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
853
// interface.
854
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
855

856
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
857
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
858

×
859
        return &sqlGraphNodeTx{
×
860
                db:    db,
×
861
                chain: chain,
×
862
                id:    id,
×
863
                node:  node,
×
864
        }
×
865
}
×
866

867
// Node returns the raw information of the node.
868
//
869
// NOTE: This is a part of the NodeRTx interface.
870
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
871
        return s.node
×
872
}
×
873

874
// ForEachChannel can be used to iterate over the node's channels under the same
875
// transaction used to fetch the node.
876
//
877
// NOTE: This is a part of the NodeRTx interface.
878
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
879
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
880

×
881
        ctx := context.TODO()
×
882

×
883
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
884
}
×
885

886
// FetchNode fetches the node with the given pub key under the same transaction
887
// used to fetch the current node. The returned node is also a NodeRTx and any
888
// operations on that NodeRTx will also be done under the same transaction.
889
//
890
// NOTE: This is a part of the NodeRTx interface.
891
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
892
        ctx := context.TODO()
×
893

×
894
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
895
        if err != nil {
×
896
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
897
                        nodePub, err)
×
898
        }
×
899

900
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
901
}
902

903
// ForEachNodeDirectedChannel iterates through all channels of a given node,
904
// executing the passed callback on the directed edge representing the channel
905
// and its incoming policy. If the callback returns an error, then the iteration
906
// is halted with the error propagated back up to the caller.
907
//
908
// Unknown policies are passed into the callback as nil values.
909
//
910
// NOTE: this is part of the graphdb.NodeTraverser interface.
911
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
912
        cb func(channel *DirectedChannel) error) error {
×
913

×
914
        var ctx = context.TODO()
×
915

×
916
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
917
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
918
        }, sqldb.NoOpReset)
×
919
}
920

921
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
922
// graph, executing the passed callback with each node encountered. If the
923
// callback returns an error, then the transaction is aborted and the iteration
924
// stops early.
925
//
926
// NOTE: This is a part of the V1Store interface.
927
func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex,
928
        *lnwire.FeatureVector) error) error {
×
929

×
930
        ctx := context.TODO()
×
931

×
932
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
933
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
934
                        nodePub route.Vertex) error {
×
935

×
936
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
937
                        if err != nil {
×
938
                                return fmt.Errorf("unable to fetch node "+
×
939
                                        "features: %w", err)
×
940
                        }
×
941

942
                        return cb(nodePub, features)
×
943
                })
944
        }, sqldb.NoOpReset)
945
        if err != nil {
×
946
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
947
        }
×
948

949
        return nil
×
950
}
951

952
// ForEachNodeChannel iterates through all channels of the given node,
953
// executing the passed callback with an edge info structure and the policies
954
// of each end of the channel. The first edge policy is the outgoing edge *to*
955
// the connecting node, while the second is the incoming edge *from* the
956
// connecting node. If the callback returns an error, then the iteration is
957
// halted with the error propagated back up to the caller.
958
//
959
// Unknown policies are passed into the callback as nil values.
960
//
961
// NOTE: part of the V1Store interface.
962
func (s *SQLStore) ForEachNodeChannel(nodePub route.Vertex,
963
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
964
                *models.ChannelEdgePolicy) error) error {
×
965

×
966
        var ctx = context.TODO()
×
967

×
968
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
969
                dbNode, err := db.GetNodeByPubKey(
×
970
                        ctx, sqlc.GetNodeByPubKeyParams{
×
971
                                Version: int16(ProtocolV1),
×
972
                                PubKey:  nodePub[:],
×
973
                        },
×
974
                )
×
975
                if errors.Is(err, sql.ErrNoRows) {
×
976
                        return nil
×
977
                } else if err != nil {
×
978
                        return fmt.Errorf("unable to fetch node: %w", err)
×
979
                }
×
980

981
                return forEachNodeChannel(
×
982
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
983
                )
×
984
        }, sqldb.NoOpReset)
985
}
986

987
// ChanUpdatesInHorizon returns all the known channel edges which have at least
988
// one edge that has an update timestamp within the specified horizon.
989
//
990
// NOTE: This is part of the V1Store interface.
991
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
992
        endTime time.Time) ([]ChannelEdge, error) {
×
993

×
994
        s.cacheMu.Lock()
×
995
        defer s.cacheMu.Unlock()
×
996

×
997
        var (
×
998
                ctx = context.TODO()
×
999
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
1000
                // an additional map to keep track of the edges already seen to
×
1001
                // prevent re-adding it.
×
1002
                edgesSeen    = make(map[uint64]struct{})
×
1003
                edgesToCache = make(map[uint64]ChannelEdge)
×
1004
                edges        []ChannelEdge
×
1005
                hits         int
×
1006
        )
×
1007
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1008
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1009
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1010
                                Version:   int16(ProtocolV1),
×
1011
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1012
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1013
                        },
×
1014
                )
×
1015
                if err != nil {
×
1016
                        return err
×
1017
                }
×
1018

1019
                for _, row := range rows {
×
1020
                        // If we've already retrieved the info and policies for
×
1021
                        // this edge, then we can skip it as we don't need to do
×
1022
                        // so again.
×
1023
                        chanIDInt := byteOrder.Uint64(row.Channel.Scid)
×
1024
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1025
                                continue
×
1026
                        }
1027

1028
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1029
                                hits++
×
1030
                                edgesSeen[chanIDInt] = struct{}{}
×
1031
                                edges = append(edges, channel)
×
1032

×
1033
                                continue
×
1034
                        }
1035

1036
                        node1, node2, err := buildNodes(
×
1037
                                ctx, db, row.Node, row.Node_2,
×
1038
                        )
×
1039
                        if err != nil {
×
1040
                                return err
×
1041
                        }
×
1042

1043
                        channel, err := getAndBuildEdgeInfo(
×
1044
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1045
                                row.Channel, node1.PubKeyBytes,
×
1046
                                node2.PubKeyBytes,
×
1047
                        )
×
1048
                        if err != nil {
×
1049
                                return fmt.Errorf("unable to build channel "+
×
1050
                                        "info: %w", err)
×
1051
                        }
×
1052

1053
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1054
                        if err != nil {
×
1055
                                return fmt.Errorf("unable to extract channel "+
×
1056
                                        "policies: %w", err)
×
1057
                        }
×
1058

1059
                        p1, p2, err := getAndBuildChanPolicies(
×
1060
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1061
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1062
                        )
×
1063
                        if err != nil {
×
1064
                                return fmt.Errorf("unable to build channel "+
×
1065
                                        "policies: %w", err)
×
1066
                        }
×
1067

1068
                        edgesSeen[chanIDInt] = struct{}{}
×
1069
                        chanEdge := ChannelEdge{
×
1070
                                Info:    channel,
×
1071
                                Policy1: p1,
×
1072
                                Policy2: p2,
×
1073
                                Node1:   node1,
×
1074
                                Node2:   node2,
×
1075
                        }
×
1076
                        edges = append(edges, chanEdge)
×
1077
                        edgesToCache[chanIDInt] = chanEdge
×
1078
                }
1079

1080
                return nil
×
1081
        }, func() {
×
1082
                edgesSeen = make(map[uint64]struct{})
×
1083
                edgesToCache = make(map[uint64]ChannelEdge)
×
1084
                edges = nil
×
1085
        })
×
1086
        if err != nil {
×
1087
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1088
        }
×
1089

1090
        // Insert any edges loaded from disk into the cache.
1091
        for chanid, channel := range edgesToCache {
×
1092
                s.chanCache.insert(chanid, channel)
×
1093
        }
×
1094

1095
        if len(edges) > 0 {
×
1096
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1097
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1098
        } else {
×
1099
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1100
                        "horizon (%s, %s)", startTime, endTime)
×
1101
        }
×
1102

1103
        return edges, nil
×
1104
}
1105

1106
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1107
// data to the call-back.
1108
//
1109
// NOTE: The callback contents MUST not be modified.
1110
//
1111
// NOTE: part of the V1Store interface.
1112
func (s *SQLStore) ForEachNodeCached(cb func(node route.Vertex,
1113
        chans map[uint64]*DirectedChannel) error) error {
×
1114

×
1115
        var ctx = context.TODO()
×
1116

×
1117
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1118
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1119
                        nodePub route.Vertex) error {
×
1120

×
1121
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1122
                        if err != nil {
×
1123
                                return fmt.Errorf("unable to fetch "+
×
1124
                                        "node(id=%d) features: %w", nodeID, err)
×
1125
                        }
×
1126

1127
                        toNodeCallback := func() route.Vertex {
×
1128
                                return nodePub
×
1129
                        }
×
1130

1131
                        rows, err := db.ListChannelsByNodeID(
×
1132
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1133
                                        Version: int16(ProtocolV1),
×
1134
                                        NodeID1: nodeID,
×
1135
                                },
×
1136
                        )
×
1137
                        if err != nil {
×
1138
                                return fmt.Errorf("unable to fetch channels "+
×
1139
                                        "of node(id=%d): %w", nodeID, err)
×
1140
                        }
×
1141

1142
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1143
                        for _, row := range rows {
×
1144
                                node1, node2, err := buildNodeVertices(
×
1145
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1146
                                )
×
1147
                                if err != nil {
×
1148
                                        return err
×
1149
                                }
×
1150

1151
                                e, err := getAndBuildEdgeInfo(
×
1152
                                        ctx, db, s.cfg.ChainHash,
×
1153
                                        row.Channel.ID, row.Channel, node1,
×
1154
                                        node2,
×
1155
                                )
×
1156
                                if err != nil {
×
1157
                                        return fmt.Errorf("unable to build "+
×
1158
                                                "channel info: %w", err)
×
1159
                                }
×
1160

1161
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1162
                                        row,
×
1163
                                )
×
1164
                                if err != nil {
×
1165
                                        return fmt.Errorf("unable to "+
×
1166
                                                "extract channel "+
×
1167
                                                "policies: %w", err)
×
1168
                                }
×
1169

1170
                                p1, p2, err := getAndBuildChanPolicies(
×
1171
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1172
                                        node1, node2,
×
1173
                                )
×
1174
                                if err != nil {
×
1175
                                        return fmt.Errorf("unable to "+
×
1176
                                                "build channel policies: %w",
×
1177
                                                err)
×
1178
                                }
×
1179

1180
                                // Determine the outgoing and incoming policy
1181
                                // for this channel and node combo.
1182
                                outPolicy, inPolicy := p1, p2
×
1183
                                if p1 != nil && p1.ToNode == nodePub {
×
1184
                                        outPolicy, inPolicy = p2, p1
×
1185
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1186
                                        outPolicy, inPolicy = p2, p1
×
1187
                                }
×
1188

1189
                                var cachedInPolicy *models.CachedEdgePolicy
×
1190
                                if inPolicy != nil {
×
1191
                                        cachedInPolicy = models.NewCachedPolicy(
×
1192
                                                p2,
×
1193
                                        )
×
1194
                                        cachedInPolicy.ToNodePubKey =
×
1195
                                                toNodeCallback
×
1196
                                        cachedInPolicy.ToNodeFeatures =
×
1197
                                                features
×
1198
                                }
×
1199

1200
                                var inboundFee lnwire.Fee
×
1201
                                outPolicy.InboundFee.WhenSome(
×
1202
                                        func(fee lnwire.Fee) {
×
1203
                                                inboundFee = fee
×
1204
                                        },
×
1205
                                )
1206

1207
                                directedChannel := &DirectedChannel{
×
1208
                                        ChannelID: e.ChannelID,
×
1209
                                        IsNode1: nodePub ==
×
1210
                                                e.NodeKey1Bytes,
×
1211
                                        OtherNode:    e.NodeKey2Bytes,
×
1212
                                        Capacity:     e.Capacity,
×
1213
                                        OutPolicySet: p1 != nil,
×
1214
                                        InPolicy:     cachedInPolicy,
×
1215
                                        InboundFee:   inboundFee,
×
1216
                                }
×
1217

×
1218
                                if nodePub == e.NodeKey2Bytes {
×
1219
                                        directedChannel.OtherNode =
×
1220
                                                e.NodeKey1Bytes
×
1221
                                }
×
1222

1223
                                channels[e.ChannelID] = directedChannel
×
1224
                        }
1225

1226
                        return cb(nodePub, channels)
×
1227
                })
1228
        }, sqldb.NoOpReset)
1229
}
1230

1231
// ForEachChannelCacheable iterates through all the channel edges stored
1232
// within the graph and invokes the passed callback for each edge. The
1233
// callback takes two edges as since this is a directed graph, both the
1234
// in/out edges are visited. If the callback returns an error, then the
1235
// transaction is aborted and the iteration stops early.
1236
//
1237
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1238
// pointer for that particular channel edge routing policy will be
1239
// passed into the callback.
1240
//
1241
// NOTE: this method is like ForEachChannel but fetches only the data
1242
// required for the graph cache.
1243
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1244
        *models.CachedEdgePolicy,
1245
        *models.CachedEdgePolicy) error) error {
×
1246

×
1247
        ctx := context.TODO()
×
1248

×
1249
        handleChannel := func(db SQLQueries,
×
1250
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1251

×
1252
                node1, node2, err := buildNodeVertices(
×
1253
                        row.Node1Pubkey, row.Node2Pubkey,
×
1254
                )
×
1255
                if err != nil {
×
1256
                        return err
×
1257
                }
×
1258

1259
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
1260

×
1261
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1262
                if err != nil {
×
1263
                        return err
×
1264
                }
×
1265

1266
                var pol1, pol2 *models.CachedEdgePolicy
×
1267
                if dbPol1 != nil {
×
1268
                        policy1, err := buildChanPolicy(
×
1269
                                *dbPol1, edge.ChannelID, nil, node2,
×
1270
                        )
×
1271
                        if err != nil {
×
1272
                                return err
×
1273
                        }
×
1274

1275
                        pol1 = models.NewCachedPolicy(policy1)
×
1276
                }
1277
                if dbPol2 != nil {
×
1278
                        policy2, err := buildChanPolicy(
×
1279
                                *dbPol2, edge.ChannelID, nil, node1,
×
1280
                        )
×
1281
                        if err != nil {
×
1282
                                return err
×
1283
                        }
×
1284

1285
                        pol2 = models.NewCachedPolicy(policy2)
×
1286
                }
1287

1288
                if err := cb(edge, pol1, pol2); err != nil {
×
1289
                        return err
×
1290
                }
×
1291

1292
                return nil
×
1293
        }
1294

1295
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1296
                lastID := int64(-1)
×
1297
                for {
×
1298
                        //nolint:ll
×
1299
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1300
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1301
                                        Version: int16(ProtocolV1),
×
1302
                                        ID:      lastID,
×
1303
                                        Limit:   pageSize,
×
1304
                                },
×
1305
                        )
×
1306
                        if err != nil {
×
1307
                                return err
×
1308
                        }
×
1309

1310
                        if len(rows) == 0 {
×
1311
                                break
×
1312
                        }
1313

1314
                        for _, row := range rows {
×
1315
                                err := handleChannel(db, row)
×
1316
                                if err != nil {
×
1317
                                        return err
×
1318
                                }
×
1319

1320
                                lastID = row.Channel.ID
×
1321
                        }
1322
                }
1323

1324
                return nil
×
1325
        }, sqldb.NoOpReset)
1326
}
1327

1328
// ForEachChannel iterates through all the channel edges stored within the
1329
// graph and invokes the passed callback for each edge. The callback takes two
1330
// edges as since this is a directed graph, both the in/out edges are visited.
1331
// If the callback returns an error, then the transaction is aborted and the
1332
// iteration stops early.
1333
//
1334
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1335
// for that particular channel edge routing policy will be passed into the
1336
// callback.
1337
//
1338
// NOTE: part of the V1Store interface.
1339
func (s *SQLStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
1340
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
1341

×
1342
        ctx := context.TODO()
×
1343

×
1344
        handleChannel := func(db SQLQueries,
×
1345
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1346

×
1347
                node1, node2, err := buildNodeVertices(
×
1348
                        row.Node1Pubkey, row.Node2Pubkey,
×
1349
                )
×
1350
                if err != nil {
×
1351
                        return fmt.Errorf("unable to build node vertices: %w",
×
1352
                                err)
×
1353
                }
×
1354

1355
                edge, err := getAndBuildEdgeInfo(
×
1356
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1357
                        node1, node2,
×
1358
                )
×
1359
                if err != nil {
×
1360
                        return fmt.Errorf("unable to build channel info: %w",
×
1361
                                err)
×
1362
                }
×
1363

1364
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1365
                if err != nil {
×
1366
                        return fmt.Errorf("unable to extract channel "+
×
1367
                                "policies: %w", err)
×
1368
                }
×
1369

1370
                p1, p2, err := getAndBuildChanPolicies(
×
1371
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1372
                )
×
1373
                if err != nil {
×
1374
                        return fmt.Errorf("unable to build channel "+
×
1375
                                "policies: %w", err)
×
1376
                }
×
1377

1378
                err = cb(edge, p1, p2)
×
1379
                if err != nil {
×
1380
                        return fmt.Errorf("callback failed for channel "+
×
1381
                                "id=%d: %w", edge.ChannelID, err)
×
1382
                }
×
1383

1384
                return nil
×
1385
        }
1386

1387
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1388
                lastID := int64(-1)
×
1389
                for {
×
1390
                        //nolint:ll
×
1391
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1392
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1393
                                        Version: int16(ProtocolV1),
×
1394
                                        ID:      lastID,
×
1395
                                        Limit:   pageSize,
×
1396
                                },
×
1397
                        )
×
1398
                        if err != nil {
×
1399
                                return err
×
1400
                        }
×
1401

1402
                        if len(rows) == 0 {
×
1403
                                break
×
1404
                        }
1405

1406
                        for _, row := range rows {
×
1407
                                err := handleChannel(db, row)
×
1408
                                if err != nil {
×
1409
                                        return err
×
1410
                                }
×
1411

1412
                                lastID = row.Channel.ID
×
1413
                        }
1414
                }
1415

1416
                return nil
×
1417
        }, sqldb.NoOpReset)
1418
}
1419

1420
// FilterChannelRange returns the channel ID's of all known channels which were
1421
// mined in a block height within the passed range. The channel IDs are grouped
1422
// by their common block height. This method can be used to quickly share with a
1423
// peer the set of channels we know of within a particular range to catch them
1424
// up after a period of time offline. If withTimestamps is true then the
1425
// timestamp info of the latest received channel update messages of the channel
1426
// will be included in the response.
1427
//
1428
// NOTE: This is part of the V1Store interface.
1429
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1430
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1431

×
1432
        var (
×
1433
                ctx       = context.TODO()
×
1434
                startSCID = &lnwire.ShortChannelID{
×
1435
                        BlockHeight: startHeight,
×
1436
                }
×
1437
                endSCID = lnwire.ShortChannelID{
×
1438
                        BlockHeight: endHeight,
×
1439
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1440
                        TxPosition:  math.MaxUint16,
×
1441
                }
×
1442
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1443
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1444
        )
×
1445

×
1446
        // 1) get all channels where channelID is between start and end chan ID.
×
1447
        // 2) skip if not public (ie, no channel_proof)
×
1448
        // 3) collect that channel.
×
1449
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1450
        //    and add those timestamps to the collected channel.
×
1451
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1452
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1453
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1454
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1455
                                StartScid: chanIDStart[:],
×
1456
                                EndScid:   chanIDEnd[:],
×
1457
                        },
×
1458
                )
×
1459
                if err != nil {
×
1460
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1461
                                err)
×
1462
                }
×
1463

1464
                for _, dbChan := range dbChans {
×
1465
                        cid := lnwire.NewShortChanIDFromInt(
×
1466
                                byteOrder.Uint64(dbChan.Scid),
×
1467
                        )
×
1468
                        chanInfo := NewChannelUpdateInfo(
×
1469
                                cid, time.Time{}, time.Time{},
×
1470
                        )
×
1471

×
1472
                        if !withTimestamps {
×
1473
                                channelsPerBlock[cid.BlockHeight] = append(
×
1474
                                        channelsPerBlock[cid.BlockHeight],
×
1475
                                        chanInfo,
×
1476
                                )
×
1477

×
1478
                                continue
×
1479
                        }
1480

1481
                        //nolint:ll
1482
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1483
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1484
                                        Version:   int16(ProtocolV1),
×
1485
                                        ChannelID: dbChan.ID,
×
1486
                                        NodeID:    dbChan.NodeID1,
×
1487
                                },
×
1488
                        )
×
1489
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1490
                                return fmt.Errorf("unable to fetch node1 "+
×
1491
                                        "policy: %w", err)
×
1492
                        } else if err == nil {
×
1493
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1494
                                        node1Policy.LastUpdate.Int64, 0,
×
1495
                                )
×
1496
                        }
×
1497

1498
                        //nolint:ll
1499
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1500
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1501
                                        Version:   int16(ProtocolV1),
×
1502
                                        ChannelID: dbChan.ID,
×
1503
                                        NodeID:    dbChan.NodeID2,
×
1504
                                },
×
1505
                        )
×
1506
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1507
                                return fmt.Errorf("unable to fetch node2 "+
×
1508
                                        "policy: %w", err)
×
1509
                        } else if err == nil {
×
1510
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1511
                                        node2Policy.LastUpdate.Int64, 0,
×
1512
                                )
×
1513
                        }
×
1514

1515
                        channelsPerBlock[cid.BlockHeight] = append(
×
1516
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1517
                        )
×
1518
                }
1519

1520
                return nil
×
1521
        }, func() {
×
1522
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1523
        })
×
1524
        if err != nil {
×
1525
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1526
        }
×
1527

1528
        if len(channelsPerBlock) == 0 {
×
1529
                return nil, nil
×
1530
        }
×
1531

1532
        // Return the channel ranges in ascending block height order.
1533
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1534
        slices.Sort(blocks)
×
1535

×
1536
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1537
                return BlockChannelRange{
×
1538
                        Height:   block,
×
1539
                        Channels: channelsPerBlock[block],
×
1540
                }
×
1541
        }), nil
×
1542
}
1543

1544
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1545
// zombie. This method is used on an ad-hoc basis, when channels need to be
1546
// marked as zombies outside the normal pruning cycle.
1547
//
1548
// NOTE: part of the V1Store interface.
1549
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1550
        pubKey1, pubKey2 [33]byte) error {
×
1551

×
1552
        ctx := context.TODO()
×
1553

×
1554
        s.cacheMu.Lock()
×
1555
        defer s.cacheMu.Unlock()
×
1556

×
1557
        chanIDB := channelIDToBytes(chanID)
×
1558

×
1559
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1560
                return db.UpsertZombieChannel(
×
1561
                        ctx, sqlc.UpsertZombieChannelParams{
×
1562
                                Version:  int16(ProtocolV1),
×
1563
                                Scid:     chanIDB[:],
×
1564
                                NodeKey1: pubKey1[:],
×
1565
                                NodeKey2: pubKey2[:],
×
1566
                        },
×
1567
                )
×
1568
        }, sqldb.NoOpReset)
×
1569
        if err != nil {
×
1570
                return fmt.Errorf("unable to upsert zombie channel "+
×
1571
                        "(channel_id=%d): %w", chanID, err)
×
1572
        }
×
1573

1574
        s.rejectCache.remove(chanID)
×
1575
        s.chanCache.remove(chanID)
×
1576

×
1577
        return nil
×
1578
}
1579

1580
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1581
//
1582
// NOTE: part of the V1Store interface.
1583
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1584
        s.cacheMu.Lock()
×
1585
        defer s.cacheMu.Unlock()
×
1586

×
1587
        var (
×
1588
                ctx     = context.TODO()
×
1589
                chanIDB = channelIDToBytes(chanID)
×
1590
        )
×
1591

×
1592
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1593
                res, err := db.DeleteZombieChannel(
×
1594
                        ctx, sqlc.DeleteZombieChannelParams{
×
1595
                                Scid:    chanIDB[:],
×
1596
                                Version: int16(ProtocolV1),
×
1597
                        },
×
1598
                )
×
1599
                if err != nil {
×
1600
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1601
                                err)
×
1602
                }
×
1603

1604
                rows, err := res.RowsAffected()
×
1605
                if err != nil {
×
1606
                        return err
×
1607
                }
×
1608

1609
                if rows == 0 {
×
1610
                        return ErrZombieEdgeNotFound
×
1611
                } else if rows > 1 {
×
1612
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1613
                                "expected 1", rows)
×
1614
                }
×
1615

1616
                return nil
×
1617
        }, sqldb.NoOpReset)
1618
        if err != nil {
×
1619
                return fmt.Errorf("unable to mark edge live "+
×
1620
                        "(channel_id=%d): %w", chanID, err)
×
1621
        }
×
1622

1623
        s.rejectCache.remove(chanID)
×
1624
        s.chanCache.remove(chanID)
×
1625

×
1626
        return err
×
1627
}
1628

1629
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1630
// zombie, then the two node public keys corresponding to this edge are also
1631
// returned.
1632
//
1633
// NOTE: part of the V1Store interface.
1634
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1635
        error) {
×
1636

×
1637
        var (
×
1638
                ctx              = context.TODO()
×
1639
                isZombie         bool
×
1640
                pubKey1, pubKey2 route.Vertex
×
1641
                chanIDB          = channelIDToBytes(chanID)
×
1642
        )
×
1643

×
1644
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1645
                zombie, err := db.GetZombieChannel(
×
1646
                        ctx, sqlc.GetZombieChannelParams{
×
1647
                                Scid:    chanIDB[:],
×
1648
                                Version: int16(ProtocolV1),
×
1649
                        },
×
1650
                )
×
1651
                if errors.Is(err, sql.ErrNoRows) {
×
1652
                        return nil
×
1653
                }
×
1654
                if err != nil {
×
1655
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1656
                                err)
×
1657
                }
×
1658

1659
                copy(pubKey1[:], zombie.NodeKey1)
×
1660
                copy(pubKey2[:], zombie.NodeKey2)
×
1661
                isZombie = true
×
1662

×
1663
                return nil
×
1664
        }, sqldb.NoOpReset)
1665
        if err != nil {
×
1666
                return false, route.Vertex{}, route.Vertex{},
×
1667
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1668
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1669
        }
×
1670

1671
        return isZombie, pubKey1, pubKey2, nil
×
1672
}
1673

1674
// NumZombies returns the current number of zombie channels in the graph.
1675
//
1676
// NOTE: part of the V1Store interface.
1677
func (s *SQLStore) NumZombies() (uint64, error) {
×
1678
        var (
×
1679
                ctx        = context.TODO()
×
1680
                numZombies uint64
×
1681
        )
×
1682
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1683
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1684
                if err != nil {
×
1685
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1686
                                err)
×
1687
                }
×
1688

1689
                numZombies = uint64(count)
×
1690

×
1691
                return nil
×
1692
        }, sqldb.NoOpReset)
1693
        if err != nil {
×
1694
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1695
        }
×
1696

1697
        return numZombies, nil
×
1698
}
1699

1700
// DeleteChannelEdges removes edges with the given channel IDs from the
1701
// database and marks them as zombies. This ensures that we're unable to re-add
1702
// it to our database once again. If an edge does not exist within the
1703
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1704
// true, then when we mark these edges as zombies, we'll set up the keys such
1705
// that we require the node that failed to send the fresh update to be the one
1706
// that resurrects the channel from its zombie state. The markZombie bool
1707
// denotes whether to mark the channel as a zombie.
1708
//
1709
// NOTE: part of the V1Store interface.
1710
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1711
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1712

×
1713
        s.cacheMu.Lock()
×
1714
        defer s.cacheMu.Unlock()
×
1715

×
1716
        var (
×
1717
                ctx     = context.TODO()
×
1718
                deleted []*models.ChannelEdgeInfo
×
1719
        )
×
1720
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1721
                for _, chanID := range chanIDs {
×
1722
                        chanIDB := channelIDToBytes(chanID)
×
1723

×
1724
                        row, err := db.GetChannelBySCIDWithPolicies(
×
1725
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1726
                                        Scid:    chanIDB[:],
×
1727
                                        Version: int16(ProtocolV1),
×
1728
                                },
×
1729
                        )
×
1730
                        if errors.Is(err, sql.ErrNoRows) {
×
1731
                                return ErrEdgeNotFound
×
1732
                        } else if err != nil {
×
1733
                                return fmt.Errorf("unable to fetch channel: %w",
×
1734
                                        err)
×
1735
                        }
×
1736

1737
                        node1, node2, err := buildNodeVertices(
×
1738
                                row.Node.PubKey, row.Node_2.PubKey,
×
1739
                        )
×
1740
                        if err != nil {
×
1741
                                return err
×
1742
                        }
×
1743

1744
                        info, err := getAndBuildEdgeInfo(
×
1745
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1746
                                row.Channel, node1, node2,
×
1747
                        )
×
1748
                        if err != nil {
×
1749
                                return err
×
1750
                        }
×
1751

1752
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
1753
                        if err != nil {
×
1754
                                return fmt.Errorf("unable to delete "+
×
1755
                                        "channel: %w", err)
×
1756
                        }
×
1757

1758
                        deleted = append(deleted, info)
×
1759

×
1760
                        if !markZombie {
×
1761
                                continue
×
1762
                        }
1763

1764
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1765
                                info.NodeKey2Bytes
×
1766
                        if strictZombiePruning {
×
1767
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1768
                                if row.Policy1LastUpdate.Valid {
×
1769
                                        e1Time := time.Unix(
×
1770
                                                row.Policy1LastUpdate.Int64, 0,
×
1771
                                        )
×
1772
                                        e1UpdateTime = &e1Time
×
1773
                                }
×
1774
                                if row.Policy2LastUpdate.Valid {
×
1775
                                        e2Time := time.Unix(
×
1776
                                                row.Policy2LastUpdate.Int64, 0,
×
1777
                                        )
×
1778
                                        e2UpdateTime = &e2Time
×
1779
                                }
×
1780

1781
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1782
                                        info, e1UpdateTime, e2UpdateTime,
×
1783
                                )
×
1784
                        }
1785

1786
                        err = db.UpsertZombieChannel(
×
1787
                                ctx, sqlc.UpsertZombieChannelParams{
×
1788
                                        Version:  int16(ProtocolV1),
×
1789
                                        Scid:     chanIDB[:],
×
1790
                                        NodeKey1: nodeKey1[:],
×
1791
                                        NodeKey2: nodeKey2[:],
×
1792
                                },
×
1793
                        )
×
1794
                        if err != nil {
×
1795
                                return fmt.Errorf("unable to mark channel as "+
×
1796
                                        "zombie: %w", err)
×
1797
                        }
×
1798
                }
1799

1800
                return nil
×
1801
        }, func() {
×
1802
                deleted = nil
×
1803
        })
×
1804
        if err != nil {
×
1805
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1806
                        err)
×
1807
        }
×
1808

1809
        for _, chanID := range chanIDs {
×
1810
                s.rejectCache.remove(chanID)
×
1811
                s.chanCache.remove(chanID)
×
1812
        }
×
1813

1814
        return deleted, nil
×
1815
}
1816

1817
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1818
// channel identified by the channel ID. If the channel can't be found, then
1819
// ErrEdgeNotFound is returned. A struct which houses the general information
1820
// for the channel itself is returned as well as two structs that contain the
1821
// routing policies for the channel in either direction.
1822
//
1823
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1824
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1825
// the ChannelEdgeInfo will only include the public keys of each node.
1826
//
1827
// NOTE: part of the V1Store interface.
1828
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1829
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1830
        *models.ChannelEdgePolicy, error) {
×
1831

×
1832
        var (
×
1833
                ctx              = context.TODO()
×
1834
                edge             *models.ChannelEdgeInfo
×
1835
                policy1, policy2 *models.ChannelEdgePolicy
×
1836
        )
×
1837
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1838
                var chanIDB [8]byte
×
1839
                byteOrder.PutUint64(chanIDB[:], chanID)
×
1840

×
1841
                row, err := db.GetChannelBySCIDWithPolicies(
×
1842
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1843
                                Scid:    chanIDB[:],
×
1844
                                Version: int16(ProtocolV1),
×
1845
                        },
×
1846
                )
×
1847
                if errors.Is(err, sql.ErrNoRows) {
×
1848
                        // First check if this edge is perhaps in the zombie
×
1849
                        // index.
×
1850
                        zombie, err := db.GetZombieChannel(
×
1851
                                ctx, sqlc.GetZombieChannelParams{
×
1852
                                        Scid:    chanIDB[:],
×
1853
                                        Version: int16(ProtocolV1),
×
1854
                                },
×
1855
                        )
×
1856
                        if errors.Is(err, sql.ErrNoRows) {
×
1857
                                return ErrEdgeNotFound
×
1858
                        } else if err != nil {
×
1859
                                return fmt.Errorf("unable to check if "+
×
1860
                                        "channel is zombie: %w", err)
×
1861
                        }
×
1862

1863
                        // At this point, we know the channel is a zombie, so
1864
                        // we'll return an error indicating this, and we will
1865
                        // populate the edge info with the public keys of each
1866
                        // party as this is the only information we have about
1867
                        // it.
1868
                        edge = &models.ChannelEdgeInfo{}
×
1869
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1870
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1871

×
1872
                        return ErrZombieEdge
×
1873
                } else if err != nil {
×
1874
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1875
                }
×
1876

1877
                node1, node2, err := buildNodeVertices(
×
1878
                        row.Node.PubKey, row.Node_2.PubKey,
×
1879
                )
×
1880
                if err != nil {
×
1881
                        return err
×
1882
                }
×
1883

1884
                edge, err = getAndBuildEdgeInfo(
×
1885
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1886
                        node1, node2,
×
1887
                )
×
1888
                if err != nil {
×
1889
                        return fmt.Errorf("unable to build channel info: %w",
×
1890
                                err)
×
1891
                }
×
1892

1893
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1894
                if err != nil {
×
1895
                        return fmt.Errorf("unable to extract channel "+
×
1896
                                "policies: %w", err)
×
1897
                }
×
1898

1899
                policy1, policy2, err = getAndBuildChanPolicies(
×
1900
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1901
                )
×
1902
                if err != nil {
×
1903
                        return fmt.Errorf("unable to build channel "+
×
1904
                                "policies: %w", err)
×
1905
                }
×
1906

1907
                return nil
×
1908
        }, sqldb.NoOpReset)
1909
        if err != nil {
×
1910
                // If we are returning the ErrZombieEdge, then we also need to
×
1911
                // return the edge info as the method comment indicates that
×
1912
                // this will be populated when the edge is a zombie.
×
1913
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1914
                        err)
×
1915
        }
×
1916

1917
        return edge, policy1, policy2, nil
×
1918
}
1919

1920
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1921
// the channel identified by the funding outpoint. If the channel can't be
1922
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1923
// information for the channel itself is returned as well as two structs that
1924
// contain the routing policies for the channel in either direction.
1925
//
1926
// NOTE: part of the V1Store interface.
1927
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1928
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1929
        *models.ChannelEdgePolicy, error) {
×
1930

×
1931
        var (
×
1932
                ctx              = context.TODO()
×
1933
                edge             *models.ChannelEdgeInfo
×
1934
                policy1, policy2 *models.ChannelEdgePolicy
×
1935
        )
×
1936
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1937
                row, err := db.GetChannelByOutpointWithPolicies(
×
1938
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1939
                                Outpoint: op.String(),
×
1940
                                Version:  int16(ProtocolV1),
×
1941
                        },
×
1942
                )
×
1943
                if errors.Is(err, sql.ErrNoRows) {
×
1944
                        return ErrEdgeNotFound
×
1945
                } else if err != nil {
×
1946
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1947
                }
×
1948

1949
                node1, node2, err := buildNodeVertices(
×
1950
                        row.Node1Pubkey, row.Node2Pubkey,
×
1951
                )
×
1952
                if err != nil {
×
1953
                        return err
×
1954
                }
×
1955

1956
                edge, err = getAndBuildEdgeInfo(
×
1957
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1958
                        node1, node2,
×
1959
                )
×
1960
                if err != nil {
×
1961
                        return fmt.Errorf("unable to build channel info: %w",
×
1962
                                err)
×
1963
                }
×
1964

1965
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1966
                if err != nil {
×
1967
                        return fmt.Errorf("unable to extract channel "+
×
1968
                                "policies: %w", err)
×
1969
                }
×
1970

1971
                policy1, policy2, err = getAndBuildChanPolicies(
×
1972
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1973
                )
×
1974
                if err != nil {
×
1975
                        return fmt.Errorf("unable to build channel "+
×
1976
                                "policies: %w", err)
×
1977
                }
×
1978

1979
                return nil
×
1980
        }, sqldb.NoOpReset)
1981
        if err != nil {
×
1982
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1983
                        err)
×
1984
        }
×
1985

1986
        return edge, policy1, policy2, nil
×
1987
}
1988

1989
// HasChannelEdge returns true if the database knows of a channel edge with the
1990
// passed channel ID, and false otherwise. If an edge with that ID is found
1991
// within the graph, then two time stamps representing the last time the edge
1992
// was updated for both directed edges are returned along with the boolean. If
1993
// it is not found, then the zombie index is checked and its result is returned
1994
// as the second boolean.
1995
//
1996
// NOTE: part of the V1Store interface.
1997
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1998
        bool, error) {
×
1999

×
2000
        ctx := context.TODO()
×
2001

×
2002
        var (
×
2003
                exists          bool
×
2004
                isZombie        bool
×
2005
                node1LastUpdate time.Time
×
2006
                node2LastUpdate time.Time
×
2007
        )
×
2008

×
2009
        // We'll query the cache with the shared lock held to allow multiple
×
2010
        // readers to access values in the cache concurrently if they exist.
×
2011
        s.cacheMu.RLock()
×
2012
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2013
                s.cacheMu.RUnlock()
×
2014
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2015
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2016
                exists, isZombie = entry.flags.unpack()
×
2017

×
2018
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2019
        }
×
2020
        s.cacheMu.RUnlock()
×
2021

×
2022
        s.cacheMu.Lock()
×
2023
        defer s.cacheMu.Unlock()
×
2024

×
2025
        // The item was not found with the shared lock, so we'll acquire the
×
2026
        // exclusive lock and check the cache again in case another method added
×
2027
        // the entry to the cache while no lock was held.
×
2028
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2029
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2030
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2031
                exists, isZombie = entry.flags.unpack()
×
2032

×
2033
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2034
        }
×
2035

2036
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2037
                var chanIDB [8]byte
×
2038
                byteOrder.PutUint64(chanIDB[:], chanID)
×
2039

×
2040
                channel, err := db.GetChannelBySCID(
×
2041
                        ctx, sqlc.GetChannelBySCIDParams{
×
2042
                                Scid:    chanIDB[:],
×
2043
                                Version: int16(ProtocolV1),
×
2044
                        },
×
2045
                )
×
2046
                if errors.Is(err, sql.ErrNoRows) {
×
2047
                        // Check if it is a zombie channel.
×
2048
                        isZombie, err = db.IsZombieChannel(
×
2049
                                ctx, sqlc.IsZombieChannelParams{
×
2050
                                        Scid:    chanIDB[:],
×
2051
                                        Version: int16(ProtocolV1),
×
2052
                                },
×
2053
                        )
×
2054
                        if err != nil {
×
2055
                                return fmt.Errorf("could not check if channel "+
×
2056
                                        "is zombie: %w", err)
×
2057
                        }
×
2058

2059
                        return nil
×
2060
                } else if err != nil {
×
2061
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2062
                }
×
2063

2064
                exists = true
×
2065

×
2066
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2067
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2068
                                Version:   int16(ProtocolV1),
×
2069
                                ChannelID: channel.ID,
×
2070
                                NodeID:    channel.NodeID1,
×
2071
                        },
×
2072
                )
×
2073
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2074
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2075
                                err)
×
2076
                } else if err == nil {
×
2077
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2078
                }
×
2079

2080
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2081
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2082
                                Version:   int16(ProtocolV1),
×
2083
                                ChannelID: channel.ID,
×
2084
                                NodeID:    channel.NodeID2,
×
2085
                        },
×
2086
                )
×
2087
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2088
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2089
                                err)
×
2090
                } else if err == nil {
×
2091
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2092
                }
×
2093

2094
                return nil
×
2095
        }, sqldb.NoOpReset)
2096
        if err != nil {
×
2097
                return time.Time{}, time.Time{}, false, false,
×
2098
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2099
        }
×
2100

2101
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2102
                upd1Time: node1LastUpdate.Unix(),
×
2103
                upd2Time: node2LastUpdate.Unix(),
×
2104
                flags:    packRejectFlags(exists, isZombie),
×
2105
        })
×
2106

×
2107
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2108
}
2109

2110
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2111
// passed channel point (outpoint). If the passed channel doesn't exist within
2112
// the database, then ErrEdgeNotFound is returned.
2113
//
2114
// NOTE: part of the V1Store interface.
2115
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2116
        var (
×
2117
                ctx       = context.TODO()
×
2118
                channelID uint64
×
2119
        )
×
2120
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2121
                chanID, err := db.GetSCIDByOutpoint(
×
2122
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2123
                                Outpoint: chanPoint.String(),
×
2124
                                Version:  int16(ProtocolV1),
×
2125
                        },
×
2126
                )
×
2127
                if errors.Is(err, sql.ErrNoRows) {
×
2128
                        return ErrEdgeNotFound
×
2129
                } else if err != nil {
×
2130
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2131
                                err)
×
2132
                }
×
2133

2134
                channelID = byteOrder.Uint64(chanID)
×
2135

×
2136
                return nil
×
2137
        }, sqldb.NoOpReset)
2138
        if err != nil {
×
2139
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2140
        }
×
2141

2142
        return channelID, nil
×
2143
}
2144

2145
// IsPublicNode is a helper method that determines whether the node with the
2146
// given public key is seen as a public node in the graph from the graph's
2147
// source node's point of view.
2148
//
2149
// NOTE: part of the V1Store interface.
2150
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2151
        ctx := context.TODO()
×
2152

×
2153
        var isPublic bool
×
2154
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2155
                var err error
×
2156
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2157

×
2158
                return err
×
2159
        }, sqldb.NoOpReset)
×
2160
        if err != nil {
×
2161
                return false, fmt.Errorf("unable to check if node is "+
×
2162
                        "public: %w", err)
×
2163
        }
×
2164

2165
        return isPublic, nil
×
2166
}
2167

2168
// FetchChanInfos returns the set of channel edges that correspond to the passed
2169
// channel ID's. If an edge is the query is unknown to the database, it will
2170
// skipped and the result will contain only those edges that exist at the time
2171
// of the query. This can be used to respond to peer queries that are seeking to
2172
// fill in gaps in their view of the channel graph.
2173
//
2174
// NOTE: part of the V1Store interface.
2175
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2176
        var (
×
2177
                ctx   = context.TODO()
×
2178
                edges []ChannelEdge
×
2179
        )
×
2180
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2181
                for _, chanID := range chanIDs {
×
2182
                        var chanIDB [8]byte
×
2183
                        byteOrder.PutUint64(chanIDB[:], chanID)
×
2184

×
2185
                        // TODO(elle): potentially optimize this by using
×
2186
                        //  sqlc.slice() once that works for both SQLite and
×
2187
                        //  Postgres.
×
2188
                        row, err := db.GetChannelBySCIDWithPolicies(
×
2189
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
2190
                                        Scid:    chanIDB[:],
×
2191
                                        Version: int16(ProtocolV1),
×
2192
                                },
×
2193
                        )
×
2194
                        if errors.Is(err, sql.ErrNoRows) {
×
2195
                                continue
×
2196
                        } else if err != nil {
×
2197
                                return fmt.Errorf("unable to fetch channel: %w",
×
2198
                                        err)
×
2199
                        }
×
2200

2201
                        node1, node2, err := buildNodes(
×
2202
                                ctx, db, row.Node, row.Node_2,
×
2203
                        )
×
2204
                        if err != nil {
×
2205
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2206
                                        err)
×
2207
                        }
×
2208

2209
                        edge, err := getAndBuildEdgeInfo(
×
2210
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2211
                                row.Channel, node1.PubKeyBytes,
×
2212
                                node2.PubKeyBytes,
×
2213
                        )
×
2214
                        if err != nil {
×
2215
                                return fmt.Errorf("unable to build "+
×
2216
                                        "channel info: %w", err)
×
2217
                        }
×
2218

2219
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2220
                        if err != nil {
×
2221
                                return fmt.Errorf("unable to extract channel "+
×
2222
                                        "policies: %w", err)
×
2223
                        }
×
2224

2225
                        p1, p2, err := getAndBuildChanPolicies(
×
2226
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2227
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2228
                        )
×
2229
                        if err != nil {
×
2230
                                return fmt.Errorf("unable to build channel "+
×
2231
                                        "policies: %w", err)
×
2232
                        }
×
2233

2234
                        edges = append(edges, ChannelEdge{
×
2235
                                Info:    edge,
×
2236
                                Policy1: p1,
×
2237
                                Policy2: p2,
×
2238
                                Node1:   node1,
×
2239
                                Node2:   node2,
×
2240
                        })
×
2241
                }
2242

2243
                return nil
×
2244
        }, func() {
×
2245
                edges = nil
×
2246
        })
×
2247
        if err != nil {
×
2248
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2249
        }
×
2250

2251
        return edges, nil
×
2252
}
2253

2254
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2255
// ID's that we don't know and are not known zombies of the passed set. In other
2256
// words, we perform a set difference of our set of chan ID's and the ones
2257
// passed in. This method can be used by callers to determine the set of
2258
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2259
// known zombies is also returned.
2260
//
2261
// NOTE: part of the V1Store interface.
2262
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2263
        []ChannelUpdateInfo, error) {
×
2264

×
2265
        var (
×
2266
                ctx          = context.TODO()
×
2267
                newChanIDs   []uint64
×
2268
                knownZombies []ChannelUpdateInfo
×
2269
        )
×
2270
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2271
                for _, chanInfo := range chansInfo {
×
2272
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2273
                        var chanIDB [8]byte
×
2274
                        byteOrder.PutUint64(chanIDB[:], channelID)
×
2275

×
2276
                        // TODO(elle): potentially optimize this by using
×
2277
                        //  sqlc.slice() once that works for both SQLite and
×
2278
                        //  Postgres.
×
2279
                        _, err := db.GetChannelBySCID(
×
2280
                                ctx, sqlc.GetChannelBySCIDParams{
×
2281
                                        Version: int16(ProtocolV1),
×
2282
                                        Scid:    chanIDB[:],
×
2283
                                },
×
2284
                        )
×
2285
                        if err == nil {
×
2286
                                continue
×
2287
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
2288
                                return fmt.Errorf("unable to fetch channel: %w",
×
2289
                                        err)
×
2290
                        }
×
2291

2292
                        isZombie, err := db.IsZombieChannel(
×
2293
                                ctx, sqlc.IsZombieChannelParams{
×
2294
                                        Scid:    chanIDB[:],
×
2295
                                        Version: int16(ProtocolV1),
×
2296
                                },
×
2297
                        )
×
2298
                        if err != nil {
×
2299
                                return fmt.Errorf("unable to fetch zombie "+
×
2300
                                        "channel: %w", err)
×
2301
                        }
×
2302

2303
                        if isZombie {
×
2304
                                knownZombies = append(knownZombies, chanInfo)
×
2305

×
2306
                                continue
×
2307
                        }
2308

2309
                        newChanIDs = append(newChanIDs, channelID)
×
2310
                }
2311

2312
                return nil
×
2313
        }, func() {
×
2314
                newChanIDs = nil
×
2315
                knownZombies = nil
×
2316
        })
×
2317
        if err != nil {
×
2318
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2319
        }
×
2320

2321
        return newChanIDs, knownZombies, nil
×
2322
}
2323

2324
// PruneGraphNodes is a garbage collection method which attempts to prune out
2325
// any nodes from the channel graph that are currently unconnected. This ensure
2326
// that we only maintain a graph of reachable nodes. In the event that a pruned
2327
// node gains more channels, it will be re-added back to the graph.
2328
//
2329
// NOTE: this prunes nodes across protocol versions. It will never prune the
2330
// source nodes.
2331
//
2332
// NOTE: part of the V1Store interface.
2333
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2334
        var ctx = context.TODO()
×
2335

×
2336
        var prunedNodes []route.Vertex
×
2337
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2338
                var err error
×
2339
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2340

×
2341
                return err
×
2342
        }, func() {
×
2343
                prunedNodes = nil
×
2344
        })
×
2345
        if err != nil {
×
2346
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2347
        }
×
2348

2349
        return prunedNodes, nil
×
2350
}
2351

2352
// PruneGraph prunes newly closed channels from the channel graph in response
2353
// to a new block being solved on the network. Any transactions which spend the
2354
// funding output of any known channels within he graph will be deleted.
2355
// Additionally, the "prune tip", or the last block which has been used to
2356
// prune the graph is stored so callers can ensure the graph is fully in sync
2357
// with the current UTXO state. A slice of channels that have been closed by
2358
// the target block along with any pruned nodes are returned if the function
2359
// succeeds without error.
2360
//
2361
// NOTE: part of the V1Store interface.
2362
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2363
        blockHash *chainhash.Hash, blockHeight uint32) (
2364
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2365

×
2366
        ctx := context.TODO()
×
2367

×
2368
        s.cacheMu.Lock()
×
2369
        defer s.cacheMu.Unlock()
×
2370

×
2371
        var (
×
2372
                closedChans []*models.ChannelEdgeInfo
×
2373
                prunedNodes []route.Vertex
×
2374
        )
×
2375
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2376
                for _, outpoint := range spentOutputs {
×
2377
                        // TODO(elle): potentially optimize this by using
×
2378
                        //  sqlc.slice() once that works for both SQLite and
×
2379
                        //  Postgres.
×
2380
                        //
×
2381
                        // NOTE: this fetches channels for all protocol
×
2382
                        // versions.
×
2383
                        row, err := db.GetChannelByOutpoint(
×
2384
                                ctx, outpoint.String(),
×
2385
                        )
×
2386
                        if errors.Is(err, sql.ErrNoRows) {
×
2387
                                continue
×
2388
                        } else if err != nil {
×
2389
                                return fmt.Errorf("unable to fetch channel: %w",
×
2390
                                        err)
×
2391
                        }
×
2392

2393
                        node1, node2, err := buildNodeVertices(
×
2394
                                row.Node1Pubkey, row.Node2Pubkey,
×
2395
                        )
×
2396
                        if err != nil {
×
2397
                                return err
×
2398
                        }
×
2399

2400
                        info, err := getAndBuildEdgeInfo(
×
2401
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2402
                                row.Channel, node1, node2,
×
2403
                        )
×
2404
                        if err != nil {
×
2405
                                return err
×
2406
                        }
×
2407

2408
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2409
                        if err != nil {
×
2410
                                return fmt.Errorf("unable to delete "+
×
2411
                                        "channel: %w", err)
×
2412
                        }
×
2413

2414
                        closedChans = append(closedChans, info)
×
2415
                }
2416

2417
                err := db.UpsertPruneLogEntry(
×
2418
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2419
                                BlockHash:   blockHash[:],
×
2420
                                BlockHeight: int64(blockHeight),
×
2421
                        },
×
2422
                )
×
2423
                if err != nil {
×
2424
                        return fmt.Errorf("unable to insert prune log "+
×
2425
                                "entry: %w", err)
×
2426
                }
×
2427

2428
                // Now that we've pruned some channels, we'll also prune any
2429
                // nodes that no longer have any channels.
2430
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2431
                if err != nil {
×
2432
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2433
                                err)
×
2434
                }
×
2435

2436
                return nil
×
2437
        }, func() {
×
2438
                prunedNodes = nil
×
2439
                closedChans = nil
×
2440
        })
×
2441
        if err != nil {
×
2442
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2443
        }
×
2444

2445
        for _, channel := range closedChans {
×
2446
                s.rejectCache.remove(channel.ChannelID)
×
2447
                s.chanCache.remove(channel.ChannelID)
×
2448
        }
×
2449

2450
        return closedChans, prunedNodes, nil
×
2451
}
2452

2453
// ChannelView returns the verifiable edge information for each active channel
2454
// within the known channel graph. The set of UTXOs (along with their scripts)
2455
// returned are the ones that need to be watched on chain to detect channel
2456
// closes on the resident blockchain.
2457
//
2458
// NOTE: part of the V1Store interface.
2459
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2460
        var (
×
2461
                ctx        = context.TODO()
×
2462
                edgePoints []EdgePoint
×
2463
        )
×
2464

×
2465
        handleChannel := func(db SQLQueries,
×
2466
                channel sqlc.ListChannelsPaginatedRow) error {
×
2467

×
2468
                pkScript, err := genMultiSigP2WSH(
×
2469
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
2470
                )
×
2471
                if err != nil {
×
2472
                        return err
×
2473
                }
×
2474

2475
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2476
                if err != nil {
×
2477
                        return err
×
2478
                }
×
2479

2480
                edgePoints = append(edgePoints, EdgePoint{
×
2481
                        FundingPkScript: pkScript,
×
2482
                        OutPoint:        *op,
×
2483
                })
×
2484

×
2485
                return nil
×
2486
        }
2487

2488
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2489
                lastID := int64(-1)
×
2490
                for {
×
2491
                        rows, err := db.ListChannelsPaginated(
×
2492
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2493
                                        Version: int16(ProtocolV1),
×
2494
                                        ID:      lastID,
×
2495
                                        Limit:   pageSize,
×
2496
                                },
×
2497
                        )
×
2498
                        if err != nil {
×
2499
                                return err
×
2500
                        }
×
2501

2502
                        if len(rows) == 0 {
×
2503
                                break
×
2504
                        }
2505

2506
                        for _, row := range rows {
×
2507
                                err := handleChannel(db, row)
×
2508
                                if err != nil {
×
2509
                                        return err
×
2510
                                }
×
2511

2512
                                lastID = row.ID
×
2513
                        }
2514
                }
2515

2516
                return nil
×
2517
        }, func() {
×
2518
                edgePoints = nil
×
2519
        })
×
2520
        if err != nil {
×
2521
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2522
        }
×
2523

2524
        return edgePoints, nil
×
2525
}
2526

2527
// PruneTip returns the block height and hash of the latest block that has been
2528
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2529
// to tell if the graph is currently in sync with the current best known UTXO
2530
// state.
2531
//
2532
// NOTE: part of the V1Store interface.
2533
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2534
        var (
×
2535
                ctx       = context.TODO()
×
2536
                tipHash   chainhash.Hash
×
2537
                tipHeight uint32
×
2538
        )
×
2539
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2540
                pruneTip, err := db.GetPruneTip(ctx)
×
2541
                if errors.Is(err, sql.ErrNoRows) {
×
2542
                        return ErrGraphNeverPruned
×
2543
                } else if err != nil {
×
2544
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2545
                }
×
2546

2547
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2548
                tipHeight = uint32(pruneTip.BlockHeight)
×
2549

×
2550
                return nil
×
2551
        }, sqldb.NoOpReset)
2552
        if err != nil {
×
2553
                return nil, 0, err
×
2554
        }
×
2555

2556
        return &tipHash, tipHeight, nil
×
2557
}
2558

2559
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2560
//
2561
// NOTE: this prunes nodes across protocol versions. It will never prune the
2562
// source nodes.
2563
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2564
        db SQLQueries) ([]route.Vertex, error) {
×
2565

×
NEW
2566
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2567
        if err != nil {
×
NEW
2568
                return nil, fmt.Errorf("unable to delete unconnected "+
×
NEW
2569
                        "nodes: %w", err)
×
UNCOV
2570
        }
×
2571

NEW
2572
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
NEW
2573
        for i, nodeKey := range nodeKeys {
×
NEW
2574
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2575
                if err != nil {
×
2576
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
NEW
2577
                                "from bytes: %w", err)
×
2578
                }
×
2579

NEW
2580
                prunedNodes[i] = pub
×
2581
        }
2582

2583
        return prunedNodes, nil
×
2584
}
2585

2586
// DisconnectBlockAtHeight is used to indicate that the block specified
2587
// by the passed height has been disconnected from the main chain. This
2588
// will "rewind" the graph back to the height below, deleting channels
2589
// that are no longer confirmed from the graph. The prune log will be
2590
// set to the last prune height valid for the remaining chain.
2591
// Channels that were removed from the graph resulting from the
2592
// disconnected block are returned.
2593
//
2594
// NOTE: part of the V1Store interface.
2595
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2596
        []*models.ChannelEdgeInfo, error) {
×
2597

×
2598
        ctx := context.TODO()
×
2599

×
2600
        var (
×
2601
                // Every channel having a ShortChannelID starting at 'height'
×
2602
                // will no longer be confirmed.
×
2603
                startShortChanID = lnwire.ShortChannelID{
×
2604
                        BlockHeight: height,
×
2605
                }
×
2606

×
2607
                // Delete everything after this height from the db up until the
×
2608
                // SCID alias range.
×
2609
                endShortChanID = aliasmgr.StartingAlias
×
2610

×
2611
                removedChans []*models.ChannelEdgeInfo
×
2612
        )
×
2613

×
2614
        var chanIDStart [8]byte
×
2615
        byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64())
×
2616
        var chanIDEnd [8]byte
×
2617
        byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64())
×
2618

×
2619
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2620
                rows, err := db.GetChannelsBySCIDRange(
×
2621
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2622
                                StartScid: chanIDStart[:],
×
2623
                                EndScid:   chanIDEnd[:],
×
2624
                        },
×
2625
                )
×
2626
                if err != nil {
×
2627
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2628
                }
×
2629

2630
                for _, row := range rows {
×
2631
                        node1, node2, err := buildNodeVertices(
×
2632
                                row.Node1PubKey, row.Node2PubKey,
×
2633
                        )
×
2634
                        if err != nil {
×
2635
                                return err
×
2636
                        }
×
2637

2638
                        channel, err := getAndBuildEdgeInfo(
×
2639
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2640
                                row.Channel, node1, node2,
×
2641
                        )
×
2642
                        if err != nil {
×
2643
                                return err
×
2644
                        }
×
2645

2646
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2647
                        if err != nil {
×
2648
                                return fmt.Errorf("unable to delete "+
×
2649
                                        "channel: %w", err)
×
2650
                        }
×
2651

2652
                        removedChans = append(removedChans, channel)
×
2653
                }
2654

2655
                return db.DeletePruneLogEntriesInRange(
×
2656
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2657
                                StartHeight: int64(height),
×
2658
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2659
                        },
×
2660
                )
×
2661
        }, func() {
×
2662
                removedChans = nil
×
2663
        })
×
2664
        if err != nil {
×
2665
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2666
                        "height: %w", err)
×
2667
        }
×
2668

2669
        for _, channel := range removedChans {
×
2670
                s.rejectCache.remove(channel.ChannelID)
×
2671
                s.chanCache.remove(channel.ChannelID)
×
2672
        }
×
2673

2674
        return removedChans, nil
×
2675
}
2676

2677
// AddEdgeProof sets the proof of an existing edge in the graph database.
2678
//
2679
// NOTE: part of the V1Store interface.
2680
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2681
        proof *models.ChannelAuthProof) error {
×
2682

×
2683
        var (
×
2684
                ctx       = context.TODO()
×
2685
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2686
        )
×
2687

×
2688
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2689
                res, err := db.AddV1ChannelProof(
×
2690
                        ctx, sqlc.AddV1ChannelProofParams{
×
2691
                                Scid:              scidBytes[:],
×
2692
                                Node1Signature:    proof.NodeSig1Bytes,
×
2693
                                Node2Signature:    proof.NodeSig2Bytes,
×
2694
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2695
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2696
                        },
×
2697
                )
×
2698
                if err != nil {
×
2699
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2700
                }
×
2701

2702
                n, err := res.RowsAffected()
×
2703
                if err != nil {
×
2704
                        return err
×
2705
                }
×
2706

2707
                if n == 0 {
×
2708
                        return fmt.Errorf("no rows affected when adding edge "+
×
2709
                                "proof for SCID %v", scid)
×
2710
                } else if n > 1 {
×
2711
                        return fmt.Errorf("multiple rows affected when adding "+
×
2712
                                "edge proof for SCID %v: %d rows affected",
×
2713
                                scid, n)
×
2714
                }
×
2715

2716
                return nil
×
2717
        }, sqldb.NoOpReset)
2718
        if err != nil {
×
2719
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2720
        }
×
2721

2722
        return nil
×
2723
}
2724

2725
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2726
// that we can ignore channel announcements that we know to be closed without
2727
// having to validate them and fetch a block.
2728
//
2729
// NOTE: part of the V1Store interface.
2730
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2731
        var (
×
2732
                ctx     = context.TODO()
×
2733
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2734
        )
×
2735

×
2736
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2737
                return db.InsertClosedChannel(ctx, chanIDB[:])
×
2738
        }, sqldb.NoOpReset)
×
2739
}
2740

2741
// IsClosedScid checks whether a channel identified by the passed in scid is
2742
// closed. This helps avoid having to perform expensive validation checks.
2743
//
2744
// NOTE: part of the V1Store interface.
2745
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2746
        var (
×
2747
                ctx      = context.TODO()
×
2748
                isClosed bool
×
2749
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2750
        )
×
2751
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2752
                var err error
×
2753
                isClosed, err = db.IsClosedChannel(ctx, chanIDB[:])
×
2754
                if err != nil {
×
2755
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2756
                                err)
×
2757
                }
×
2758

2759
                return nil
×
2760
        }, sqldb.NoOpReset)
2761
        if err != nil {
×
2762
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2763
                        err)
×
2764
        }
×
2765

2766
        return isClosed, nil
×
2767
}
2768

2769
// GraphSession will provide the call-back with access to a NodeTraverser
2770
// instance which can be used to perform queries against the channel graph.
2771
//
2772
// NOTE: part of the V1Store interface.
2773
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error) error {
×
2774
        var ctx = context.TODO()
×
2775

×
2776
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2777
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2778
        }, sqldb.NoOpReset)
×
2779
}
2780

2781
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2782
// read only transaction for a consistent view of the graph.
2783
type sqlNodeTraverser struct {
2784
        db    SQLQueries
2785
        chain chainhash.Hash
2786
}
2787

2788
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2789
// NodeTraverser interface.
2790
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2791

2792
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2793
func newSQLNodeTraverser(db SQLQueries,
2794
        chain chainhash.Hash) *sqlNodeTraverser {
×
2795

×
2796
        return &sqlNodeTraverser{
×
2797
                db:    db,
×
2798
                chain: chain,
×
2799
        }
×
2800
}
×
2801

2802
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2803
// node.
2804
//
2805
// NOTE: Part of the NodeTraverser interface.
2806
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2807
        cb func(channel *DirectedChannel) error) error {
×
2808

×
2809
        ctx := context.TODO()
×
2810

×
2811
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2812
}
×
2813

2814
// FetchNodeFeatures returns the features of the given node. If the node is
2815
// unknown, assume no additional features are supported.
2816
//
2817
// NOTE: Part of the NodeTraverser interface.
2818
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2819
        *lnwire.FeatureVector, error) {
×
2820

×
2821
        ctx := context.TODO()
×
2822

×
2823
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2824
}
×
2825

2826
// forEachNodeDirectedChannel iterates through all channels of a given
2827
// node, executing the passed callback on the directed edge representing the
2828
// channel and its incoming policy. If the node is not found, no error is
2829
// returned.
2830
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2831
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2832

×
2833
        toNodeCallback := func() route.Vertex {
×
2834
                return nodePub
×
2835
        }
×
2836

2837
        dbID, err := db.GetNodeIDByPubKey(
×
2838
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2839
                        Version: int16(ProtocolV1),
×
2840
                        PubKey:  nodePub[:],
×
2841
                },
×
2842
        )
×
2843
        if errors.Is(err, sql.ErrNoRows) {
×
2844
                return nil
×
2845
        } else if err != nil {
×
2846
                return fmt.Errorf("unable to fetch node: %w", err)
×
2847
        }
×
2848

2849
        rows, err := db.ListChannelsByNodeID(
×
2850
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2851
                        Version: int16(ProtocolV1),
×
2852
                        NodeID1: dbID,
×
2853
                },
×
2854
        )
×
2855
        if err != nil {
×
2856
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2857
        }
×
2858

2859
        // Exit early if there are no channels for this node so we don't
2860
        // do the unnecessary feature fetching.
2861
        if len(rows) == 0 {
×
2862
                return nil
×
2863
        }
×
2864

2865
        features, err := getNodeFeatures(ctx, db, dbID)
×
2866
        if err != nil {
×
2867
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2868
        }
×
2869

2870
        for _, row := range rows {
×
2871
                node1, node2, err := buildNodeVertices(
×
2872
                        row.Node1Pubkey, row.Node2Pubkey,
×
2873
                )
×
2874
                if err != nil {
×
2875
                        return fmt.Errorf("unable to build node vertices: %w",
×
2876
                                err)
×
2877
                }
×
2878

2879
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
2880

×
2881
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2882
                if err != nil {
×
2883
                        return err
×
2884
                }
×
2885

2886
                var p1, p2 *models.CachedEdgePolicy
×
2887
                if dbPol1 != nil {
×
2888
                        policy1, err := buildChanPolicy(
×
2889
                                *dbPol1, edge.ChannelID, nil, node2,
×
2890
                        )
×
2891
                        if err != nil {
×
2892
                                return err
×
2893
                        }
×
2894

2895
                        p1 = models.NewCachedPolicy(policy1)
×
2896
                }
2897
                if dbPol2 != nil {
×
2898
                        policy2, err := buildChanPolicy(
×
2899
                                *dbPol2, edge.ChannelID, nil, node1,
×
2900
                        )
×
2901
                        if err != nil {
×
2902
                                return err
×
2903
                        }
×
2904

2905
                        p2 = models.NewCachedPolicy(policy2)
×
2906
                }
2907

2908
                // Determine the outgoing and incoming policy for this
2909
                // channel and node combo.
2910
                outPolicy, inPolicy := p1, p2
×
2911
                if p1 != nil && node2 == nodePub {
×
2912
                        outPolicy, inPolicy = p2, p1
×
2913
                } else if p2 != nil && node1 != nodePub {
×
2914
                        outPolicy, inPolicy = p2, p1
×
2915
                }
×
2916

2917
                var cachedInPolicy *models.CachedEdgePolicy
×
2918
                if inPolicy != nil {
×
2919
                        cachedInPolicy = inPolicy
×
2920
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2921
                        cachedInPolicy.ToNodeFeatures = features
×
2922
                }
×
2923

2924
                directedChannel := &DirectedChannel{
×
2925
                        ChannelID:    edge.ChannelID,
×
2926
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2927
                        OtherNode:    edge.NodeKey2Bytes,
×
2928
                        Capacity:     edge.Capacity,
×
2929
                        OutPolicySet: outPolicy != nil,
×
2930
                        InPolicy:     cachedInPolicy,
×
2931
                }
×
2932
                if outPolicy != nil {
×
2933
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2934
                                directedChannel.InboundFee = fee
×
2935
                        })
×
2936
                }
2937

2938
                if nodePub == edge.NodeKey2Bytes {
×
2939
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2940
                }
×
2941

2942
                if err := cb(directedChannel); err != nil {
×
2943
                        return err
×
2944
                }
×
2945
        }
2946

2947
        return nil
×
2948
}
2949

2950
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2951
// and executes the provided callback for each node.
2952
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
2953
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
2954

×
2955
        lastID := int64(-1)
×
2956

×
2957
        for {
×
2958
                nodes, err := db.ListNodeIDsAndPubKeys(
×
2959
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
2960
                                Version: int16(ProtocolV1),
×
2961
                                ID:      lastID,
×
2962
                                Limit:   pageSize,
×
2963
                        },
×
2964
                )
×
2965
                if err != nil {
×
2966
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
2967
                }
×
2968

2969
                if len(nodes) == 0 {
×
2970
                        break
×
2971
                }
2972

2973
                for _, node := range nodes {
×
2974
                        var pub route.Vertex
×
2975
                        copy(pub[:], node.PubKey)
×
2976

×
2977
                        if err := cb(node.ID, pub); err != nil {
×
2978
                                return fmt.Errorf("forEachNodeCacheable "+
×
2979
                                        "callback failed for node(id=%d): %w",
×
2980
                                        node.ID, err)
×
2981
                        }
×
2982

2983
                        lastID = node.ID
×
2984
                }
2985
        }
2986

2987
        return nil
×
2988
}
2989

2990
// forEachNodeChannel iterates through all channels of a node, executing
2991
// the passed callback on each. The call-back is provided with the channel's
2992
// edge information, the outgoing policy and the incoming policy for the
2993
// channel and node combo.
2994
func forEachNodeChannel(ctx context.Context, db SQLQueries,
2995
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
2996
                *models.ChannelEdgePolicy,
2997
                *models.ChannelEdgePolicy) error) error {
×
2998

×
2999
        // Get all the V1 channels for this node.Add commentMore actions
×
3000
        rows, err := db.ListChannelsByNodeID(
×
3001
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3002
                        Version: int16(ProtocolV1),
×
3003
                        NodeID1: id,
×
3004
                },
×
3005
        )
×
3006
        if err != nil {
×
3007
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3008
        }
×
3009

3010
        // Call the call-back for each channel and its known policies.
3011
        for _, row := range rows {
×
3012
                node1, node2, err := buildNodeVertices(
×
3013
                        row.Node1Pubkey, row.Node2Pubkey,
×
3014
                )
×
3015
                if err != nil {
×
3016
                        return fmt.Errorf("unable to build node vertices: %w",
×
3017
                                err)
×
3018
                }
×
3019

3020
                edge, err := getAndBuildEdgeInfo(
×
3021
                        ctx, db, chain, row.Channel.ID, row.Channel, node1,
×
3022
                        node2,
×
3023
                )
×
3024
                if err != nil {
×
3025
                        return fmt.Errorf("unable to build channel info: %w",
×
3026
                                err)
×
3027
                }
×
3028

3029
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3030
                if err != nil {
×
3031
                        return fmt.Errorf("unable to extract channel "+
×
3032
                                "policies: %w", err)
×
3033
                }
×
3034

3035
                p1, p2, err := getAndBuildChanPolicies(
×
3036
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3037
                )
×
3038
                if err != nil {
×
3039
                        return fmt.Errorf("unable to build channel "+
×
3040
                                "policies: %w", err)
×
3041
                }
×
3042

3043
                // Determine the outgoing and incoming policy for this
3044
                // channel and node combo.
3045
                p1ToNode := row.Channel.NodeID2
×
3046
                p2ToNode := row.Channel.NodeID1
×
3047
                outPolicy, inPolicy := p1, p2
×
3048
                if (p1 != nil && p1ToNode == id) ||
×
3049
                        (p2 != nil && p2ToNode != id) {
×
3050

×
3051
                        outPolicy, inPolicy = p2, p1
×
3052
                }
×
3053

3054
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3055
                        return err
×
3056
                }
×
3057
        }
3058

3059
        return nil
×
3060
}
3061

3062
// updateChanEdgePolicy upserts the channel policy info we have stored for
3063
// a channel we already know of.
3064
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3065
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3066
        error) {
×
3067

×
3068
        var (
×
3069
                node1Pub, node2Pub route.Vertex
×
3070
                isNode1            bool
×
3071
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3072
        )
×
3073

×
3074
        // Check that this edge policy refers to a channel that we already
×
3075
        // know of. We do this explicitly so that we can return the appropriate
×
3076
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3077
        // abort the transaction which would abort the entire batch.
×
3078
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3079
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3080
                        Scid:    chanIDB[:],
×
3081
                        Version: int16(ProtocolV1),
×
3082
                },
×
3083
        )
×
3084
        if errors.Is(err, sql.ErrNoRows) {
×
3085
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3086
        } else if err != nil {
×
3087
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3088
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3089
        }
×
3090

3091
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3092
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3093

×
3094
        // Figure out which node this edge is from.
×
3095
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3096
        nodeID := dbChan.NodeID1
×
3097
        if !isNode1 {
×
3098
                nodeID = dbChan.NodeID2
×
3099
        }
×
3100

3101
        var (
×
3102
                inboundBase sql.NullInt64
×
3103
                inboundRate sql.NullInt64
×
3104
        )
×
3105
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3106
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3107
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3108
        })
×
3109

3110
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3111
                Version:     int16(ProtocolV1),
×
3112
                ChannelID:   dbChan.ID,
×
3113
                NodeID:      nodeID,
×
3114
                Timelock:    int32(edge.TimeLockDelta),
×
3115
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3116
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3117
                MinHtlcMsat: int64(edge.MinHTLC),
×
3118
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3119
                Disabled: sql.NullBool{
×
3120
                        Valid: true,
×
3121
                        Bool:  edge.IsDisabled(),
×
3122
                },
×
3123
                MaxHtlcMsat: sql.NullInt64{
×
3124
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3125
                        Int64: int64(edge.MaxHTLC),
×
3126
                },
×
3127
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3128
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3129
                InboundBaseFeeMsat:      inboundBase,
×
3130
                InboundFeeRateMilliMsat: inboundRate,
×
3131
                Signature:               edge.SigBytes,
×
3132
        })
×
3133
        if err != nil {
×
3134
                return node1Pub, node2Pub, isNode1,
×
3135
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3136
        }
×
3137

3138
        // Convert the flat extra opaque data into a map of TLV types to
3139
        // values.
3140
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3141
        if err != nil {
×
3142
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3143
                        "marshal extra opaque data: %w", err)
×
3144
        }
×
3145

3146
        // Update the channel policy's extra signed fields.
3147
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3148
        if err != nil {
×
3149
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3150
                        "policy extra TLVs: %w", err)
×
3151
        }
×
3152

3153
        return node1Pub, node2Pub, isNode1, nil
×
3154
}
3155

3156
// getNodeByPubKey attempts to look up a target node by its public key.
3157
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3158
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3159

×
3160
        dbNode, err := db.GetNodeByPubKey(
×
3161
                ctx, sqlc.GetNodeByPubKeyParams{
×
3162
                        Version: int16(ProtocolV1),
×
3163
                        PubKey:  pubKey[:],
×
3164
                },
×
3165
        )
×
3166
        if errors.Is(err, sql.ErrNoRows) {
×
3167
                return 0, nil, ErrGraphNodeNotFound
×
3168
        } else if err != nil {
×
3169
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3170
        }
×
3171

3172
        node, err := buildNode(ctx, db, &dbNode)
×
3173
        if err != nil {
×
3174
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3175
        }
×
3176

3177
        return dbNode.ID, node, nil
×
3178
}
3179

3180
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3181
// provided database channel row and the public keys of the two nodes
3182
// involved in the channel.
3183
func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub,
3184
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3185

×
3186
        return &models.CachedEdgeInfo{
×
3187
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
3188
                NodeKey1Bytes: node1Pub,
×
3189
                NodeKey2Bytes: node2Pub,
×
3190
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
3191
        }
×
3192
}
×
3193

3194
// buildNode constructs a LightningNode instance from the given database node
3195
// record. The node's features, addresses and extra signed fields are also
3196
// fetched from the database and set on the node.
3197
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
3198
        *models.LightningNode, error) {
×
3199

×
3200
        if dbNode.Version != int16(ProtocolV1) {
×
3201
                return nil, fmt.Errorf("unsupported node version: %d",
×
3202
                        dbNode.Version)
×
3203
        }
×
3204

3205
        var pub [33]byte
×
3206
        copy(pub[:], dbNode.PubKey)
×
3207

×
3208
        node := &models.LightningNode{
×
3209
                PubKeyBytes: pub,
×
3210
                Features:    lnwire.EmptyFeatureVector(),
×
3211
                LastUpdate:  time.Unix(0, 0),
×
3212
        }
×
3213

×
3214
        if len(dbNode.Signature) == 0 {
×
3215
                return node, nil
×
3216
        }
×
3217

3218
        node.HaveNodeAnnouncement = true
×
3219
        node.AuthSigBytes = dbNode.Signature
×
3220
        node.Alias = dbNode.Alias.String
×
3221
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3222

×
3223
        var err error
×
3224
        if dbNode.Color.Valid {
×
3225
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3226
                if err != nil {
×
3227
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3228
                                err)
×
3229
                }
×
3230
        }
3231

3232
        // Fetch the node's features.
3233
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
3234
        if err != nil {
×
3235
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3236
                        "features: %w", dbNode.ID, err)
×
3237
        }
×
3238

3239
        // Fetch the node's addresses.
3240
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
3241
        if err != nil {
×
3242
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3243
                        "addresses: %w", dbNode.ID, err)
×
3244
        }
×
3245

3246
        // Fetch the node's extra signed fields.
3247
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
3248
        if err != nil {
×
3249
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3250
                        "extra signed fields: %w", dbNode.ID, err)
×
3251
        }
×
3252

3253
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
3254
        if err != nil {
×
3255
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3256
                        "fields: %w", err)
×
3257
        }
×
3258

3259
        if len(recs) != 0 {
×
3260
                node.ExtraOpaqueData = recs
×
3261
        }
×
3262

3263
        return node, nil
×
3264
}
3265

3266
// getNodeFeatures fetches the feature bits and constructs the feature vector
3267
// for a node with the given DB ID.
3268
func getNodeFeatures(ctx context.Context, db SQLQueries,
3269
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3270

×
3271
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3272
        if err != nil {
×
3273
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3274
                        nodeID, err)
×
3275
        }
×
3276

3277
        features := lnwire.EmptyFeatureVector()
×
3278
        for _, feature := range rows {
×
3279
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3280
        }
×
3281

3282
        return features, nil
×
3283
}
3284

3285
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3286
// given DB ID.
3287
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3288
        nodeID int64) (map[uint64][]byte, error) {
×
3289

×
3290
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3291
        if err != nil {
×
3292
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
3293
                        "signed fields: %w", nodeID, err)
×
3294
        }
×
3295

3296
        extraFields := make(map[uint64][]byte)
×
3297
        for _, field := range fields {
×
3298
                extraFields[uint64(field.Type)] = field.Value
×
3299
        }
×
3300

3301
        return extraFields, nil
×
3302
}
3303

3304
// upsertNode upserts the node record into the database. If the node already
3305
// exists, then the node's information is updated. If the node doesn't exist,
3306
// then a new node is created. The node's features, addresses and extra TLV
3307
// types are also updated. The node's DB ID is returned.
3308
func upsertNode(ctx context.Context, db SQLQueries,
3309
        node *models.LightningNode) (int64, error) {
×
3310

×
3311
        params := sqlc.UpsertNodeParams{
×
3312
                Version: int16(ProtocolV1),
×
3313
                PubKey:  node.PubKeyBytes[:],
×
3314
        }
×
3315

×
3316
        if node.HaveNodeAnnouncement {
×
3317
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3318
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3319
                params.Alias = sqldb.SQLStr(node.Alias)
×
3320
                params.Signature = node.AuthSigBytes
×
3321
        }
×
3322

3323
        nodeID, err := db.UpsertNode(ctx, params)
×
3324
        if err != nil {
×
3325
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3326
                        err)
×
3327
        }
×
3328

3329
        // We can exit here if we don't have the announcement yet.
3330
        if !node.HaveNodeAnnouncement {
×
3331
                return nodeID, nil
×
3332
        }
×
3333

3334
        // Update the node's features.
3335
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3336
        if err != nil {
×
3337
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3338
        }
×
3339

3340
        // Update the node's addresses.
3341
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3342
        if err != nil {
×
3343
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3344
        }
×
3345

3346
        // Convert the flat extra opaque data into a map of TLV types to
3347
        // values.
3348
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3349
        if err != nil {
×
3350
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3351
                        err)
×
3352
        }
×
3353

3354
        // Update the node's extra signed fields.
3355
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3356
        if err != nil {
×
3357
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3358
        }
×
3359

3360
        return nodeID, nil
×
3361
}
3362

3363
// upsertNodeFeatures updates the node's features node_features table. This
3364
// includes deleting any feature bits no longer present and inserting any new
3365
// feature bits. If the feature bit does not yet exist in the features table,
3366
// then an entry is created in that table first.
3367
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3368
        features *lnwire.FeatureVector) error {
×
3369

×
3370
        // Get any existing features for the node.
×
3371
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3372
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3373
                return err
×
3374
        }
×
3375

3376
        // Copy the nodes latest set of feature bits.
3377
        newFeatures := make(map[int32]struct{})
×
3378
        if features != nil {
×
3379
                for feature := range features.Features() {
×
3380
                        newFeatures[int32(feature)] = struct{}{}
×
3381
                }
×
3382
        }
3383

3384
        // For any current feature that already exists in the DB, remove it from
3385
        // the in-memory map. For any existing feature that does not exist in
3386
        // the in-memory map, delete it from the database.
3387
        for _, feature := range existingFeatures {
×
3388
                // The feature is still present, so there are no updates to be
×
3389
                // made.
×
3390
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3391
                        delete(newFeatures, feature.FeatureBit)
×
3392
                        continue
×
3393
                }
3394

3395
                // The feature is no longer present, so we remove it from the
3396
                // database.
3397
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3398
                        NodeID:     nodeID,
×
3399
                        FeatureBit: feature.FeatureBit,
×
3400
                })
×
3401
                if err != nil {
×
3402
                        return fmt.Errorf("unable to delete node(%d) "+
×
3403
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3404
                                err)
×
3405
                }
×
3406
        }
3407

3408
        // Any remaining entries in newFeatures are new features that need to be
3409
        // added to the database for the first time.
3410
        for feature := range newFeatures {
×
3411
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3412
                        NodeID:     nodeID,
×
3413
                        FeatureBit: feature,
×
3414
                })
×
3415
                if err != nil {
×
3416
                        return fmt.Errorf("unable to insert node(%d) "+
×
3417
                                "feature(%v): %w", nodeID, feature, err)
×
3418
                }
×
3419
        }
3420

3421
        return nil
×
3422
}
3423

3424
// fetchNodeFeatures fetches the features for a node with the given public key.
3425
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3426
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3427

×
3428
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3429
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3430
                        PubKey:  nodePub[:],
×
3431
                        Version: int16(ProtocolV1),
×
3432
                },
×
3433
        )
×
3434
        if err != nil {
×
3435
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3436
                        nodePub, err)
×
3437
        }
×
3438

3439
        features := lnwire.EmptyFeatureVector()
×
3440
        for _, bit := range rows {
×
3441
                features.Set(lnwire.FeatureBit(bit))
×
3442
        }
×
3443

3444
        return features, nil
×
3445
}
3446

3447
// dbAddressType is an enum type that represents the different address types
3448
// that we store in the node_addresses table. The address type determines how
3449
// the address is to be serialised/deserialize.
3450
type dbAddressType uint8
3451

3452
const (
3453
        addressTypeIPv4   dbAddressType = 1
3454
        addressTypeIPv6   dbAddressType = 2
3455
        addressTypeTorV2  dbAddressType = 3
3456
        addressTypeTorV3  dbAddressType = 4
3457
        addressTypeOpaque dbAddressType = math.MaxInt8
3458
)
3459

3460
// upsertNodeAddresses updates the node's addresses in the database. This
3461
// includes deleting any existing addresses and inserting the new set of
3462
// addresses. The deletion is necessary since the ordering of the addresses may
3463
// change, and we need to ensure that the database reflects the latest set of
3464
// addresses so that at the time of reconstructing the node announcement, the
3465
// order is preserved and the signature over the message remains valid.
3466
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3467
        addresses []net.Addr) error {
×
3468

×
3469
        // Delete any existing addresses for the node. This is required since
×
3470
        // even if the new set of addresses is the same, the ordering may have
×
3471
        // changed for a given address type.
×
3472
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3473
        if err != nil {
×
3474
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3475
                        nodeID, err)
×
3476
        }
×
3477

3478
        // Copy the nodes latest set of addresses.
3479
        newAddresses := map[dbAddressType][]string{
×
3480
                addressTypeIPv4:   {},
×
3481
                addressTypeIPv6:   {},
×
3482
                addressTypeTorV2:  {},
×
3483
                addressTypeTorV3:  {},
×
3484
                addressTypeOpaque: {},
×
3485
        }
×
3486
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3487
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3488
        }
×
3489

3490
        for _, address := range addresses {
×
3491
                switch addr := address.(type) {
×
3492
                case *net.TCPAddr:
×
3493
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3494
                                addAddr(addressTypeIPv4, addr)
×
3495
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3496
                                addAddr(addressTypeIPv6, addr)
×
3497
                        } else {
×
3498
                                return fmt.Errorf("unhandled IP address: %v",
×
3499
                                        addr)
×
3500
                        }
×
3501

3502
                case *tor.OnionAddr:
×
3503
                        switch len(addr.OnionService) {
×
3504
                        case tor.V2Len:
×
3505
                                addAddr(addressTypeTorV2, addr)
×
3506
                        case tor.V3Len:
×
3507
                                addAddr(addressTypeTorV3, addr)
×
3508
                        default:
×
3509
                                return fmt.Errorf("invalid length for a tor " +
×
3510
                                        "address")
×
3511
                        }
3512

3513
                case *lnwire.OpaqueAddrs:
×
3514
                        addAddr(addressTypeOpaque, addr)
×
3515

3516
                default:
×
3517
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3518
                }
3519
        }
3520

3521
        // Any remaining entries in newAddresses are new addresses that need to
3522
        // be added to the database for the first time.
3523
        for addrType, addrList := range newAddresses {
×
3524
                for position, addr := range addrList {
×
3525
                        err := db.InsertNodeAddress(
×
3526
                                ctx, sqlc.InsertNodeAddressParams{
×
3527
                                        NodeID:   nodeID,
×
3528
                                        Type:     int16(addrType),
×
3529
                                        Address:  addr,
×
3530
                                        Position: int32(position),
×
3531
                                },
×
3532
                        )
×
3533
                        if err != nil {
×
3534
                                return fmt.Errorf("unable to insert "+
×
3535
                                        "node(%d) address(%v): %w", nodeID,
×
3536
                                        addr, err)
×
3537
                        }
×
3538
                }
3539
        }
3540

3541
        return nil
×
3542
}
3543

3544
// getNodeAddresses fetches the addresses for a node with the given public key.
3545
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3546
        []net.Addr, error) {
×
3547

×
3548
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
3549
        // are returned in the same order as they were inserted.
×
3550
        rows, err := db.GetNodeAddressesByPubKey(
×
3551
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
3552
                        Version: int16(ProtocolV1),
×
3553
                        PubKey:  nodePub,
×
3554
                },
×
3555
        )
×
3556
        if err != nil {
×
3557
                return false, nil, err
×
3558
        }
×
3559

3560
        // GetNodeAddressesByPubKey uses a left join so there should always be
3561
        // at least one row returned if the node exists even if it has no
3562
        // addresses.
3563
        if len(rows) == 0 {
×
3564
                return false, nil, nil
×
3565
        }
×
3566

3567
        addresses := make([]net.Addr, 0, len(rows))
×
3568
        for _, addr := range rows {
×
3569
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3570
                        continue
×
3571
                }
3572

3573
                address := addr.Address.String
×
3574

×
3575
                switch dbAddressType(addr.Type.Int16) {
×
3576
                case addressTypeIPv4:
×
3577
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3578
                        if err != nil {
×
3579
                                return false, nil, nil
×
3580
                        }
×
3581
                        tcp.IP = tcp.IP.To4()
×
3582

×
3583
                        addresses = append(addresses, tcp)
×
3584

3585
                case addressTypeIPv6:
×
3586
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3587
                        if err != nil {
×
3588
                                return false, nil, nil
×
3589
                        }
×
3590
                        addresses = append(addresses, tcp)
×
3591

3592
                case addressTypeTorV3, addressTypeTorV2:
×
3593
                        service, portStr, err := net.SplitHostPort(address)
×
3594
                        if err != nil {
×
3595
                                return false, nil, fmt.Errorf("unable to "+
×
3596
                                        "split tor v3 address: %v",
×
3597
                                        addr.Address)
×
3598
                        }
×
3599

3600
                        port, err := strconv.Atoi(portStr)
×
3601
                        if err != nil {
×
3602
                                return false, nil, err
×
3603
                        }
×
3604

3605
                        addresses = append(addresses, &tor.OnionAddr{
×
3606
                                OnionService: service,
×
3607
                                Port:         port,
×
3608
                        })
×
3609

3610
                case addressTypeOpaque:
×
3611
                        opaque, err := hex.DecodeString(address)
×
3612
                        if err != nil {
×
3613
                                return false, nil, fmt.Errorf("unable to "+
×
3614
                                        "decode opaque address: %v", addr)
×
3615
                        }
×
3616

3617
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3618
                                Payload: opaque,
×
3619
                        })
×
3620

3621
                default:
×
3622
                        return false, nil, fmt.Errorf("unknown address "+
×
3623
                                "type: %v", addr.Type)
×
3624
                }
3625
        }
3626

3627
        // If we have no addresses, then we'll return nil instead of an
3628
        // empty slice.
NEW
3629
        if len(addresses) == 0 {
×
NEW
3630
                addresses = nil
×
NEW
3631
        }
×
3632

UNCOV
3633
        return true, addresses, nil
×
3634
}
3635

3636
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3637
// database. This includes updating any existing types, inserting any new types,
3638
// and deleting any types that are no longer present.
3639
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3640
        nodeID int64, extraFields map[uint64][]byte) error {
×
3641

×
3642
        // Get any existing extra signed fields for the node.
×
3643
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3644
        if err != nil {
×
3645
                return err
×
3646
        }
×
3647

3648
        // Make a lookup map of the existing field types so that we can use it
3649
        // to keep track of any fields we should delete.
3650
        m := make(map[uint64]bool)
×
3651
        for _, field := range existingFields {
×
3652
                m[uint64(field.Type)] = true
×
3653
        }
×
3654

3655
        // For all the new fields, we'll upsert them and remove them from the
3656
        // map of existing fields.
3657
        for tlvType, value := range extraFields {
×
3658
                err = db.UpsertNodeExtraType(
×
3659
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3660
                                NodeID: nodeID,
×
3661
                                Type:   int64(tlvType),
×
3662
                                Value:  value,
×
3663
                        },
×
3664
                )
×
3665
                if err != nil {
×
3666
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3667
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3668
                }
×
3669

3670
                // Remove the field from the map of existing fields if it was
3671
                // present.
3672
                delete(m, tlvType)
×
3673
        }
3674

3675
        // For all the fields that are left in the map of existing fields, we'll
3676
        // delete them as they are no longer present in the new set of fields.
3677
        for tlvType := range m {
×
3678
                err = db.DeleteExtraNodeType(
×
3679
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3680
                                NodeID: nodeID,
×
3681
                                Type:   int64(tlvType),
×
3682
                        },
×
3683
                )
×
3684
                if err != nil {
×
3685
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3686
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3687
                }
×
3688
        }
3689

3690
        return nil
×
3691
}
3692

3693
// srcNodeInfo holds the information about the source node of the graph.
3694
type srcNodeInfo struct {
3695
        // id is the DB level ID of the source node entry in the "nodes" table.
3696
        id int64
3697

3698
        // pub is the public key of the source node.
3699
        pub route.Vertex
3700
}
3701

3702
// getSourceNode returns the DB node ID and pub key of the source node for the
3703
// specified protocol version.
3704
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3705
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3706

×
3707
        s.srcNodeMu.Lock()
×
3708
        defer s.srcNodeMu.Unlock()
×
3709

×
3710
        // If we already have the source node ID and pub key cached, then
×
3711
        // return them.
×
3712
        if info, ok := s.srcNodes[version]; ok {
×
3713
                return info.id, info.pub, nil
×
3714
        }
×
3715

3716
        var pubKey route.Vertex
×
3717

×
3718
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3719
        if err != nil {
×
3720
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3721
                        err)
×
3722
        }
×
3723

3724
        if len(nodes) == 0 {
×
3725
                return 0, pubKey, ErrSourceNodeNotSet
×
3726
        } else if len(nodes) > 1 {
×
3727
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3728
                        "protocol %s found", version)
×
3729
        }
×
3730

3731
        copy(pubKey[:], nodes[0].PubKey)
×
3732

×
3733
        s.srcNodes[version] = &srcNodeInfo{
×
3734
                id:  nodes[0].NodeID,
×
3735
                pub: pubKey,
×
3736
        }
×
3737

×
3738
        return nodes[0].NodeID, pubKey, nil
×
3739
}
3740

3741
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3742
// This then produces a map from TLV type to value. If the input is not a
3743
// valid TLV stream, then an error is returned.
3744
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3745
        r := bytes.NewReader(data)
×
3746

×
3747
        tlvStream, err := tlv.NewStream()
×
3748
        if err != nil {
×
3749
                return nil, err
×
3750
        }
×
3751

3752
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3753
        // pass it into the P2P decoding variant.
3754
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3755
        if err != nil {
×
3756
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3757
        }
×
3758
        if len(parsedTypes) == 0 {
×
3759
                return nil, nil
×
3760
        }
×
3761

3762
        records := make(map[uint64][]byte)
×
3763
        for k, v := range parsedTypes {
×
3764
                records[uint64(k)] = v
×
3765
        }
×
3766

3767
        return records, nil
×
3768
}
3769

3770
// insertChannel inserts a new channel record into the database.
3771
func insertChannel(ctx context.Context, db SQLQueries,
3772
        edge *models.ChannelEdgeInfo) error {
×
3773

×
3774
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3775

×
3776
        // Make sure that the channel doesn't already exist. We do this
×
3777
        // explicitly instead of relying on catching a unique constraint error
×
3778
        // because relying on SQL to throw that error would abort the entire
×
3779
        // batch of transactions.
×
3780
        _, err := db.GetChannelBySCID(
×
3781
                ctx, sqlc.GetChannelBySCIDParams{
×
3782
                        Scid:    chanIDB[:],
×
3783
                        Version: int16(ProtocolV1),
×
3784
                },
×
3785
        )
×
3786
        if err == nil {
×
3787
                return ErrEdgeAlreadyExist
×
3788
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3789
                return fmt.Errorf("unable to fetch channel: %w", err)
×
3790
        }
×
3791

3792
        // Make sure that at least a "shell" entry for each node is present in
3793
        // the nodes table.
3794
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3795
        if err != nil {
×
3796
                return fmt.Errorf("unable to create shell node: %w", err)
×
3797
        }
×
3798

3799
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3800
        if err != nil {
×
3801
                return fmt.Errorf("unable to create shell node: %w", err)
×
3802
        }
×
3803

3804
        var capacity sql.NullInt64
×
3805
        if edge.Capacity != 0 {
×
3806
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3807
        }
×
3808

3809
        createParams := sqlc.CreateChannelParams{
×
3810
                Version:     int16(ProtocolV1),
×
3811
                Scid:        chanIDB[:],
×
3812
                NodeID1:     node1DBID,
×
3813
                NodeID2:     node2DBID,
×
3814
                Outpoint:    edge.ChannelPoint.String(),
×
3815
                Capacity:    capacity,
×
3816
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3817
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3818
        }
×
3819

×
3820
        if edge.AuthProof != nil {
×
3821
                proof := edge.AuthProof
×
3822

×
3823
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3824
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3825
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3826
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3827
        }
×
3828

3829
        // Insert the new channel record.
3830
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3831
        if err != nil {
×
3832
                return err
×
3833
        }
×
3834

3835
        // Insert any channel features.
3836
        if len(edge.Features) != 0 {
×
3837
                chanFeatures := lnwire.NewRawFeatureVector()
×
3838
                err := chanFeatures.Decode(bytes.NewReader(edge.Features))
×
3839
                if err != nil {
×
3840
                        return err
×
3841
                }
×
3842

3843
                fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features)
×
3844
                for feature := range fv.Features() {
×
3845
                        err = db.InsertChannelFeature(
×
3846
                                ctx, sqlc.InsertChannelFeatureParams{
×
3847
                                        ChannelID:  dbChanID,
×
3848
                                        FeatureBit: int32(feature),
×
3849
                                },
×
3850
                        )
×
3851
                        if err != nil {
×
3852
                                return fmt.Errorf("unable to insert "+
×
3853
                                        "channel(%d) feature(%v): %w", dbChanID,
×
3854
                                        feature, err)
×
3855
                        }
×
3856
                }
3857
        }
3858

3859
        // Finally, insert any extra TLV fields in the channel announcement.
3860
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3861
        if err != nil {
×
3862
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
3863
                        err)
×
3864
        }
×
3865

3866
        for tlvType, value := range extra {
×
3867
                err := db.CreateChannelExtraType(
×
3868
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3869
                                ChannelID: dbChanID,
×
3870
                                Type:      int64(tlvType),
×
3871
                                Value:     value,
×
3872
                        },
×
3873
                )
×
3874
                if err != nil {
×
3875
                        return fmt.Errorf("unable to upsert channel(%d) extra "+
×
3876
                                "signed field(%v): %w", edge.ChannelID,
×
3877
                                tlvType, err)
×
3878
                }
×
3879
        }
3880

3881
        return nil
×
3882
}
3883

3884
// maybeCreateShellNode checks if a shell node entry exists for the
3885
// given public key. If it does not exist, then a new shell node entry is
3886
// created. The ID of the node is returned. A shell node only has a protocol
3887
// version and public key persisted.
3888
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3889
        pubKey route.Vertex) (int64, error) {
×
3890

×
3891
        dbNode, err := db.GetNodeByPubKey(
×
3892
                ctx, sqlc.GetNodeByPubKeyParams{
×
3893
                        PubKey:  pubKey[:],
×
3894
                        Version: int16(ProtocolV1),
×
3895
                },
×
3896
        )
×
3897
        // The node exists. Return the ID.
×
3898
        if err == nil {
×
3899
                return dbNode.ID, nil
×
3900
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3901
                return 0, err
×
3902
        }
×
3903

3904
        // Otherwise, the node does not exist, so we create a shell entry for
3905
        // it.
3906
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3907
                Version: int16(ProtocolV1),
×
3908
                PubKey:  pubKey[:],
×
3909
        })
×
3910
        if err != nil {
×
3911
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3912
        }
×
3913

3914
        return id, nil
×
3915
}
3916

3917
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3918
// the database. This includes deleting any existing types and then inserting
3919
// the new types.
3920
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3921
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3922

×
3923
        // Delete all existing extra signed fields for the channel policy.
×
3924
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3925
        if err != nil {
×
3926
                return fmt.Errorf("unable to delete "+
×
3927
                        "existing policy extra signed fields for policy %d: %w",
×
3928
                        chanPolicyID, err)
×
3929
        }
×
3930

3931
        // Insert all new extra signed fields for the channel policy.
3932
        for tlvType, value := range extraFields {
×
3933
                err = db.InsertChanPolicyExtraType(
×
3934
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3935
                                ChannelPolicyID: chanPolicyID,
×
3936
                                Type:            int64(tlvType),
×
3937
                                Value:           value,
×
3938
                        },
×
3939
                )
×
3940
                if err != nil {
×
3941
                        return fmt.Errorf("unable to insert "+
×
3942
                                "channel_policy(%d) extra signed field(%v): %w",
×
3943
                                chanPolicyID, tlvType, err)
×
3944
                }
×
3945
        }
3946

3947
        return nil
×
3948
}
3949

3950
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3951
// provided dbChanRow and also fetches any other required information
3952
// to construct the edge info.
3953
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
3954
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1,
3955
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3956

×
3957
        if dbChan.Version != int16(ProtocolV1) {
×
3958
                return nil, fmt.Errorf("unsupported channel version: %d",
×
3959
                        dbChan.Version)
×
3960
        }
×
3961

3962
        fv, extras, err := getChanFeaturesAndExtras(
×
3963
                ctx, db, dbChanID,
×
3964
        )
×
3965
        if err != nil {
×
3966
                return nil, err
×
3967
        }
×
3968

3969
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
3970
        if err != nil {
×
3971
                return nil, err
×
3972
        }
×
3973

3974
        var featureBuf bytes.Buffer
×
3975
        if err := fv.Encode(&featureBuf); err != nil {
×
3976
                return nil, fmt.Errorf("unable to encode features: %w", err)
×
3977
        }
×
3978

3979
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3980
        if err != nil {
×
3981
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3982
                        "fields: %w", err)
×
3983
        }
×
3984
        if recs == nil {
×
3985
                recs = make([]byte, 0)
×
3986
        }
×
3987

3988
        var btcKey1, btcKey2 route.Vertex
×
3989
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
3990
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
3991

×
3992
        channel := &models.ChannelEdgeInfo{
×
3993
                ChainHash:        chain,
×
3994
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
3995
                NodeKey1Bytes:    node1,
×
3996
                NodeKey2Bytes:    node2,
×
3997
                BitcoinKey1Bytes: btcKey1,
×
3998
                BitcoinKey2Bytes: btcKey2,
×
3999
                ChannelPoint:     *op,
×
4000
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4001
                Features:         featureBuf.Bytes(),
×
4002
                ExtraOpaqueData:  recs,
×
4003
        }
×
4004

×
4005
        // We always set all the signatures at the same time, so we can
×
4006
        // safely check if one signature is present to determine if we have the
×
4007
        // rest of the signatures for the auth proof.
×
4008
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4009
                channel.AuthProof = &models.ChannelAuthProof{
×
4010
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4011
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4012
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4013
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4014
                }
×
4015
        }
×
4016

4017
        return channel, nil
×
4018
}
4019

4020
// buildNodeVertices is a helper that converts raw node public keys
4021
// into route.Vertex instances.
4022
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4023
        route.Vertex, error) {
×
4024

×
4025
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4026
        if err != nil {
×
4027
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4028
                        "create vertex from node1 pubkey: %w", err)
×
4029
        }
×
4030

4031
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4032
        if err != nil {
×
4033
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4034
                        "create vertex from node2 pubkey: %w", err)
×
4035
        }
×
4036

4037
        return node1Vertex, node2Vertex, nil
×
4038
}
4039

4040
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
4041
// for a channel with the given ID.
4042
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
4043
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
4044

×
4045
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
4046
        if err != nil {
×
4047
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
4048
                        "features and extras: %w", err)
×
4049
        }
×
4050

4051
        var (
×
4052
                fv     = lnwire.EmptyFeatureVector()
×
4053
                extras = make(map[uint64][]byte)
×
4054
        )
×
4055
        for _, row := range rows {
×
4056
                if row.IsFeature {
×
4057
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
4058

×
4059
                        continue
×
4060
                }
4061

4062
                tlvType, ok := row.ExtraKey.(int64)
×
4063
                if !ok {
×
4064
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4065
                                "TLV type: %T", row.ExtraKey)
×
4066
                }
×
4067

4068
                valueBytes, ok := row.Value.([]byte)
×
4069
                if !ok {
×
4070
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4071
                                "Value: %T", row.Value)
×
4072
                }
×
4073

4074
                extras[uint64(tlvType)] = valueBytes
×
4075
        }
4076

4077
        return fv, extras, nil
×
4078
}
4079

4080
// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves
4081
// all the extra info required to build the complete models.ChannelEdgePolicy
4082
// types. It returns two policies, which may be nil if the provided
4083
// sqlc.ChannelPolicy records are nil.
4084
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4085
        dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1,
4086
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4087
        *models.ChannelEdgePolicy, error) {
×
4088

×
4089
        if dbPol1 == nil && dbPol2 == nil {
×
4090
                return nil, nil, nil
×
4091
        }
×
4092

4093
        var (
×
4094
                policy1ID int64
×
4095
                policy2ID int64
×
4096
        )
×
4097
        if dbPol1 != nil {
×
4098
                policy1ID = dbPol1.ID
×
4099
        }
×
4100
        if dbPol2 != nil {
×
4101
                policy2ID = dbPol2.ID
×
4102
        }
×
4103
        rows, err := db.GetChannelPolicyExtraTypes(
×
4104
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
4105
                        ID:   policy1ID,
×
4106
                        ID_2: policy2ID,
×
4107
                },
×
4108
        )
×
4109
        if err != nil {
×
4110
                return nil, nil, err
×
4111
        }
×
4112

4113
        var (
×
4114
                dbPol1Extras = make(map[uint64][]byte)
×
4115
                dbPol2Extras = make(map[uint64][]byte)
×
4116
        )
×
4117
        for _, row := range rows {
×
4118
                switch row.PolicyID {
×
4119
                case policy1ID:
×
4120
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
4121
                case policy2ID:
×
4122
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
4123
                default:
×
4124
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
4125
                                "in row: %v", row.PolicyID, row)
×
4126
                }
4127
        }
4128

4129
        var pol1, pol2 *models.ChannelEdgePolicy
×
4130
        if dbPol1 != nil {
×
4131
                pol1, err = buildChanPolicy(
×
4132
                        *dbPol1, channelID, dbPol1Extras, node2,
×
4133
                )
×
4134
                if err != nil {
×
4135
                        return nil, nil, err
×
4136
                }
×
4137
        }
4138
        if dbPol2 != nil {
×
4139
                pol2, err = buildChanPolicy(
×
4140
                        *dbPol2, channelID, dbPol2Extras, node1,
×
4141
                )
×
4142
                if err != nil {
×
4143
                        return nil, nil, err
×
4144
                }
×
4145
        }
4146

4147
        return pol1, pol2, nil
×
4148
}
4149

4150
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4151
// provided sqlc.ChannelPolicy and other required information.
4152
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
4153
        extras map[uint64][]byte,
4154
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4155

×
4156
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4157
        if err != nil {
×
4158
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4159
                        "fields: %w", err)
×
4160
        }
×
4161

4162
        var inboundFee fn.Option[lnwire.Fee]
×
4163
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4164
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4165

×
4166
                inboundFee = fn.Some(lnwire.Fee{
×
4167
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4168
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4169
                })
×
4170
        }
×
4171

4172
        return &models.ChannelEdgePolicy{
×
4173
                SigBytes:  dbPolicy.Signature,
×
4174
                ChannelID: channelID,
×
4175
                LastUpdate: time.Unix(
×
4176
                        dbPolicy.LastUpdate.Int64, 0,
×
4177
                ),
×
4178
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4179
                        dbPolicy.MessageFlags,
×
4180
                ),
×
4181
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4182
                        dbPolicy.ChannelFlags,
×
4183
                ),
×
4184
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4185
                MinHTLC: lnwire.MilliSatoshi(
×
4186
                        dbPolicy.MinHtlcMsat,
×
4187
                ),
×
4188
                MaxHTLC: lnwire.MilliSatoshi(
×
4189
                        dbPolicy.MaxHtlcMsat.Int64,
×
4190
                ),
×
4191
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4192
                        dbPolicy.BaseFeeMsat,
×
4193
                ),
×
4194
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4195
                ToNode:                    toNode,
×
4196
                InboundFee:                inboundFee,
×
4197
                ExtraOpaqueData:           recs,
×
4198
        }, nil
×
4199
}
4200

4201
// buildNodes builds the models.LightningNode instances for the
4202
// given row which is expected to be a sqlc type that contains node information.
4203
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4204
        dbNode2 sqlc.Node) (*models.LightningNode, *models.LightningNode,
4205
        error) {
×
4206

×
4207
        node1, err := buildNode(ctx, db, &dbNode1)
×
4208
        if err != nil {
×
4209
                return nil, nil, err
×
4210
        }
×
4211

4212
        node2, err := buildNode(ctx, db, &dbNode2)
×
4213
        if err != nil {
×
4214
                return nil, nil, err
×
4215
        }
×
4216

4217
        return node1, node2, nil
×
4218
}
4219

4220
// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give
4221
// row which is expected to be a sqlc type that contains channel policy
4222
// information. It returns two policies, which may be nil if the policy
4223
// information is not present in the row.
4224
//
4225
//nolint:ll,dupl,funlen
4226
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
4227
        error) {
×
4228

×
4229
        var policy1, policy2 *sqlc.ChannelPolicy
×
4230
        switch r := row.(type) {
×
4231
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4232
                if r.Policy1ID.Valid {
×
4233
                        policy1 = &sqlc.ChannelPolicy{
×
4234
                                ID:                      r.Policy1ID.Int64,
×
4235
                                Version:                 r.Policy1Version.Int16,
×
4236
                                ChannelID:               r.Channel.ID,
×
4237
                                NodeID:                  r.Policy1NodeID.Int64,
×
4238
                                Timelock:                r.Policy1Timelock.Int32,
×
4239
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4240
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4241
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4242
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4243
                                LastUpdate:              r.Policy1LastUpdate,
×
4244
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4245
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4246
                                Disabled:                r.Policy1Disabled,
×
4247
                                MessageFlags:            r.Policy1MessageFlags,
×
4248
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4249
                                Signature:               r.Policy1Signature,
×
4250
                        }
×
4251
                }
×
4252
                if r.Policy2ID.Valid {
×
4253
                        policy2 = &sqlc.ChannelPolicy{
×
4254
                                ID:                      r.Policy2ID.Int64,
×
4255
                                Version:                 r.Policy2Version.Int16,
×
4256
                                ChannelID:               r.Channel.ID,
×
4257
                                NodeID:                  r.Policy2NodeID.Int64,
×
4258
                                Timelock:                r.Policy2Timelock.Int32,
×
4259
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4260
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4261
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4262
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4263
                                LastUpdate:              r.Policy2LastUpdate,
×
4264
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4265
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4266
                                Disabled:                r.Policy2Disabled,
×
4267
                                MessageFlags:            r.Policy2MessageFlags,
×
4268
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4269
                                Signature:               r.Policy2Signature,
×
4270
                        }
×
4271
                }
×
4272

4273
                return policy1, policy2, nil
×
4274

4275
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4276
                if r.Policy1ID.Valid {
×
4277
                        policy1 = &sqlc.ChannelPolicy{
×
4278
                                ID:                      r.Policy1ID.Int64,
×
4279
                                Version:                 r.Policy1Version.Int16,
×
4280
                                ChannelID:               r.Channel.ID,
×
4281
                                NodeID:                  r.Policy1NodeID.Int64,
×
4282
                                Timelock:                r.Policy1Timelock.Int32,
×
4283
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4284
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4285
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4286
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4287
                                LastUpdate:              r.Policy1LastUpdate,
×
4288
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4289
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4290
                                Disabled:                r.Policy1Disabled,
×
4291
                                MessageFlags:            r.Policy1MessageFlags,
×
4292
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4293
                                Signature:               r.Policy1Signature,
×
4294
                        }
×
4295
                }
×
4296
                if r.Policy2ID.Valid {
×
4297
                        policy2 = &sqlc.ChannelPolicy{
×
4298
                                ID:                      r.Policy2ID.Int64,
×
4299
                                Version:                 r.Policy2Version.Int16,
×
4300
                                ChannelID:               r.Channel.ID,
×
4301
                                NodeID:                  r.Policy2NodeID.Int64,
×
4302
                                Timelock:                r.Policy2Timelock.Int32,
×
4303
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4304
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4305
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4306
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4307
                                LastUpdate:              r.Policy2LastUpdate,
×
4308
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4309
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4310
                                Disabled:                r.Policy2Disabled,
×
4311
                                MessageFlags:            r.Policy2MessageFlags,
×
4312
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4313
                                Signature:               r.Policy2Signature,
×
4314
                        }
×
4315
                }
×
4316

4317
                return policy1, policy2, nil
×
4318

4319
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4320
                if r.Policy1ID.Valid {
×
4321
                        policy1 = &sqlc.ChannelPolicy{
×
4322
                                ID:                      r.Policy1ID.Int64,
×
4323
                                Version:                 r.Policy1Version.Int16,
×
4324
                                ChannelID:               r.Channel.ID,
×
4325
                                NodeID:                  r.Policy1NodeID.Int64,
×
4326
                                Timelock:                r.Policy1Timelock.Int32,
×
4327
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4328
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4329
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4330
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4331
                                LastUpdate:              r.Policy1LastUpdate,
×
4332
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4333
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4334
                                Disabled:                r.Policy1Disabled,
×
4335
                                MessageFlags:            r.Policy1MessageFlags,
×
4336
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4337
                                Signature:               r.Policy1Signature,
×
4338
                        }
×
4339
                }
×
4340
                if r.Policy2ID.Valid {
×
4341
                        policy2 = &sqlc.ChannelPolicy{
×
4342
                                ID:                      r.Policy2ID.Int64,
×
4343
                                Version:                 r.Policy2Version.Int16,
×
4344
                                ChannelID:               r.Channel.ID,
×
4345
                                NodeID:                  r.Policy2NodeID.Int64,
×
4346
                                Timelock:                r.Policy2Timelock.Int32,
×
4347
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4348
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4349
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4350
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4351
                                LastUpdate:              r.Policy2LastUpdate,
×
4352
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4353
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4354
                                Disabled:                r.Policy2Disabled,
×
4355
                                MessageFlags:            r.Policy2MessageFlags,
×
4356
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4357
                                Signature:               r.Policy2Signature,
×
4358
                        }
×
4359
                }
×
4360

4361
                return policy1, policy2, nil
×
4362

4363
        case sqlc.ListChannelsByNodeIDRow:
×
4364
                if r.Policy1ID.Valid {
×
4365
                        policy1 = &sqlc.ChannelPolicy{
×
4366
                                ID:                      r.Policy1ID.Int64,
×
4367
                                Version:                 r.Policy1Version.Int16,
×
4368
                                ChannelID:               r.Channel.ID,
×
4369
                                NodeID:                  r.Policy1NodeID.Int64,
×
4370
                                Timelock:                r.Policy1Timelock.Int32,
×
4371
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4372
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4373
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4374
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4375
                                LastUpdate:              r.Policy1LastUpdate,
×
4376
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4377
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4378
                                Disabled:                r.Policy1Disabled,
×
4379
                                MessageFlags:            r.Policy1MessageFlags,
×
4380
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4381
                                Signature:               r.Policy1Signature,
×
4382
                        }
×
4383
                }
×
4384
                if r.Policy2ID.Valid {
×
4385
                        policy2 = &sqlc.ChannelPolicy{
×
4386
                                ID:                      r.Policy2ID.Int64,
×
4387
                                Version:                 r.Policy2Version.Int16,
×
4388
                                ChannelID:               r.Channel.ID,
×
4389
                                NodeID:                  r.Policy2NodeID.Int64,
×
4390
                                Timelock:                r.Policy2Timelock.Int32,
×
4391
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4392
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4393
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4394
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4395
                                LastUpdate:              r.Policy2LastUpdate,
×
4396
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4397
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4398
                                Disabled:                r.Policy2Disabled,
×
4399
                                MessageFlags:            r.Policy2MessageFlags,
×
4400
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4401
                                Signature:               r.Policy2Signature,
×
4402
                        }
×
4403
                }
×
4404

4405
                return policy1, policy2, nil
×
4406

4407
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4408
                if r.Policy1ID.Valid {
×
4409
                        policy1 = &sqlc.ChannelPolicy{
×
4410
                                ID:                      r.Policy1ID.Int64,
×
4411
                                Version:                 r.Policy1Version.Int16,
×
4412
                                ChannelID:               r.Channel.ID,
×
4413
                                NodeID:                  r.Policy1NodeID.Int64,
×
4414
                                Timelock:                r.Policy1Timelock.Int32,
×
4415
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4416
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4417
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4418
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4419
                                LastUpdate:              r.Policy1LastUpdate,
×
4420
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4421
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4422
                                Disabled:                r.Policy1Disabled,
×
4423
                                MessageFlags:            r.Policy1MessageFlags,
×
4424
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4425
                                Signature:               r.Policy1Signature,
×
4426
                        }
×
4427
                }
×
4428
                if r.Policy2ID.Valid {
×
4429
                        policy2 = &sqlc.ChannelPolicy{
×
4430
                                ID:                      r.Policy2ID.Int64,
×
4431
                                Version:                 r.Policy2Version.Int16,
×
4432
                                ChannelID:               r.Channel.ID,
×
4433
                                NodeID:                  r.Policy2NodeID.Int64,
×
4434
                                Timelock:                r.Policy2Timelock.Int32,
×
4435
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4436
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4437
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4438
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4439
                                LastUpdate:              r.Policy2LastUpdate,
×
4440
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4441
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4442
                                Disabled:                r.Policy2Disabled,
×
4443
                                MessageFlags:            r.Policy2MessageFlags,
×
4444
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4445
                                Signature:               r.Policy2Signature,
×
4446
                        }
×
4447
                }
×
4448

4449
                return policy1, policy2, nil
×
4450
        default:
×
4451
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4452
                        "extractChannelPolicies: %T", r)
×
4453
        }
4454
}
4455

4456
// channelIDToBytes converts a channel ID (SCID) to a byte array
4457
// representation.
4458
func channelIDToBytes(channelID uint64) [8]byte {
×
4459
        var chanIDB [8]byte
×
4460
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4461

×
4462
        return chanIDB
×
4463
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc