• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16305584530

15 Jul 2025 10:21PM UTC coverage: 67.321% (+0.01%) from 67.307%
16305584530

push

github

web-flow
Merge pull request #10080 from ellemouton/graphPrefixTables

sqldb+graph/db: prefix graph SQL objects with "graph_"

0 of 211 new or added lines in 3 files covered. (0.0%)

58 existing lines in 19 files now uncovered.

135405 of 201132 relevant lines covered (67.32%)

21775.86 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
72
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
73
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
74

75
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
96
        GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error)
97
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
98
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
99
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
100
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
101
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
102
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
103
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
104
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
105
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
106
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
107
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
108
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
109
        DeleteChannel(ctx context.Context, id int64) error
110

111
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
112
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
113

114
        /*
115
                Channel Policy table queries.
116
        */
117
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
118
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
119
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
120

121
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
122
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
123
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
124

125
        /*
126
                Zombie index queries.
127
        */
128
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
129
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
130
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
131
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
132
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
133

134
        /*
135
                Prune log table queries.
136
        */
137
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
138
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
139
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
140
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
141

142
        /*
143
                Closed SCID table queries.
144
        */
145
        InsertClosedChannel(ctx context.Context, scid []byte) error
146
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
147
}
148

149
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
150
// database operations.
151
type BatchedSQLQueries interface {
152
        SQLQueries
153
        sqldb.BatchedTx[SQLQueries]
154
}
155

156
// SQLStore is an implementation of the V1Store interface that uses a SQL
157
// database as the backend.
158
type SQLStore struct {
159
        cfg *SQLStoreConfig
160
        db  BatchedSQLQueries
161

162
        // cacheMu guards all caches (rejectCache and chanCache). If
163
        // this mutex will be acquired at the same time as the DB mutex then
164
        // the cacheMu MUST be acquired first to prevent deadlock.
165
        cacheMu     sync.RWMutex
166
        rejectCache *rejectCache
167
        chanCache   *channelCache
168

169
        chanScheduler batch.Scheduler[SQLQueries]
170
        nodeScheduler batch.Scheduler[SQLQueries]
171

172
        srcNodes  map[ProtocolVersion]*srcNodeInfo
173
        srcNodeMu sync.Mutex
174
}
175

176
// A compile-time assertion to ensure that SQLStore implements the V1Store
177
// interface.
178
var _ V1Store = (*SQLStore)(nil)
179

180
// SQLStoreConfig holds the configuration for the SQLStore.
181
type SQLStoreConfig struct {
182
        // ChainHash is the genesis hash for the chain that all the gossip
183
        // messages in this store are aimed at.
184
        ChainHash chainhash.Hash
185
}
186

187
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
188
// storage backend.
189
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
190
        options ...StoreOptionModifier) (*SQLStore, error) {
×
191

×
192
        opts := DefaultOptions()
×
193
        for _, o := range options {
×
194
                o(opts)
×
195
        }
×
196

197
        if opts.NoMigration {
×
198
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
199
                        "supported for SQL stores")
×
200
        }
×
201

202
        s := &SQLStore{
×
203
                cfg:         cfg,
×
204
                db:          db,
×
205
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
206
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
207
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
208
        }
×
209

×
210
        s.chanScheduler = batch.NewTimeScheduler(
×
211
                db, &s.cacheMu, opts.BatchCommitInterval,
×
212
        )
×
213
        s.nodeScheduler = batch.NewTimeScheduler(
×
214
                db, nil, opts.BatchCommitInterval,
×
215
        )
×
216

×
217
        return s, nil
×
218
}
219

220
// AddLightningNode adds a vertex/node to the graph database. If the node is not
221
// in the database from before, this will add a new, unconnected one to the
222
// graph. If it is present from before, this will update that node's
223
// information.
224
//
225
// NOTE: part of the V1Store interface.
226
func (s *SQLStore) AddLightningNode(ctx context.Context,
227
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
228

×
229
        r := &batch.Request[SQLQueries]{
×
230
                Opts: batch.NewSchedulerOptions(opts...),
×
231
                Do: func(queries SQLQueries) error {
×
232
                        _, err := upsertNode(ctx, queries, node)
×
233
                        return err
×
234
                },
×
235
        }
236

237
        return s.nodeScheduler.Execute(ctx, r)
×
238
}
239

240
// FetchLightningNode attempts to look up a target node by its identity public
241
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
242
// returned.
243
//
244
// NOTE: part of the V1Store interface.
245
func (s *SQLStore) FetchLightningNode(ctx context.Context,
246
        pubKey route.Vertex) (*models.LightningNode, error) {
×
247

×
248
        var node *models.LightningNode
×
249
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
250
                var err error
×
251
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
252

×
253
                return err
×
254
        }, sqldb.NoOpReset)
×
255
        if err != nil {
×
256
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
257
        }
×
258

259
        return node, nil
×
260
}
261

262
// HasLightningNode determines if the graph has a vertex identified by the
263
// target node identity public key. If the node exists in the database, a
264
// timestamp of when the data for the node was lasted updated is returned along
265
// with a true boolean. Otherwise, an empty time.Time is returned with a false
266
// boolean.
267
//
268
// NOTE: part of the V1Store interface.
269
func (s *SQLStore) HasLightningNode(ctx context.Context,
270
        pubKey [33]byte) (time.Time, bool, error) {
×
271

×
272
        var (
×
273
                exists     bool
×
274
                lastUpdate time.Time
×
275
        )
×
276
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
277
                dbNode, err := db.GetNodeByPubKey(
×
278
                        ctx, sqlc.GetNodeByPubKeyParams{
×
279
                                Version: int16(ProtocolV1),
×
280
                                PubKey:  pubKey[:],
×
281
                        },
×
282
                )
×
283
                if errors.Is(err, sql.ErrNoRows) {
×
284
                        return nil
×
285
                } else if err != nil {
×
286
                        return fmt.Errorf("unable to fetch node: %w", err)
×
287
                }
×
288

289
                exists = true
×
290

×
291
                if dbNode.LastUpdate.Valid {
×
292
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
293
                }
×
294

295
                return nil
×
296
        }, sqldb.NoOpReset)
297
        if err != nil {
×
298
                return time.Time{}, false,
×
299
                        fmt.Errorf("unable to fetch node: %w", err)
×
300
        }
×
301

302
        return lastUpdate, exists, nil
×
303
}
304

305
// AddrsForNode returns all known addresses for the target node public key
306
// that the graph DB is aware of. The returned boolean indicates if the
307
// given node is unknown to the graph DB or not.
308
//
309
// NOTE: part of the V1Store interface.
310
func (s *SQLStore) AddrsForNode(ctx context.Context,
311
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
312

×
313
        var (
×
314
                addresses []net.Addr
×
315
                known     bool
×
316
        )
×
317
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
318
                var err error
×
319
                known, addresses, err = getNodeAddresses(
×
320
                        ctx, db, nodePub.SerializeCompressed(),
×
321
                )
×
322
                if err != nil {
×
323
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
324
                                err)
×
325
                }
×
326

327
                return nil
×
328
        }, sqldb.NoOpReset)
329
        if err != nil {
×
330
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
331
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
332
        }
×
333

334
        return known, addresses, nil
×
335
}
336

337
// DeleteLightningNode starts a new database transaction to remove a vertex/node
338
// from the database according to the node's public key.
339
//
340
// NOTE: part of the V1Store interface.
341
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
342
        pubKey route.Vertex) error {
×
343

×
344
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
345
                res, err := db.DeleteNodeByPubKey(
×
346
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
347
                                Version: int16(ProtocolV1),
×
348
                                PubKey:  pubKey[:],
×
349
                        },
×
350
                )
×
351
                if err != nil {
×
352
                        return err
×
353
                }
×
354

355
                rows, err := res.RowsAffected()
×
356
                if err != nil {
×
357
                        return err
×
358
                }
×
359

360
                if rows == 0 {
×
361
                        return ErrGraphNodeNotFound
×
362
                } else if rows > 1 {
×
363
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
364
                }
×
365

366
                return err
×
367
        }, sqldb.NoOpReset)
368
        if err != nil {
×
369
                return fmt.Errorf("unable to delete node: %w", err)
×
370
        }
×
371

372
        return nil
×
373
}
374

375
// FetchNodeFeatures returns the features of the given node. If no features are
376
// known for the node, an empty feature vector is returned.
377
//
378
// NOTE: this is part of the graphdb.NodeTraverser interface.
379
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
380
        *lnwire.FeatureVector, error) {
×
381

×
382
        ctx := context.TODO()
×
383

×
384
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
385
}
×
386

387
// DisabledChannelIDs returns the channel ids of disabled channels.
388
// A channel is disabled when two of the associated ChanelEdgePolicies
389
// have their disabled bit on.
390
//
391
// NOTE: part of the V1Store interface.
392
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
393
        var (
×
394
                ctx     = context.TODO()
×
395
                chanIDs []uint64
×
396
        )
×
397
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
398
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
399
                if err != nil {
×
400
                        return fmt.Errorf("unable to fetch disabled "+
×
401
                                "channels: %w", err)
×
402
                }
×
403

404
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
405

×
406
                return nil
×
407
        }, sqldb.NoOpReset)
408
        if err != nil {
×
409
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
410
                        err)
×
411
        }
×
412

413
        return chanIDs, nil
×
414
}
415

416
// LookupAlias attempts to return the alias as advertised by the target node.
417
//
418
// NOTE: part of the V1Store interface.
419
func (s *SQLStore) LookupAlias(ctx context.Context,
420
        pub *btcec.PublicKey) (string, error) {
×
421

×
422
        var alias string
×
423
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
424
                dbNode, err := db.GetNodeByPubKey(
×
425
                        ctx, sqlc.GetNodeByPubKeyParams{
×
426
                                Version: int16(ProtocolV1),
×
427
                                PubKey:  pub.SerializeCompressed(),
×
428
                        },
×
429
                )
×
430
                if errors.Is(err, sql.ErrNoRows) {
×
431
                        return ErrNodeAliasNotFound
×
432
                } else if err != nil {
×
433
                        return fmt.Errorf("unable to fetch node: %w", err)
×
434
                }
×
435

436
                if !dbNode.Alias.Valid {
×
437
                        return ErrNodeAliasNotFound
×
438
                }
×
439

440
                alias = dbNode.Alias.String
×
441

×
442
                return nil
×
443
        }, sqldb.NoOpReset)
444
        if err != nil {
×
445
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
446
        }
×
447

448
        return alias, nil
×
449
}
450

451
// SourceNode returns the source node of the graph. The source node is treated
452
// as the center node within a star-graph. This method may be used to kick off
453
// a path finding algorithm in order to explore the reachability of another
454
// node based off the source node.
455
//
456
// NOTE: part of the V1Store interface.
457
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
458
        error) {
×
459

×
460
        var node *models.LightningNode
×
461
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
462
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
463
                if err != nil {
×
464
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
465
                                err)
×
466
                }
×
467

468
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
469

×
470
                return err
×
471
        }, sqldb.NoOpReset)
472
        if err != nil {
×
473
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
474
        }
×
475

476
        return node, nil
×
477
}
478

479
// SetSourceNode sets the source node within the graph database. The source
480
// node is to be used as the center of a star-graph within path finding
481
// algorithms.
482
//
483
// NOTE: part of the V1Store interface.
484
func (s *SQLStore) SetSourceNode(ctx context.Context,
485
        node *models.LightningNode) error {
×
486

×
487
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
488
                id, err := upsertNode(ctx, db, node)
×
489
                if err != nil {
×
490
                        return fmt.Errorf("unable to upsert source node: %w",
×
491
                                err)
×
492
                }
×
493

494
                // Make sure that if a source node for this version is already
495
                // set, then the ID is the same as the one we are about to set.
496
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
497
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
498
                        return fmt.Errorf("unable to fetch source node: %w",
×
499
                                err)
×
500
                } else if err == nil {
×
501
                        if dbSourceNodeID != id {
×
502
                                return fmt.Errorf("v1 source node already "+
×
503
                                        "set to a different node: %d vs %d",
×
504
                                        dbSourceNodeID, id)
×
505
                        }
×
506

507
                        return nil
×
508
                }
509

510
                return db.AddSourceNode(ctx, id)
×
511
        }, sqldb.NoOpReset)
512
}
513

514
// NodeUpdatesInHorizon returns all the known lightning node which have an
515
// update timestamp within the passed range. This method can be used by two
516
// nodes to quickly determine if they have the same set of up to date node
517
// announcements.
518
//
519
// NOTE: This is part of the V1Store interface.
520
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
521
        endTime time.Time) ([]models.LightningNode, error) {
×
522

×
523
        ctx := context.TODO()
×
524

×
525
        var nodes []models.LightningNode
×
526
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
527
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
528
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
529
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
530
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
531
                        },
×
532
                )
×
533
                if err != nil {
×
534
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
535
                }
×
536

537
                for _, dbNode := range dbNodes {
×
538
                        node, err := buildNode(ctx, db, &dbNode)
×
539
                        if err != nil {
×
540
                                return fmt.Errorf("unable to build node: %w",
×
541
                                        err)
×
542
                        }
×
543

544
                        nodes = append(nodes, *node)
×
545
                }
546

547
                return nil
×
548
        }, sqldb.NoOpReset)
549
        if err != nil {
×
550
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
551
        }
×
552

553
        return nodes, nil
×
554
}
555

556
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
557
// undirected edge from the two target nodes are created. The information stored
558
// denotes the static attributes of the channel, such as the channelID, the keys
559
// involved in creation of the channel, and the set of features that the channel
560
// supports. The chanPoint and chanID are used to uniquely identify the edge
561
// globally within the database.
562
//
563
// NOTE: part of the V1Store interface.
564
func (s *SQLStore) AddChannelEdge(ctx context.Context,
565
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
566

×
567
        var alreadyExists bool
×
568
        r := &batch.Request[SQLQueries]{
×
569
                Opts: batch.NewSchedulerOptions(opts...),
×
570
                Reset: func() {
×
571
                        alreadyExists = false
×
572
                },
×
573
                Do: func(tx SQLQueries) error {
×
574
                        _, err := insertChannel(ctx, tx, edge)
×
575

×
576
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
577
                        // succeed, but propagate the error via local state.
×
578
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
579
                                alreadyExists = true
×
580
                                return nil
×
581
                        }
×
582

583
                        return err
×
584
                },
585
                OnCommit: func(err error) error {
×
586
                        switch {
×
587
                        case err != nil:
×
588
                                return err
×
589
                        case alreadyExists:
×
590
                                return ErrEdgeAlreadyExist
×
591
                        default:
×
592
                                s.rejectCache.remove(edge.ChannelID)
×
593
                                s.chanCache.remove(edge.ChannelID)
×
594
                                return nil
×
595
                        }
596
                },
597
        }
598

599
        return s.chanScheduler.Execute(ctx, r)
×
600
}
601

602
// HighestChanID returns the "highest" known channel ID in the channel graph.
603
// This represents the "newest" channel from the PoV of the chain. This method
604
// can be used by peers to quickly determine if their graphs are in sync.
605
//
606
// NOTE: This is part of the V1Store interface.
607
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
608
        var highestChanID uint64
×
609
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
610
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
611
                if errors.Is(err, sql.ErrNoRows) {
×
612
                        return nil
×
613
                } else if err != nil {
×
614
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
615
                                err)
×
616
                }
×
617

618
                highestChanID = byteOrder.Uint64(chanID)
×
619

×
620
                return nil
×
621
        }, sqldb.NoOpReset)
622
        if err != nil {
×
623
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
624
        }
×
625

626
        return highestChanID, nil
×
627
}
628

629
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
630
// within the database for the referenced channel. The `flags` attribute within
631
// the ChannelEdgePolicy determines which of the directed edges are being
632
// updated. If the flag is 1, then the first node's information is being
633
// updated, otherwise it's the second node's information. The node ordering is
634
// determined by the lexicographical ordering of the identity public keys of the
635
// nodes on either side of the channel.
636
//
637
// NOTE: part of the V1Store interface.
638
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
639
        edge *models.ChannelEdgePolicy,
640
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
641

×
642
        var (
×
643
                isUpdate1    bool
×
644
                edgeNotFound bool
×
645
                from, to     route.Vertex
×
646
        )
×
647

×
648
        r := &batch.Request[SQLQueries]{
×
649
                Opts: batch.NewSchedulerOptions(opts...),
×
650
                Reset: func() {
×
651
                        isUpdate1 = false
×
652
                        edgeNotFound = false
×
653
                },
×
654
                Do: func(tx SQLQueries) error {
×
655
                        var err error
×
656
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
657
                                ctx, tx, edge,
×
658
                        )
×
659
                        if err != nil {
×
660
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
661
                        }
×
662

663
                        // Silence ErrEdgeNotFound so that the batch can
664
                        // succeed, but propagate the error via local state.
665
                        if errors.Is(err, ErrEdgeNotFound) {
×
666
                                edgeNotFound = true
×
667
                                return nil
×
668
                        }
×
669

670
                        return err
×
671
                },
672
                OnCommit: func(err error) error {
×
673
                        switch {
×
674
                        case err != nil:
×
675
                                return err
×
676
                        case edgeNotFound:
×
677
                                return ErrEdgeNotFound
×
678
                        default:
×
679
                                s.updateEdgeCache(edge, isUpdate1)
×
680
                                return nil
×
681
                        }
682
                },
683
        }
684

685
        err := s.chanScheduler.Execute(ctx, r)
×
686

×
687
        return from, to, err
×
688
}
689

690
// updateEdgeCache updates our reject and channel caches with the new
691
// edge policy information.
692
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
693
        isUpdate1 bool) {
×
694

×
695
        // If an entry for this channel is found in reject cache, we'll modify
×
696
        // the entry with the updated timestamp for the direction that was just
×
697
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
698
        // during the next query for this edge.
×
699
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
700
                if isUpdate1 {
×
701
                        entry.upd1Time = e.LastUpdate.Unix()
×
702
                } else {
×
703
                        entry.upd2Time = e.LastUpdate.Unix()
×
704
                }
×
705
                s.rejectCache.insert(e.ChannelID, entry)
×
706
        }
707

708
        // If an entry for this channel is found in channel cache, we'll modify
709
        // the entry with the updated policy for the direction that was just
710
        // written. If the edge doesn't exist, we'll defer loading the info and
711
        // policies and lazily read from disk during the next query.
712
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
713
                if isUpdate1 {
×
714
                        channel.Policy1 = e
×
715
                } else {
×
716
                        channel.Policy2 = e
×
717
                }
×
718
                s.chanCache.insert(e.ChannelID, channel)
×
719
        }
720
}
721

722
// ForEachSourceNodeChannel iterates through all channels of the source node,
723
// executing the passed callback on each. The call-back is provided with the
724
// channel's outpoint, whether we have a policy for the channel and the channel
725
// peer's node information.
726
//
727
// NOTE: part of the V1Store interface.
728
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
729
        cb func(chanPoint wire.OutPoint, havePolicy bool,
730
                otherNode *models.LightningNode) error, reset func()) error {
×
731

×
732
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
733
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
734
                if err != nil {
×
735
                        return fmt.Errorf("unable to fetch source node: %w",
×
736
                                err)
×
737
                }
×
738

739
                return forEachNodeChannel(
×
740
                        ctx, db, s.cfg.ChainHash, nodeID,
×
741
                        func(info *models.ChannelEdgeInfo,
×
742
                                outPolicy *models.ChannelEdgePolicy,
×
743
                                _ *models.ChannelEdgePolicy) error {
×
744

×
745
                                // Fetch the other node.
×
746
                                var (
×
747
                                        otherNodePub [33]byte
×
748
                                        node1        = info.NodeKey1Bytes
×
749
                                        node2        = info.NodeKey2Bytes
×
750
                                )
×
751
                                switch {
×
752
                                case bytes.Equal(node1[:], nodePub[:]):
×
753
                                        otherNodePub = node2
×
754
                                case bytes.Equal(node2[:], nodePub[:]):
×
755
                                        otherNodePub = node1
×
756
                                default:
×
757
                                        return fmt.Errorf("node not " +
×
758
                                                "participating in this channel")
×
759
                                }
760

761
                                _, otherNode, err := getNodeByPubKey(
×
762
                                        ctx, db, otherNodePub,
×
763
                                )
×
764
                                if err != nil {
×
765
                                        return fmt.Errorf("unable to fetch "+
×
766
                                                "other node(%x): %w",
×
767
                                                otherNodePub, err)
×
768
                                }
×
769

770
                                return cb(
×
771
                                        info.ChannelPoint, outPolicy != nil,
×
772
                                        otherNode,
×
773
                                )
×
774
                        },
775
                )
776
        }, reset)
777
}
778

779
// ForEachNode iterates through all the stored vertices/nodes in the graph,
780
// executing the passed callback with each node encountered. If the callback
781
// returns an error, then the transaction is aborted and the iteration stops
782
// early. Any operations performed on the NodeTx passed to the call-back are
783
// executed under the same read transaction and so, methods on the NodeTx object
784
// _MUST_ only be called from within the call-back.
785
//
786
// NOTE: part of the V1Store interface.
787
func (s *SQLStore) ForEachNode(ctx context.Context,
788
        cb func(tx NodeRTx) error, reset func()) error {
×
789

×
790
        var lastID int64 = 0
×
NEW
791
        handleNode := func(db SQLQueries, dbNode sqlc.GraphNode) error {
×
792
                node, err := buildNode(ctx, db, &dbNode)
×
793
                if err != nil {
×
794
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
795
                                dbNode.ID, err)
×
796
                }
×
797

798
                err = cb(
×
799
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
800
                )
×
801
                if err != nil {
×
802
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
803
                                dbNode.ID, err)
×
804
                }
×
805

806
                return nil
×
807
        }
808

809
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
810
                for {
×
811
                        nodes, err := db.ListNodesPaginated(
×
812
                                ctx, sqlc.ListNodesPaginatedParams{
×
813
                                        Version: int16(ProtocolV1),
×
814
                                        ID:      lastID,
×
815
                                        Limit:   pageSize,
×
816
                                },
×
817
                        )
×
818
                        if err != nil {
×
819
                                return fmt.Errorf("unable to fetch nodes: %w",
×
820
                                        err)
×
821
                        }
×
822

823
                        if len(nodes) == 0 {
×
824
                                break
×
825
                        }
826

827
                        for _, dbNode := range nodes {
×
828
                                err = handleNode(db, dbNode)
×
829
                                if err != nil {
×
830
                                        return err
×
831
                                }
×
832

833
                                lastID = dbNode.ID
×
834
                        }
835
                }
836

837
                return nil
×
838
        }, reset)
839
}
840

841
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
842
// SQLStore and a SQL transaction.
843
type sqlGraphNodeTx struct {
844
        db    SQLQueries
845
        id    int64
846
        node  *models.LightningNode
847
        chain chainhash.Hash
848
}
849

850
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
851
// interface.
852
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
853

854
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
855
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
856

×
857
        return &sqlGraphNodeTx{
×
858
                db:    db,
×
859
                chain: chain,
×
860
                id:    id,
×
861
                node:  node,
×
862
        }
×
863
}
×
864

865
// Node returns the raw information of the node.
866
//
867
// NOTE: This is a part of the NodeRTx interface.
868
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
869
        return s.node
×
870
}
×
871

872
// ForEachChannel can be used to iterate over the node's channels under the same
873
// transaction used to fetch the node.
874
//
875
// NOTE: This is a part of the NodeRTx interface.
876
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
877
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
878

×
879
        ctx := context.TODO()
×
880

×
881
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
882
}
×
883

884
// FetchNode fetches the node with the given pub key under the same transaction
885
// used to fetch the current node. The returned node is also a NodeRTx and any
886
// operations on that NodeRTx will also be done under the same transaction.
887
//
888
// NOTE: This is a part of the NodeRTx interface.
889
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
890
        ctx := context.TODO()
×
891

×
892
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
893
        if err != nil {
×
894
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
895
                        nodePub, err)
×
896
        }
×
897

898
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
899
}
900

901
// ForEachNodeDirectedChannel iterates through all channels of a given node,
902
// executing the passed callback on the directed edge representing the channel
903
// and its incoming policy. If the callback returns an error, then the iteration
904
// is halted with the error propagated back up to the caller.
905
//
906
// Unknown policies are passed into the callback as nil values.
907
//
908
// NOTE: this is part of the graphdb.NodeTraverser interface.
909
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
910
        cb func(channel *DirectedChannel) error, reset func()) error {
×
911

×
912
        var ctx = context.TODO()
×
913

×
914
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
915
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
916
        }, reset)
×
917
}
918

919
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
920
// graph, executing the passed callback with each node encountered. If the
921
// callback returns an error, then the transaction is aborted and the iteration
922
// stops early.
923
//
924
// NOTE: This is a part of the V1Store interface.
925
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
926
        cb func(route.Vertex, *lnwire.FeatureVector) error,
927
        reset func()) error {
×
928

×
929
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
930
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
931
                        nodePub route.Vertex) error {
×
932

×
933
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
934
                        if err != nil {
×
935
                                return fmt.Errorf("unable to fetch node "+
×
936
                                        "features: %w", err)
×
937
                        }
×
938

939
                        return cb(nodePub, features)
×
940
                })
941
        }, reset)
942
        if err != nil {
×
943
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
944
        }
×
945

946
        return nil
×
947
}
948

949
// ForEachNodeChannel iterates through all channels of the given node,
950
// executing the passed callback with an edge info structure and the policies
951
// of each end of the channel. The first edge policy is the outgoing edge *to*
952
// the connecting node, while the second is the incoming edge *from* the
953
// connecting node. If the callback returns an error, then the iteration is
954
// halted with the error propagated back up to the caller.
955
//
956
// Unknown policies are passed into the callback as nil values.
957
//
958
// NOTE: part of the V1Store interface.
959
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
960
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
961
                *models.ChannelEdgePolicy) error, reset func()) error {
×
962

×
963
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
964
                dbNode, err := db.GetNodeByPubKey(
×
965
                        ctx, sqlc.GetNodeByPubKeyParams{
×
966
                                Version: int16(ProtocolV1),
×
967
                                PubKey:  nodePub[:],
×
968
                        },
×
969
                )
×
970
                if errors.Is(err, sql.ErrNoRows) {
×
971
                        return nil
×
972
                } else if err != nil {
×
973
                        return fmt.Errorf("unable to fetch node: %w", err)
×
974
                }
×
975

976
                return forEachNodeChannel(
×
977
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
978
                )
×
979
        }, reset)
980
}
981

982
// ChanUpdatesInHorizon returns all the known channel edges which have at least
983
// one edge that has an update timestamp within the specified horizon.
984
//
985
// NOTE: This is part of the V1Store interface.
986
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
987
        endTime time.Time) ([]ChannelEdge, error) {
×
988

×
989
        s.cacheMu.Lock()
×
990
        defer s.cacheMu.Unlock()
×
991

×
992
        var (
×
993
                ctx = context.TODO()
×
994
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
995
                // an additional map to keep track of the edges already seen to
×
996
                // prevent re-adding it.
×
997
                edgesSeen    = make(map[uint64]struct{})
×
998
                edgesToCache = make(map[uint64]ChannelEdge)
×
999
                edges        []ChannelEdge
×
1000
                hits         int
×
1001
        )
×
1002
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1003
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1004
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1005
                                Version:   int16(ProtocolV1),
×
1006
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1007
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1008
                        },
×
1009
                )
×
1010
                if err != nil {
×
1011
                        return err
×
1012
                }
×
1013

1014
                for _, row := range rows {
×
1015
                        // If we've already retrieved the info and policies for
×
1016
                        // this edge, then we can skip it as we don't need to do
×
1017
                        // so again.
×
NEW
1018
                        chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid)
×
1019
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1020
                                continue
×
1021
                        }
1022

1023
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1024
                                hits++
×
1025
                                edgesSeen[chanIDInt] = struct{}{}
×
1026
                                edges = append(edges, channel)
×
1027

×
1028
                                continue
×
1029
                        }
1030

1031
                        node1, node2, err := buildNodes(
×
NEW
1032
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
1033
                        )
×
1034
                        if err != nil {
×
1035
                                return err
×
1036
                        }
×
1037

1038
                        channel, err := getAndBuildEdgeInfo(
×
NEW
1039
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
NEW
1040
                                row.GraphChannel, node1.PubKeyBytes,
×
1041
                                node2.PubKeyBytes,
×
1042
                        )
×
1043
                        if err != nil {
×
1044
                                return fmt.Errorf("unable to build channel "+
×
1045
                                        "info: %w", err)
×
1046
                        }
×
1047

1048
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1049
                        if err != nil {
×
1050
                                return fmt.Errorf("unable to extract channel "+
×
1051
                                        "policies: %w", err)
×
1052
                        }
×
1053

1054
                        p1, p2, err := getAndBuildChanPolicies(
×
1055
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1056
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1057
                        )
×
1058
                        if err != nil {
×
1059
                                return fmt.Errorf("unable to build channel "+
×
1060
                                        "policies: %w", err)
×
1061
                        }
×
1062

1063
                        edgesSeen[chanIDInt] = struct{}{}
×
1064
                        chanEdge := ChannelEdge{
×
1065
                                Info:    channel,
×
1066
                                Policy1: p1,
×
1067
                                Policy2: p2,
×
1068
                                Node1:   node1,
×
1069
                                Node2:   node2,
×
1070
                        }
×
1071
                        edges = append(edges, chanEdge)
×
1072
                        edgesToCache[chanIDInt] = chanEdge
×
1073
                }
1074

1075
                return nil
×
1076
        }, func() {
×
1077
                edgesSeen = make(map[uint64]struct{})
×
1078
                edgesToCache = make(map[uint64]ChannelEdge)
×
1079
                edges = nil
×
1080
        })
×
1081
        if err != nil {
×
1082
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1083
        }
×
1084

1085
        // Insert any edges loaded from disk into the cache.
1086
        for chanid, channel := range edgesToCache {
×
1087
                s.chanCache.insert(chanid, channel)
×
1088
        }
×
1089

1090
        if len(edges) > 0 {
×
1091
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1092
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1093
        } else {
×
1094
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1095
                        "horizon (%s, %s)", startTime, endTime)
×
1096
        }
×
1097

1098
        return edges, nil
×
1099
}
1100

1101
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1102
// data to the call-back.
1103
//
1104
// NOTE: The callback contents MUST not be modified.
1105
//
1106
// NOTE: part of the V1Store interface.
1107
func (s *SQLStore) ForEachNodeCached(ctx context.Context,
1108
        cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error,
1109
        reset func()) error {
×
1110

×
1111
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1112
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1113
                        nodePub route.Vertex) error {
×
1114

×
1115
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1116
                        if err != nil {
×
1117
                                return fmt.Errorf("unable to fetch "+
×
1118
                                        "node(id=%d) features: %w", nodeID, err)
×
1119
                        }
×
1120

1121
                        toNodeCallback := func() route.Vertex {
×
1122
                                return nodePub
×
1123
                        }
×
1124

1125
                        rows, err := db.ListChannelsByNodeID(
×
1126
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1127
                                        Version: int16(ProtocolV1),
×
1128
                                        NodeID1: nodeID,
×
1129
                                },
×
1130
                        )
×
1131
                        if err != nil {
×
1132
                                return fmt.Errorf("unable to fetch channels "+
×
1133
                                        "of node(id=%d): %w", nodeID, err)
×
1134
                        }
×
1135

1136
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1137
                        for _, row := range rows {
×
1138
                                node1, node2, err := buildNodeVertices(
×
1139
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1140
                                )
×
1141
                                if err != nil {
×
1142
                                        return err
×
1143
                                }
×
1144

1145
                                e, err := getAndBuildEdgeInfo(
×
1146
                                        ctx, db, s.cfg.ChainHash,
×
NEW
1147
                                        row.GraphChannel.ID, row.GraphChannel,
×
NEW
1148
                                        node1, node2,
×
1149
                                )
×
1150
                                if err != nil {
×
1151
                                        return fmt.Errorf("unable to build "+
×
1152
                                                "channel info: %w", err)
×
1153
                                }
×
1154

1155
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1156
                                        row,
×
1157
                                )
×
1158
                                if err != nil {
×
1159
                                        return fmt.Errorf("unable to "+
×
1160
                                                "extract channel "+
×
1161
                                                "policies: %w", err)
×
1162
                                }
×
1163

1164
                                p1, p2, err := getAndBuildChanPolicies(
×
1165
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1166
                                        node1, node2,
×
1167
                                )
×
1168
                                if err != nil {
×
1169
                                        return fmt.Errorf("unable to "+
×
1170
                                                "build channel policies: %w",
×
1171
                                                err)
×
1172
                                }
×
1173

1174
                                // Determine the outgoing and incoming policy
1175
                                // for this channel and node combo.
1176
                                outPolicy, inPolicy := p1, p2
×
1177
                                if p1 != nil && p1.ToNode == nodePub {
×
1178
                                        outPolicy, inPolicy = p2, p1
×
1179
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1180
                                        outPolicy, inPolicy = p2, p1
×
1181
                                }
×
1182

1183
                                var cachedInPolicy *models.CachedEdgePolicy
×
1184
                                if inPolicy != nil {
×
1185
                                        cachedInPolicy = models.NewCachedPolicy(
×
1186
                                                p2,
×
1187
                                        )
×
1188
                                        cachedInPolicy.ToNodePubKey =
×
1189
                                                toNodeCallback
×
1190
                                        cachedInPolicy.ToNodeFeatures =
×
1191
                                                features
×
1192
                                }
×
1193

1194
                                var inboundFee lnwire.Fee
×
1195
                                outPolicy.InboundFee.WhenSome(
×
1196
                                        func(fee lnwire.Fee) {
×
1197
                                                inboundFee = fee
×
1198
                                        },
×
1199
                                )
1200

1201
                                directedChannel := &DirectedChannel{
×
1202
                                        ChannelID: e.ChannelID,
×
1203
                                        IsNode1: nodePub ==
×
1204
                                                e.NodeKey1Bytes,
×
1205
                                        OtherNode:    e.NodeKey2Bytes,
×
1206
                                        Capacity:     e.Capacity,
×
1207
                                        OutPolicySet: p1 != nil,
×
1208
                                        InPolicy:     cachedInPolicy,
×
1209
                                        InboundFee:   inboundFee,
×
1210
                                }
×
1211

×
1212
                                if nodePub == e.NodeKey2Bytes {
×
1213
                                        directedChannel.OtherNode =
×
1214
                                                e.NodeKey1Bytes
×
1215
                                }
×
1216

1217
                                channels[e.ChannelID] = directedChannel
×
1218
                        }
1219

1220
                        return cb(nodePub, channels)
×
1221
                })
1222
        }, reset)
1223
}
1224

1225
// ForEachChannelCacheable iterates through all the channel edges stored
1226
// within the graph and invokes the passed callback for each edge. The
1227
// callback takes two edges as since this is a directed graph, both the
1228
// in/out edges are visited. If the callback returns an error, then the
1229
// transaction is aborted and the iteration stops early.
1230
//
1231
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1232
// pointer for that particular channel edge routing policy will be
1233
// passed into the callback.
1234
//
1235
// NOTE: this method is like ForEachChannel but fetches only the data
1236
// required for the graph cache.
1237
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1238
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1239
        reset func()) error {
×
1240

×
1241
        ctx := context.TODO()
×
1242

×
1243
        handleChannel := func(db SQLQueries,
×
1244
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1245

×
1246
                node1, node2, err := buildNodeVertices(
×
1247
                        row.Node1Pubkey, row.Node2Pubkey,
×
1248
                )
×
1249
                if err != nil {
×
1250
                        return err
×
1251
                }
×
1252

NEW
1253
                edge := buildCacheableChannelInfo(
×
NEW
1254
                        row.GraphChannel, node1, node2,
×
NEW
1255
                )
×
1256

×
1257
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1258
                if err != nil {
×
1259
                        return err
×
1260
                }
×
1261

1262
                var pol1, pol2 *models.CachedEdgePolicy
×
1263
                if dbPol1 != nil {
×
1264
                        policy1, err := buildChanPolicy(
×
1265
                                *dbPol1, edge.ChannelID, nil, node2,
×
1266
                        )
×
1267
                        if err != nil {
×
1268
                                return err
×
1269
                        }
×
1270

1271
                        pol1 = models.NewCachedPolicy(policy1)
×
1272
                }
1273
                if dbPol2 != nil {
×
1274
                        policy2, err := buildChanPolicy(
×
1275
                                *dbPol2, edge.ChannelID, nil, node1,
×
1276
                        )
×
1277
                        if err != nil {
×
1278
                                return err
×
1279
                        }
×
1280

1281
                        pol2 = models.NewCachedPolicy(policy2)
×
1282
                }
1283

1284
                if err := cb(edge, pol1, pol2); err != nil {
×
1285
                        return err
×
1286
                }
×
1287

1288
                return nil
×
1289
        }
1290

1291
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1292
                lastID := int64(-1)
×
1293
                for {
×
1294
                        //nolint:ll
×
1295
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1296
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1297
                                        Version: int16(ProtocolV1),
×
1298
                                        ID:      lastID,
×
1299
                                        Limit:   pageSize,
×
1300
                                },
×
1301
                        )
×
1302
                        if err != nil {
×
1303
                                return err
×
1304
                        }
×
1305

1306
                        if len(rows) == 0 {
×
1307
                                break
×
1308
                        }
1309

1310
                        for _, row := range rows {
×
1311
                                err := handleChannel(db, row)
×
1312
                                if err != nil {
×
1313
                                        return err
×
1314
                                }
×
1315

NEW
1316
                                lastID = row.GraphChannel.ID
×
1317
                        }
1318
                }
1319

1320
                return nil
×
1321
        }, reset)
1322
}
1323

1324
// ForEachChannel iterates through all the channel edges stored within the
1325
// graph and invokes the passed callback for each edge. The callback takes two
1326
// edges as since this is a directed graph, both the in/out edges are visited.
1327
// If the callback returns an error, then the transaction is aborted and the
1328
// iteration stops early.
1329
//
1330
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1331
// for that particular channel edge routing policy will be passed into the
1332
// callback.
1333
//
1334
// NOTE: part of the V1Store interface.
1335
func (s *SQLStore) ForEachChannel(ctx context.Context,
1336
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1337
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1338

×
1339
        handleChannel := func(db SQLQueries,
×
1340
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1341

×
1342
                node1, node2, err := buildNodeVertices(
×
1343
                        row.Node1Pubkey, row.Node2Pubkey,
×
1344
                )
×
1345
                if err != nil {
×
1346
                        return fmt.Errorf("unable to build node vertices: %w",
×
1347
                                err)
×
1348
                }
×
1349

1350
                edge, err := getAndBuildEdgeInfo(
×
NEW
1351
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
NEW
1352
                        row.GraphChannel, node1, node2,
×
1353
                )
×
1354
                if err != nil {
×
1355
                        return fmt.Errorf("unable to build channel info: %w",
×
1356
                                err)
×
1357
                }
×
1358

1359
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1360
                if err != nil {
×
1361
                        return fmt.Errorf("unable to extract channel "+
×
1362
                                "policies: %w", err)
×
1363
                }
×
1364

1365
                p1, p2, err := getAndBuildChanPolicies(
×
1366
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1367
                )
×
1368
                if err != nil {
×
1369
                        return fmt.Errorf("unable to build channel "+
×
1370
                                "policies: %w", err)
×
1371
                }
×
1372

1373
                err = cb(edge, p1, p2)
×
1374
                if err != nil {
×
1375
                        return fmt.Errorf("callback failed for channel "+
×
1376
                                "id=%d: %w", edge.ChannelID, err)
×
1377
                }
×
1378

1379
                return nil
×
1380
        }
1381

1382
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1383
                lastID := int64(-1)
×
1384
                for {
×
1385
                        //nolint:ll
×
1386
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1387
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1388
                                        Version: int16(ProtocolV1),
×
1389
                                        ID:      lastID,
×
1390
                                        Limit:   pageSize,
×
1391
                                },
×
1392
                        )
×
1393
                        if err != nil {
×
1394
                                return err
×
1395
                        }
×
1396

1397
                        if len(rows) == 0 {
×
1398
                                break
×
1399
                        }
1400

1401
                        for _, row := range rows {
×
1402
                                err := handleChannel(db, row)
×
1403
                                if err != nil {
×
1404
                                        return err
×
1405
                                }
×
1406

NEW
1407
                                lastID = row.GraphChannel.ID
×
1408
                        }
1409
                }
1410

1411
                return nil
×
1412
        }, reset)
1413
}
1414

1415
// FilterChannelRange returns the channel ID's of all known channels which were
1416
// mined in a block height within the passed range. The channel IDs are grouped
1417
// by their common block height. This method can be used to quickly share with a
1418
// peer the set of channels we know of within a particular range to catch them
1419
// up after a period of time offline. If withTimestamps is true then the
1420
// timestamp info of the latest received channel update messages of the channel
1421
// will be included in the response.
1422
//
1423
// NOTE: This is part of the V1Store interface.
1424
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1425
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1426

×
1427
        var (
×
1428
                ctx       = context.TODO()
×
1429
                startSCID = &lnwire.ShortChannelID{
×
1430
                        BlockHeight: startHeight,
×
1431
                }
×
1432
                endSCID = lnwire.ShortChannelID{
×
1433
                        BlockHeight: endHeight,
×
1434
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1435
                        TxPosition:  math.MaxUint16,
×
1436
                }
×
1437
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1438
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1439
        )
×
1440

×
1441
        // 1) get all channels where channelID is between start and end chan ID.
×
1442
        // 2) skip if not public (ie, no channel_proof)
×
1443
        // 3) collect that channel.
×
1444
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1445
        //    and add those timestamps to the collected channel.
×
1446
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1447
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1448
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1449
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1450
                                StartScid: chanIDStart,
×
1451
                                EndScid:   chanIDEnd,
×
1452
                        },
×
1453
                )
×
1454
                if err != nil {
×
1455
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1456
                                err)
×
1457
                }
×
1458

1459
                for _, dbChan := range dbChans {
×
1460
                        cid := lnwire.NewShortChanIDFromInt(
×
1461
                                byteOrder.Uint64(dbChan.Scid),
×
1462
                        )
×
1463
                        chanInfo := NewChannelUpdateInfo(
×
1464
                                cid, time.Time{}, time.Time{},
×
1465
                        )
×
1466

×
1467
                        if !withTimestamps {
×
1468
                                channelsPerBlock[cid.BlockHeight] = append(
×
1469
                                        channelsPerBlock[cid.BlockHeight],
×
1470
                                        chanInfo,
×
1471
                                )
×
1472

×
1473
                                continue
×
1474
                        }
1475

1476
                        //nolint:ll
1477
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1478
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1479
                                        Version:   int16(ProtocolV1),
×
1480
                                        ChannelID: dbChan.ID,
×
1481
                                        NodeID:    dbChan.NodeID1,
×
1482
                                },
×
1483
                        )
×
1484
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1485
                                return fmt.Errorf("unable to fetch node1 "+
×
1486
                                        "policy: %w", err)
×
1487
                        } else if err == nil {
×
1488
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1489
                                        node1Policy.LastUpdate.Int64, 0,
×
1490
                                )
×
1491
                        }
×
1492

1493
                        //nolint:ll
1494
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1495
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1496
                                        Version:   int16(ProtocolV1),
×
1497
                                        ChannelID: dbChan.ID,
×
1498
                                        NodeID:    dbChan.NodeID2,
×
1499
                                },
×
1500
                        )
×
1501
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1502
                                return fmt.Errorf("unable to fetch node2 "+
×
1503
                                        "policy: %w", err)
×
1504
                        } else if err == nil {
×
1505
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1506
                                        node2Policy.LastUpdate.Int64, 0,
×
1507
                                )
×
1508
                        }
×
1509

1510
                        channelsPerBlock[cid.BlockHeight] = append(
×
1511
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1512
                        )
×
1513
                }
1514

1515
                return nil
×
1516
        }, func() {
×
1517
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1518
        })
×
1519
        if err != nil {
×
1520
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1521
        }
×
1522

1523
        if len(channelsPerBlock) == 0 {
×
1524
                return nil, nil
×
1525
        }
×
1526

1527
        // Return the channel ranges in ascending block height order.
1528
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1529
        slices.Sort(blocks)
×
1530

×
1531
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1532
                return BlockChannelRange{
×
1533
                        Height:   block,
×
1534
                        Channels: channelsPerBlock[block],
×
1535
                }
×
1536
        }), nil
×
1537
}
1538

1539
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1540
// zombie. This method is used on an ad-hoc basis, when channels need to be
1541
// marked as zombies outside the normal pruning cycle.
1542
//
1543
// NOTE: part of the V1Store interface.
1544
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1545
        pubKey1, pubKey2 [33]byte) error {
×
1546

×
1547
        ctx := context.TODO()
×
1548

×
1549
        s.cacheMu.Lock()
×
1550
        defer s.cacheMu.Unlock()
×
1551

×
1552
        chanIDB := channelIDToBytes(chanID)
×
1553

×
1554
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1555
                return db.UpsertZombieChannel(
×
1556
                        ctx, sqlc.UpsertZombieChannelParams{
×
1557
                                Version:  int16(ProtocolV1),
×
1558
                                Scid:     chanIDB,
×
1559
                                NodeKey1: pubKey1[:],
×
1560
                                NodeKey2: pubKey2[:],
×
1561
                        },
×
1562
                )
×
1563
        }, sqldb.NoOpReset)
×
1564
        if err != nil {
×
1565
                return fmt.Errorf("unable to upsert zombie channel "+
×
1566
                        "(channel_id=%d): %w", chanID, err)
×
1567
        }
×
1568

1569
        s.rejectCache.remove(chanID)
×
1570
        s.chanCache.remove(chanID)
×
1571

×
1572
        return nil
×
1573
}
1574

1575
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1576
//
1577
// NOTE: part of the V1Store interface.
1578
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1579
        s.cacheMu.Lock()
×
1580
        defer s.cacheMu.Unlock()
×
1581

×
1582
        var (
×
1583
                ctx     = context.TODO()
×
1584
                chanIDB = channelIDToBytes(chanID)
×
1585
        )
×
1586

×
1587
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1588
                res, err := db.DeleteZombieChannel(
×
1589
                        ctx, sqlc.DeleteZombieChannelParams{
×
1590
                                Scid:    chanIDB,
×
1591
                                Version: int16(ProtocolV1),
×
1592
                        },
×
1593
                )
×
1594
                if err != nil {
×
1595
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1596
                                err)
×
1597
                }
×
1598

1599
                rows, err := res.RowsAffected()
×
1600
                if err != nil {
×
1601
                        return err
×
1602
                }
×
1603

1604
                if rows == 0 {
×
1605
                        return ErrZombieEdgeNotFound
×
1606
                } else if rows > 1 {
×
1607
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1608
                                "expected 1", rows)
×
1609
                }
×
1610

1611
                return nil
×
1612
        }, sqldb.NoOpReset)
1613
        if err != nil {
×
1614
                return fmt.Errorf("unable to mark edge live "+
×
1615
                        "(channel_id=%d): %w", chanID, err)
×
1616
        }
×
1617

1618
        s.rejectCache.remove(chanID)
×
1619
        s.chanCache.remove(chanID)
×
1620

×
1621
        return err
×
1622
}
1623

1624
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1625
// zombie, then the two node public keys corresponding to this edge are also
1626
// returned.
1627
//
1628
// NOTE: part of the V1Store interface.
1629
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1630
        error) {
×
1631

×
1632
        var (
×
1633
                ctx              = context.TODO()
×
1634
                isZombie         bool
×
1635
                pubKey1, pubKey2 route.Vertex
×
1636
                chanIDB          = channelIDToBytes(chanID)
×
1637
        )
×
1638

×
1639
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1640
                zombie, err := db.GetZombieChannel(
×
1641
                        ctx, sqlc.GetZombieChannelParams{
×
1642
                                Scid:    chanIDB,
×
1643
                                Version: int16(ProtocolV1),
×
1644
                        },
×
1645
                )
×
1646
                if errors.Is(err, sql.ErrNoRows) {
×
1647
                        return nil
×
1648
                }
×
1649
                if err != nil {
×
1650
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1651
                                err)
×
1652
                }
×
1653

1654
                copy(pubKey1[:], zombie.NodeKey1)
×
1655
                copy(pubKey2[:], zombie.NodeKey2)
×
1656
                isZombie = true
×
1657

×
1658
                return nil
×
1659
        }, sqldb.NoOpReset)
1660
        if err != nil {
×
1661
                return false, route.Vertex{}, route.Vertex{},
×
1662
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1663
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1664
        }
×
1665

1666
        return isZombie, pubKey1, pubKey2, nil
×
1667
}
1668

1669
// NumZombies returns the current number of zombie channels in the graph.
1670
//
1671
// NOTE: part of the V1Store interface.
1672
func (s *SQLStore) NumZombies() (uint64, error) {
×
1673
        var (
×
1674
                ctx        = context.TODO()
×
1675
                numZombies uint64
×
1676
        )
×
1677
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1678
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1679
                if err != nil {
×
1680
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1681
                                err)
×
1682
                }
×
1683

1684
                numZombies = uint64(count)
×
1685

×
1686
                return nil
×
1687
        }, sqldb.NoOpReset)
1688
        if err != nil {
×
1689
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1690
        }
×
1691

1692
        return numZombies, nil
×
1693
}
1694

1695
// DeleteChannelEdges removes edges with the given channel IDs from the
1696
// database and marks them as zombies. This ensures that we're unable to re-add
1697
// it to our database once again. If an edge does not exist within the
1698
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1699
// true, then when we mark these edges as zombies, we'll set up the keys such
1700
// that we require the node that failed to send the fresh update to be the one
1701
// that resurrects the channel from its zombie state. The markZombie bool
1702
// denotes whether to mark the channel as a zombie.
1703
//
1704
// NOTE: part of the V1Store interface.
1705
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1706
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1707

×
1708
        s.cacheMu.Lock()
×
1709
        defer s.cacheMu.Unlock()
×
1710

×
1711
        var (
×
1712
                ctx     = context.TODO()
×
1713
                deleted []*models.ChannelEdgeInfo
×
1714
        )
×
1715
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1716
                for _, chanID := range chanIDs {
×
1717
                        chanIDB := channelIDToBytes(chanID)
×
1718

×
1719
                        row, err := db.GetChannelBySCIDWithPolicies(
×
1720
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1721
                                        Scid:    chanIDB,
×
1722
                                        Version: int16(ProtocolV1),
×
1723
                                },
×
1724
                        )
×
1725
                        if errors.Is(err, sql.ErrNoRows) {
×
1726
                                return ErrEdgeNotFound
×
1727
                        } else if err != nil {
×
1728
                                return fmt.Errorf("unable to fetch channel: %w",
×
1729
                                        err)
×
1730
                        }
×
1731

1732
                        node1, node2, err := buildNodeVertices(
×
NEW
1733
                                row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1734
                        )
×
1735
                        if err != nil {
×
1736
                                return err
×
1737
                        }
×
1738

1739
                        info, err := getAndBuildEdgeInfo(
×
NEW
1740
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
NEW
1741
                                row.GraphChannel, node1, node2,
×
1742
                        )
×
1743
                        if err != nil {
×
1744
                                return err
×
1745
                        }
×
1746

NEW
1747
                        err = db.DeleteChannel(ctx, row.GraphChannel.ID)
×
1748
                        if err != nil {
×
1749
                                return fmt.Errorf("unable to delete "+
×
1750
                                        "channel: %w", err)
×
1751
                        }
×
1752

1753
                        deleted = append(deleted, info)
×
1754

×
1755
                        if !markZombie {
×
1756
                                continue
×
1757
                        }
1758

1759
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1760
                                info.NodeKey2Bytes
×
1761
                        if strictZombiePruning {
×
1762
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1763
                                if row.Policy1LastUpdate.Valid {
×
1764
                                        e1Time := time.Unix(
×
1765
                                                row.Policy1LastUpdate.Int64, 0,
×
1766
                                        )
×
1767
                                        e1UpdateTime = &e1Time
×
1768
                                }
×
1769
                                if row.Policy2LastUpdate.Valid {
×
1770
                                        e2Time := time.Unix(
×
1771
                                                row.Policy2LastUpdate.Int64, 0,
×
1772
                                        )
×
1773
                                        e2UpdateTime = &e2Time
×
1774
                                }
×
1775

1776
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1777
                                        info, e1UpdateTime, e2UpdateTime,
×
1778
                                )
×
1779
                        }
1780

1781
                        err = db.UpsertZombieChannel(
×
1782
                                ctx, sqlc.UpsertZombieChannelParams{
×
1783
                                        Version:  int16(ProtocolV1),
×
1784
                                        Scid:     chanIDB,
×
1785
                                        NodeKey1: nodeKey1[:],
×
1786
                                        NodeKey2: nodeKey2[:],
×
1787
                                },
×
1788
                        )
×
1789
                        if err != nil {
×
1790
                                return fmt.Errorf("unable to mark channel as "+
×
1791
                                        "zombie: %w", err)
×
1792
                        }
×
1793
                }
1794

1795
                return nil
×
1796
        }, func() {
×
1797
                deleted = nil
×
1798
        })
×
1799
        if err != nil {
×
1800
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1801
                        err)
×
1802
        }
×
1803

1804
        for _, chanID := range chanIDs {
×
1805
                s.rejectCache.remove(chanID)
×
1806
                s.chanCache.remove(chanID)
×
1807
        }
×
1808

1809
        return deleted, nil
×
1810
}
1811

1812
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1813
// channel identified by the channel ID. If the channel can't be found, then
1814
// ErrEdgeNotFound is returned. A struct which houses the general information
1815
// for the channel itself is returned as well as two structs that contain the
1816
// routing policies for the channel in either direction.
1817
//
1818
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1819
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1820
// the ChannelEdgeInfo will only include the public keys of each node.
1821
//
1822
// NOTE: part of the V1Store interface.
1823
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1824
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1825
        *models.ChannelEdgePolicy, error) {
×
1826

×
1827
        var (
×
1828
                ctx              = context.TODO()
×
1829
                edge             *models.ChannelEdgeInfo
×
1830
                policy1, policy2 *models.ChannelEdgePolicy
×
1831
                chanIDB          = channelIDToBytes(chanID)
×
1832
        )
×
1833
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1834
                row, err := db.GetChannelBySCIDWithPolicies(
×
1835
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1836
                                Scid:    chanIDB,
×
1837
                                Version: int16(ProtocolV1),
×
1838
                        },
×
1839
                )
×
1840
                if errors.Is(err, sql.ErrNoRows) {
×
1841
                        // First check if this edge is perhaps in the zombie
×
1842
                        // index.
×
1843
                        zombie, err := db.GetZombieChannel(
×
1844
                                ctx, sqlc.GetZombieChannelParams{
×
1845
                                        Scid:    chanIDB,
×
1846
                                        Version: int16(ProtocolV1),
×
1847
                                },
×
1848
                        )
×
1849
                        if errors.Is(err, sql.ErrNoRows) {
×
1850
                                return ErrEdgeNotFound
×
1851
                        } else if err != nil {
×
1852
                                return fmt.Errorf("unable to check if "+
×
1853
                                        "channel is zombie: %w", err)
×
1854
                        }
×
1855

1856
                        // At this point, we know the channel is a zombie, so
1857
                        // we'll return an error indicating this, and we will
1858
                        // populate the edge info with the public keys of each
1859
                        // party as this is the only information we have about
1860
                        // it.
1861
                        edge = &models.ChannelEdgeInfo{}
×
1862
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1863
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1864

×
1865
                        return ErrZombieEdge
×
1866
                } else if err != nil {
×
1867
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1868
                }
×
1869

1870
                node1, node2, err := buildNodeVertices(
×
NEW
1871
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1872
                )
×
1873
                if err != nil {
×
1874
                        return err
×
1875
                }
×
1876

1877
                edge, err = getAndBuildEdgeInfo(
×
NEW
1878
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
NEW
1879
                        row.GraphChannel, node1, node2,
×
1880
                )
×
1881
                if err != nil {
×
1882
                        return fmt.Errorf("unable to build channel info: %w",
×
1883
                                err)
×
1884
                }
×
1885

1886
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1887
                if err != nil {
×
1888
                        return fmt.Errorf("unable to extract channel "+
×
1889
                                "policies: %w", err)
×
1890
                }
×
1891

1892
                policy1, policy2, err = getAndBuildChanPolicies(
×
1893
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1894
                )
×
1895
                if err != nil {
×
1896
                        return fmt.Errorf("unable to build channel "+
×
1897
                                "policies: %w", err)
×
1898
                }
×
1899

1900
                return nil
×
1901
        }, sqldb.NoOpReset)
1902
        if err != nil {
×
1903
                // If we are returning the ErrZombieEdge, then we also need to
×
1904
                // return the edge info as the method comment indicates that
×
1905
                // this will be populated when the edge is a zombie.
×
1906
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1907
                        err)
×
1908
        }
×
1909

1910
        return edge, policy1, policy2, nil
×
1911
}
1912

1913
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1914
// the channel identified by the funding outpoint. If the channel can't be
1915
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1916
// information for the channel itself is returned as well as two structs that
1917
// contain the routing policies for the channel in either direction.
1918
//
1919
// NOTE: part of the V1Store interface.
1920
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1921
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1922
        *models.ChannelEdgePolicy, error) {
×
1923

×
1924
        var (
×
1925
                ctx              = context.TODO()
×
1926
                edge             *models.ChannelEdgeInfo
×
1927
                policy1, policy2 *models.ChannelEdgePolicy
×
1928
        )
×
1929
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1930
                row, err := db.GetChannelByOutpointWithPolicies(
×
1931
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1932
                                Outpoint: op.String(),
×
1933
                                Version:  int16(ProtocolV1),
×
1934
                        },
×
1935
                )
×
1936
                if errors.Is(err, sql.ErrNoRows) {
×
1937
                        return ErrEdgeNotFound
×
1938
                } else if err != nil {
×
1939
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1940
                }
×
1941

1942
                node1, node2, err := buildNodeVertices(
×
1943
                        row.Node1Pubkey, row.Node2Pubkey,
×
1944
                )
×
1945
                if err != nil {
×
1946
                        return err
×
1947
                }
×
1948

1949
                edge, err = getAndBuildEdgeInfo(
×
NEW
1950
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
NEW
1951
                        row.GraphChannel, node1, node2,
×
1952
                )
×
1953
                if err != nil {
×
1954
                        return fmt.Errorf("unable to build channel info: %w",
×
1955
                                err)
×
1956
                }
×
1957

1958
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1959
                if err != nil {
×
1960
                        return fmt.Errorf("unable to extract channel "+
×
1961
                                "policies: %w", err)
×
1962
                }
×
1963

1964
                policy1, policy2, err = getAndBuildChanPolicies(
×
1965
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1966
                )
×
1967
                if err != nil {
×
1968
                        return fmt.Errorf("unable to build channel "+
×
1969
                                "policies: %w", err)
×
1970
                }
×
1971

1972
                return nil
×
1973
        }, sqldb.NoOpReset)
1974
        if err != nil {
×
1975
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1976
                        err)
×
1977
        }
×
1978

1979
        return edge, policy1, policy2, nil
×
1980
}
1981

1982
// HasChannelEdge returns true if the database knows of a channel edge with the
1983
// passed channel ID, and false otherwise. If an edge with that ID is found
1984
// within the graph, then two time stamps representing the last time the edge
1985
// was updated for both directed edges are returned along with the boolean. If
1986
// it is not found, then the zombie index is checked and its result is returned
1987
// as the second boolean.
1988
//
1989
// NOTE: part of the V1Store interface.
1990
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1991
        bool, error) {
×
1992

×
1993
        ctx := context.TODO()
×
1994

×
1995
        var (
×
1996
                exists          bool
×
1997
                isZombie        bool
×
1998
                node1LastUpdate time.Time
×
1999
                node2LastUpdate time.Time
×
2000
        )
×
2001

×
2002
        // We'll query the cache with the shared lock held to allow multiple
×
2003
        // readers to access values in the cache concurrently if they exist.
×
2004
        s.cacheMu.RLock()
×
2005
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2006
                s.cacheMu.RUnlock()
×
2007
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2008
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2009
                exists, isZombie = entry.flags.unpack()
×
2010

×
2011
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2012
        }
×
2013
        s.cacheMu.RUnlock()
×
2014

×
2015
        s.cacheMu.Lock()
×
2016
        defer s.cacheMu.Unlock()
×
2017

×
2018
        // The item was not found with the shared lock, so we'll acquire the
×
2019
        // exclusive lock and check the cache again in case another method added
×
2020
        // the entry to the cache while no lock was held.
×
2021
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2022
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2023
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2024
                exists, isZombie = entry.flags.unpack()
×
2025

×
2026
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2027
        }
×
2028

2029
        chanIDB := channelIDToBytes(chanID)
×
2030
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2031
                channel, err := db.GetChannelBySCID(
×
2032
                        ctx, sqlc.GetChannelBySCIDParams{
×
2033
                                Scid:    chanIDB,
×
2034
                                Version: int16(ProtocolV1),
×
2035
                        },
×
2036
                )
×
2037
                if errors.Is(err, sql.ErrNoRows) {
×
2038
                        // Check if it is a zombie channel.
×
2039
                        isZombie, err = db.IsZombieChannel(
×
2040
                                ctx, sqlc.IsZombieChannelParams{
×
2041
                                        Scid:    chanIDB,
×
2042
                                        Version: int16(ProtocolV1),
×
2043
                                },
×
2044
                        )
×
2045
                        if err != nil {
×
2046
                                return fmt.Errorf("could not check if channel "+
×
2047
                                        "is zombie: %w", err)
×
2048
                        }
×
2049

2050
                        return nil
×
2051
                } else if err != nil {
×
2052
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2053
                }
×
2054

2055
                exists = true
×
2056

×
2057
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2058
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2059
                                Version:   int16(ProtocolV1),
×
2060
                                ChannelID: channel.ID,
×
2061
                                NodeID:    channel.NodeID1,
×
2062
                        },
×
2063
                )
×
2064
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2065
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2066
                                err)
×
2067
                } else if err == nil {
×
2068
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2069
                }
×
2070

2071
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2072
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2073
                                Version:   int16(ProtocolV1),
×
2074
                                ChannelID: channel.ID,
×
2075
                                NodeID:    channel.NodeID2,
×
2076
                        },
×
2077
                )
×
2078
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2079
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2080
                                err)
×
2081
                } else if err == nil {
×
2082
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2083
                }
×
2084

2085
                return nil
×
2086
        }, sqldb.NoOpReset)
2087
        if err != nil {
×
2088
                return time.Time{}, time.Time{}, false, false,
×
2089
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2090
        }
×
2091

2092
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2093
                upd1Time: node1LastUpdate.Unix(),
×
2094
                upd2Time: node2LastUpdate.Unix(),
×
2095
                flags:    packRejectFlags(exists, isZombie),
×
2096
        })
×
2097

×
2098
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2099
}
2100

2101
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2102
// passed channel point (outpoint). If the passed channel doesn't exist within
2103
// the database, then ErrEdgeNotFound is returned.
2104
//
2105
// NOTE: part of the V1Store interface.
2106
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2107
        var (
×
2108
                ctx       = context.TODO()
×
2109
                channelID uint64
×
2110
        )
×
2111
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2112
                chanID, err := db.GetSCIDByOutpoint(
×
2113
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2114
                                Outpoint: chanPoint.String(),
×
2115
                                Version:  int16(ProtocolV1),
×
2116
                        },
×
2117
                )
×
2118
                if errors.Is(err, sql.ErrNoRows) {
×
2119
                        return ErrEdgeNotFound
×
2120
                } else if err != nil {
×
2121
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2122
                                err)
×
2123
                }
×
2124

2125
                channelID = byteOrder.Uint64(chanID)
×
2126

×
2127
                return nil
×
2128
        }, sqldb.NoOpReset)
2129
        if err != nil {
×
2130
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2131
        }
×
2132

2133
        return channelID, nil
×
2134
}
2135

2136
// IsPublicNode is a helper method that determines whether the node with the
2137
// given public key is seen as a public node in the graph from the graph's
2138
// source node's point of view.
2139
//
2140
// NOTE: part of the V1Store interface.
2141
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2142
        ctx := context.TODO()
×
2143

×
2144
        var isPublic bool
×
2145
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2146
                var err error
×
2147
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2148

×
2149
                return err
×
2150
        }, sqldb.NoOpReset)
×
2151
        if err != nil {
×
2152
                return false, fmt.Errorf("unable to check if node is "+
×
2153
                        "public: %w", err)
×
2154
        }
×
2155

2156
        return isPublic, nil
×
2157
}
2158

2159
// FetchChanInfos returns the set of channel edges that correspond to the passed
2160
// channel ID's. If an edge is the query is unknown to the database, it will
2161
// skipped and the result will contain only those edges that exist at the time
2162
// of the query. This can be used to respond to peer queries that are seeking to
2163
// fill in gaps in their view of the channel graph.
2164
//
2165
// NOTE: part of the V1Store interface.
2166
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2167
        var (
×
2168
                ctx   = context.TODO()
×
2169
                edges []ChannelEdge
×
2170
        )
×
2171
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2172
                for _, chanID := range chanIDs {
×
2173
                        chanIDB := channelIDToBytes(chanID)
×
2174

×
2175
                        // TODO(elle): potentially optimize this by using
×
2176
                        //  sqlc.slice() once that works for both SQLite and
×
2177
                        //  Postgres.
×
2178
                        row, err := db.GetChannelBySCIDWithPolicies(
×
2179
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
2180
                                        Scid:    chanIDB,
×
2181
                                        Version: int16(ProtocolV1),
×
2182
                                },
×
2183
                        )
×
2184
                        if errors.Is(err, sql.ErrNoRows) {
×
2185
                                continue
×
2186
                        } else if err != nil {
×
2187
                                return fmt.Errorf("unable to fetch channel: %w",
×
2188
                                        err)
×
2189
                        }
×
2190

2191
                        node1, node2, err := buildNodes(
×
NEW
2192
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
2193
                        )
×
2194
                        if err != nil {
×
2195
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2196
                                        err)
×
2197
                        }
×
2198

2199
                        edge, err := getAndBuildEdgeInfo(
×
NEW
2200
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
NEW
2201
                                row.GraphChannel, node1.PubKeyBytes,
×
2202
                                node2.PubKeyBytes,
×
2203
                        )
×
2204
                        if err != nil {
×
2205
                                return fmt.Errorf("unable to build "+
×
2206
                                        "channel info: %w", err)
×
2207
                        }
×
2208

2209
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2210
                        if err != nil {
×
2211
                                return fmt.Errorf("unable to extract channel "+
×
2212
                                        "policies: %w", err)
×
2213
                        }
×
2214

2215
                        p1, p2, err := getAndBuildChanPolicies(
×
2216
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2217
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2218
                        )
×
2219
                        if err != nil {
×
2220
                                return fmt.Errorf("unable to build channel "+
×
2221
                                        "policies: %w", err)
×
2222
                        }
×
2223

2224
                        edges = append(edges, ChannelEdge{
×
2225
                                Info:    edge,
×
2226
                                Policy1: p1,
×
2227
                                Policy2: p2,
×
2228
                                Node1:   node1,
×
2229
                                Node2:   node2,
×
2230
                        })
×
2231
                }
2232

2233
                return nil
×
2234
        }, func() {
×
2235
                edges = nil
×
2236
        })
×
2237
        if err != nil {
×
2238
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2239
        }
×
2240

2241
        return edges, nil
×
2242
}
2243

2244
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2245
// ID's that we don't know and are not known zombies of the passed set. In other
2246
// words, we perform a set difference of our set of chan ID's and the ones
2247
// passed in. This method can be used by callers to determine the set of
2248
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2249
// known zombies is also returned.
2250
//
2251
// NOTE: part of the V1Store interface.
2252
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2253
        []ChannelUpdateInfo, error) {
×
2254

×
2255
        var (
×
2256
                ctx          = context.TODO()
×
2257
                newChanIDs   []uint64
×
2258
                knownZombies []ChannelUpdateInfo
×
2259
        )
×
2260
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2261
                for _, chanInfo := range chansInfo {
×
2262
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2263
                        chanIDB := channelIDToBytes(channelID)
×
2264

×
2265
                        // TODO(elle): potentially optimize this by using
×
2266
                        //  sqlc.slice() once that works for both SQLite and
×
2267
                        //  Postgres.
×
2268
                        _, err := db.GetChannelBySCID(
×
2269
                                ctx, sqlc.GetChannelBySCIDParams{
×
2270
                                        Version: int16(ProtocolV1),
×
2271
                                        Scid:    chanIDB,
×
2272
                                },
×
2273
                        )
×
2274
                        if err == nil {
×
2275
                                continue
×
2276
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
2277
                                return fmt.Errorf("unable to fetch channel: %w",
×
2278
                                        err)
×
2279
                        }
×
2280

2281
                        isZombie, err := db.IsZombieChannel(
×
2282
                                ctx, sqlc.IsZombieChannelParams{
×
2283
                                        Scid:    chanIDB,
×
2284
                                        Version: int16(ProtocolV1),
×
2285
                                },
×
2286
                        )
×
2287
                        if err != nil {
×
2288
                                return fmt.Errorf("unable to fetch zombie "+
×
2289
                                        "channel: %w", err)
×
2290
                        }
×
2291

2292
                        if isZombie {
×
2293
                                knownZombies = append(knownZombies, chanInfo)
×
2294

×
2295
                                continue
×
2296
                        }
2297

2298
                        newChanIDs = append(newChanIDs, channelID)
×
2299
                }
2300

2301
                return nil
×
2302
        }, func() {
×
2303
                newChanIDs = nil
×
2304
                knownZombies = nil
×
2305
        })
×
2306
        if err != nil {
×
2307
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2308
        }
×
2309

2310
        return newChanIDs, knownZombies, nil
×
2311
}
2312

2313
// PruneGraphNodes is a garbage collection method which attempts to prune out
2314
// any nodes from the channel graph that are currently unconnected. This ensure
2315
// that we only maintain a graph of reachable nodes. In the event that a pruned
2316
// node gains more channels, it will be re-added back to the graph.
2317
//
2318
// NOTE: this prunes nodes across protocol versions. It will never prune the
2319
// source nodes.
2320
//
2321
// NOTE: part of the V1Store interface.
2322
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2323
        var ctx = context.TODO()
×
2324

×
2325
        var prunedNodes []route.Vertex
×
2326
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2327
                var err error
×
2328
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2329

×
2330
                return err
×
2331
        }, func() {
×
2332
                prunedNodes = nil
×
2333
        })
×
2334
        if err != nil {
×
2335
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2336
        }
×
2337

2338
        return prunedNodes, nil
×
2339
}
2340

2341
// PruneGraph prunes newly closed channels from the channel graph in response
2342
// to a new block being solved on the network. Any transactions which spend the
2343
// funding output of any known channels within he graph will be deleted.
2344
// Additionally, the "prune tip", or the last block which has been used to
2345
// prune the graph is stored so callers can ensure the graph is fully in sync
2346
// with the current UTXO state. A slice of channels that have been closed by
2347
// the target block along with any pruned nodes are returned if the function
2348
// succeeds without error.
2349
//
2350
// NOTE: part of the V1Store interface.
2351
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2352
        blockHash *chainhash.Hash, blockHeight uint32) (
2353
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2354

×
2355
        ctx := context.TODO()
×
2356

×
2357
        s.cacheMu.Lock()
×
2358
        defer s.cacheMu.Unlock()
×
2359

×
2360
        var (
×
2361
                closedChans []*models.ChannelEdgeInfo
×
2362
                prunedNodes []route.Vertex
×
2363
        )
×
2364
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2365
                for _, outpoint := range spentOutputs {
×
2366
                        // TODO(elle): potentially optimize this by using
×
2367
                        //  sqlc.slice() once that works for both SQLite and
×
2368
                        //  Postgres.
×
2369
                        //
×
2370
                        // NOTE: this fetches channels for all protocol
×
2371
                        // versions.
×
2372
                        row, err := db.GetChannelByOutpoint(
×
2373
                                ctx, outpoint.String(),
×
2374
                        )
×
2375
                        if errors.Is(err, sql.ErrNoRows) {
×
2376
                                continue
×
2377
                        } else if err != nil {
×
2378
                                return fmt.Errorf("unable to fetch channel: %w",
×
2379
                                        err)
×
2380
                        }
×
2381

2382
                        node1, node2, err := buildNodeVertices(
×
2383
                                row.Node1Pubkey, row.Node2Pubkey,
×
2384
                        )
×
2385
                        if err != nil {
×
2386
                                return err
×
2387
                        }
×
2388

2389
                        info, err := getAndBuildEdgeInfo(
×
NEW
2390
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
NEW
2391
                                row.GraphChannel, node1, node2,
×
2392
                        )
×
2393
                        if err != nil {
×
2394
                                return err
×
2395
                        }
×
2396

NEW
2397
                        err = db.DeleteChannel(ctx, row.GraphChannel.ID)
×
2398
                        if err != nil {
×
2399
                                return fmt.Errorf("unable to delete "+
×
2400
                                        "channel: %w", err)
×
2401
                        }
×
2402

2403
                        closedChans = append(closedChans, info)
×
2404
                }
2405

2406
                err := db.UpsertPruneLogEntry(
×
2407
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2408
                                BlockHash:   blockHash[:],
×
2409
                                BlockHeight: int64(blockHeight),
×
2410
                        },
×
2411
                )
×
2412
                if err != nil {
×
2413
                        return fmt.Errorf("unable to insert prune log "+
×
2414
                                "entry: %w", err)
×
2415
                }
×
2416

2417
                // Now that we've pruned some channels, we'll also prune any
2418
                // nodes that no longer have any channels.
2419
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2420
                if err != nil {
×
2421
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2422
                                err)
×
2423
                }
×
2424

2425
                return nil
×
2426
        }, func() {
×
2427
                prunedNodes = nil
×
2428
                closedChans = nil
×
2429
        })
×
2430
        if err != nil {
×
2431
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2432
        }
×
2433

2434
        for _, channel := range closedChans {
×
2435
                s.rejectCache.remove(channel.ChannelID)
×
2436
                s.chanCache.remove(channel.ChannelID)
×
2437
        }
×
2438

2439
        return closedChans, prunedNodes, nil
×
2440
}
2441

2442
// ChannelView returns the verifiable edge information for each active channel
2443
// within the known channel graph. The set of UTXOs (along with their scripts)
2444
// returned are the ones that need to be watched on chain to detect channel
2445
// closes on the resident blockchain.
2446
//
2447
// NOTE: part of the V1Store interface.
2448
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2449
        var (
×
2450
                ctx        = context.TODO()
×
2451
                edgePoints []EdgePoint
×
2452
        )
×
2453

×
2454
        handleChannel := func(db SQLQueries,
×
2455
                channel sqlc.ListChannelsPaginatedRow) error {
×
2456

×
2457
                pkScript, err := genMultiSigP2WSH(
×
2458
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
2459
                )
×
2460
                if err != nil {
×
2461
                        return err
×
2462
                }
×
2463

2464
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2465
                if err != nil {
×
2466
                        return err
×
2467
                }
×
2468

2469
                edgePoints = append(edgePoints, EdgePoint{
×
2470
                        FundingPkScript: pkScript,
×
2471
                        OutPoint:        *op,
×
2472
                })
×
2473

×
2474
                return nil
×
2475
        }
2476

2477
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2478
                lastID := int64(-1)
×
2479
                for {
×
2480
                        rows, err := db.ListChannelsPaginated(
×
2481
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2482
                                        Version: int16(ProtocolV1),
×
2483
                                        ID:      lastID,
×
2484
                                        Limit:   pageSize,
×
2485
                                },
×
2486
                        )
×
2487
                        if err != nil {
×
2488
                                return err
×
2489
                        }
×
2490

2491
                        if len(rows) == 0 {
×
2492
                                break
×
2493
                        }
2494

2495
                        for _, row := range rows {
×
2496
                                err := handleChannel(db, row)
×
2497
                                if err != nil {
×
2498
                                        return err
×
2499
                                }
×
2500

2501
                                lastID = row.ID
×
2502
                        }
2503
                }
2504

2505
                return nil
×
2506
        }, func() {
×
2507
                edgePoints = nil
×
2508
        })
×
2509
        if err != nil {
×
2510
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2511
        }
×
2512

2513
        return edgePoints, nil
×
2514
}
2515

2516
// PruneTip returns the block height and hash of the latest block that has been
2517
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2518
// to tell if the graph is currently in sync with the current best known UTXO
2519
// state.
2520
//
2521
// NOTE: part of the V1Store interface.
2522
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2523
        var (
×
2524
                ctx       = context.TODO()
×
2525
                tipHash   chainhash.Hash
×
2526
                tipHeight uint32
×
2527
        )
×
2528
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2529
                pruneTip, err := db.GetPruneTip(ctx)
×
2530
                if errors.Is(err, sql.ErrNoRows) {
×
2531
                        return ErrGraphNeverPruned
×
2532
                } else if err != nil {
×
2533
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2534
                }
×
2535

2536
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2537
                tipHeight = uint32(pruneTip.BlockHeight)
×
2538

×
2539
                return nil
×
2540
        }, sqldb.NoOpReset)
2541
        if err != nil {
×
2542
                return nil, 0, err
×
2543
        }
×
2544

2545
        return &tipHash, tipHeight, nil
×
2546
}
2547

2548
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2549
//
2550
// NOTE: this prunes nodes across protocol versions. It will never prune the
2551
// source nodes.
2552
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2553
        db SQLQueries) ([]route.Vertex, error) {
×
2554

×
2555
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2556
        if err != nil {
×
2557
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2558
                        "nodes: %w", err)
×
2559
        }
×
2560

2561
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2562
        for i, nodeKey := range nodeKeys {
×
2563
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2564
                if err != nil {
×
2565
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2566
                                "from bytes: %w", err)
×
2567
                }
×
2568

2569
                prunedNodes[i] = pub
×
2570
        }
2571

2572
        return prunedNodes, nil
×
2573
}
2574

2575
// DisconnectBlockAtHeight is used to indicate that the block specified
2576
// by the passed height has been disconnected from the main chain. This
2577
// will "rewind" the graph back to the height below, deleting channels
2578
// that are no longer confirmed from the graph. The prune log will be
2579
// set to the last prune height valid for the remaining chain.
2580
// Channels that were removed from the graph resulting from the
2581
// disconnected block are returned.
2582
//
2583
// NOTE: part of the V1Store interface.
2584
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2585
        []*models.ChannelEdgeInfo, error) {
×
2586

×
2587
        ctx := context.TODO()
×
2588

×
2589
        var (
×
2590
                // Every channel having a ShortChannelID starting at 'height'
×
2591
                // will no longer be confirmed.
×
2592
                startShortChanID = lnwire.ShortChannelID{
×
2593
                        BlockHeight: height,
×
2594
                }
×
2595

×
2596
                // Delete everything after this height from the db up until the
×
2597
                // SCID alias range.
×
2598
                endShortChanID = aliasmgr.StartingAlias
×
2599

×
2600
                removedChans []*models.ChannelEdgeInfo
×
2601

×
2602
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2603
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2604
        )
×
2605

×
2606
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2607
                rows, err := db.GetChannelsBySCIDRange(
×
2608
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2609
                                StartScid: chanIDStart,
×
2610
                                EndScid:   chanIDEnd,
×
2611
                        },
×
2612
                )
×
2613
                if err != nil {
×
2614
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2615
                }
×
2616

2617
                for _, row := range rows {
×
2618
                        node1, node2, err := buildNodeVertices(
×
2619
                                row.Node1PubKey, row.Node2PubKey,
×
2620
                        )
×
2621
                        if err != nil {
×
2622
                                return err
×
2623
                        }
×
2624

2625
                        channel, err := getAndBuildEdgeInfo(
×
NEW
2626
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
NEW
2627
                                row.GraphChannel, node1, node2,
×
2628
                        )
×
2629
                        if err != nil {
×
2630
                                return err
×
2631
                        }
×
2632

NEW
2633
                        err = db.DeleteChannel(ctx, row.GraphChannel.ID)
×
2634
                        if err != nil {
×
2635
                                return fmt.Errorf("unable to delete "+
×
2636
                                        "channel: %w", err)
×
2637
                        }
×
2638

2639
                        removedChans = append(removedChans, channel)
×
2640
                }
2641

2642
                return db.DeletePruneLogEntriesInRange(
×
2643
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2644
                                StartHeight: int64(height),
×
2645
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2646
                        },
×
2647
                )
×
2648
        }, func() {
×
2649
                removedChans = nil
×
2650
        })
×
2651
        if err != nil {
×
2652
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2653
                        "height: %w", err)
×
2654
        }
×
2655

2656
        for _, channel := range removedChans {
×
2657
                s.rejectCache.remove(channel.ChannelID)
×
2658
                s.chanCache.remove(channel.ChannelID)
×
2659
        }
×
2660

2661
        return removedChans, nil
×
2662
}
2663

2664
// AddEdgeProof sets the proof of an existing edge in the graph database.
2665
//
2666
// NOTE: part of the V1Store interface.
2667
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2668
        proof *models.ChannelAuthProof) error {
×
2669

×
2670
        var (
×
2671
                ctx       = context.TODO()
×
2672
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2673
        )
×
2674

×
2675
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2676
                res, err := db.AddV1ChannelProof(
×
2677
                        ctx, sqlc.AddV1ChannelProofParams{
×
2678
                                Scid:              scidBytes,
×
2679
                                Node1Signature:    proof.NodeSig1Bytes,
×
2680
                                Node2Signature:    proof.NodeSig2Bytes,
×
2681
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2682
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2683
                        },
×
2684
                )
×
2685
                if err != nil {
×
2686
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2687
                }
×
2688

2689
                n, err := res.RowsAffected()
×
2690
                if err != nil {
×
2691
                        return err
×
2692
                }
×
2693

2694
                if n == 0 {
×
2695
                        return fmt.Errorf("no rows affected when adding edge "+
×
2696
                                "proof for SCID %v", scid)
×
2697
                } else if n > 1 {
×
2698
                        return fmt.Errorf("multiple rows affected when adding "+
×
2699
                                "edge proof for SCID %v: %d rows affected",
×
2700
                                scid, n)
×
2701
                }
×
2702

2703
                return nil
×
2704
        }, sqldb.NoOpReset)
2705
        if err != nil {
×
2706
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2707
        }
×
2708

2709
        return nil
×
2710
}
2711

2712
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2713
// that we can ignore channel announcements that we know to be closed without
2714
// having to validate them and fetch a block.
2715
//
2716
// NOTE: part of the V1Store interface.
2717
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2718
        var (
×
2719
                ctx     = context.TODO()
×
2720
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2721
        )
×
2722

×
2723
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2724
                return db.InsertClosedChannel(ctx, chanIDB)
×
2725
        }, sqldb.NoOpReset)
×
2726
}
2727

2728
// IsClosedScid checks whether a channel identified by the passed in scid is
2729
// closed. This helps avoid having to perform expensive validation checks.
2730
//
2731
// NOTE: part of the V1Store interface.
2732
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2733
        var (
×
2734
                ctx      = context.TODO()
×
2735
                isClosed bool
×
2736
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2737
        )
×
2738
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2739
                var err error
×
2740
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2741
                if err != nil {
×
2742
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2743
                                err)
×
2744
                }
×
2745

2746
                return nil
×
2747
        }, sqldb.NoOpReset)
2748
        if err != nil {
×
2749
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2750
                        err)
×
2751
        }
×
2752

2753
        return isClosed, nil
×
2754
}
2755

2756
// GraphSession will provide the call-back with access to a NodeTraverser
2757
// instance which can be used to perform queries against the channel graph.
2758
//
2759
// NOTE: part of the V1Store interface.
2760
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
2761
        reset func()) error {
×
2762

×
2763
        var ctx = context.TODO()
×
2764

×
2765
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2766
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2767
        }, reset)
×
2768
}
2769

2770
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2771
// read only transaction for a consistent view of the graph.
2772
type sqlNodeTraverser struct {
2773
        db    SQLQueries
2774
        chain chainhash.Hash
2775
}
2776

2777
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2778
// NodeTraverser interface.
2779
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2780

2781
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2782
func newSQLNodeTraverser(db SQLQueries,
2783
        chain chainhash.Hash) *sqlNodeTraverser {
×
2784

×
2785
        return &sqlNodeTraverser{
×
2786
                db:    db,
×
2787
                chain: chain,
×
2788
        }
×
2789
}
×
2790

2791
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2792
// node.
2793
//
2794
// NOTE: Part of the NodeTraverser interface.
2795
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2796
        cb func(channel *DirectedChannel) error, _ func()) error {
×
2797

×
2798
        ctx := context.TODO()
×
2799

×
2800
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2801
}
×
2802

2803
// FetchNodeFeatures returns the features of the given node. If the node is
2804
// unknown, assume no additional features are supported.
2805
//
2806
// NOTE: Part of the NodeTraverser interface.
2807
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2808
        *lnwire.FeatureVector, error) {
×
2809

×
2810
        ctx := context.TODO()
×
2811

×
2812
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2813
}
×
2814

2815
// forEachNodeDirectedChannel iterates through all channels of a given
2816
// node, executing the passed callback on the directed edge representing the
2817
// channel and its incoming policy. If the node is not found, no error is
2818
// returned.
2819
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2820
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2821

×
2822
        toNodeCallback := func() route.Vertex {
×
2823
                return nodePub
×
2824
        }
×
2825

2826
        dbID, err := db.GetNodeIDByPubKey(
×
2827
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2828
                        Version: int16(ProtocolV1),
×
2829
                        PubKey:  nodePub[:],
×
2830
                },
×
2831
        )
×
2832
        if errors.Is(err, sql.ErrNoRows) {
×
2833
                return nil
×
2834
        } else if err != nil {
×
2835
                return fmt.Errorf("unable to fetch node: %w", err)
×
2836
        }
×
2837

2838
        rows, err := db.ListChannelsByNodeID(
×
2839
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2840
                        Version: int16(ProtocolV1),
×
2841
                        NodeID1: dbID,
×
2842
                },
×
2843
        )
×
2844
        if err != nil {
×
2845
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2846
        }
×
2847

2848
        // Exit early if there are no channels for this node so we don't
2849
        // do the unnecessary feature fetching.
2850
        if len(rows) == 0 {
×
2851
                return nil
×
2852
        }
×
2853

2854
        features, err := getNodeFeatures(ctx, db, dbID)
×
2855
        if err != nil {
×
2856
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2857
        }
×
2858

2859
        for _, row := range rows {
×
2860
                node1, node2, err := buildNodeVertices(
×
2861
                        row.Node1Pubkey, row.Node2Pubkey,
×
2862
                )
×
2863
                if err != nil {
×
2864
                        return fmt.Errorf("unable to build node vertices: %w",
×
2865
                                err)
×
2866
                }
×
2867

NEW
2868
                edge := buildCacheableChannelInfo(
×
NEW
2869
                        row.GraphChannel, node1, node2,
×
NEW
2870
                )
×
2871

×
2872
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2873
                if err != nil {
×
2874
                        return err
×
2875
                }
×
2876

2877
                var p1, p2 *models.CachedEdgePolicy
×
2878
                if dbPol1 != nil {
×
2879
                        policy1, err := buildChanPolicy(
×
2880
                                *dbPol1, edge.ChannelID, nil, node2,
×
2881
                        )
×
2882
                        if err != nil {
×
2883
                                return err
×
2884
                        }
×
2885

2886
                        p1 = models.NewCachedPolicy(policy1)
×
2887
                }
2888
                if dbPol2 != nil {
×
2889
                        policy2, err := buildChanPolicy(
×
2890
                                *dbPol2, edge.ChannelID, nil, node1,
×
2891
                        )
×
2892
                        if err != nil {
×
2893
                                return err
×
2894
                        }
×
2895

2896
                        p2 = models.NewCachedPolicy(policy2)
×
2897
                }
2898

2899
                // Determine the outgoing and incoming policy for this
2900
                // channel and node combo.
2901
                outPolicy, inPolicy := p1, p2
×
2902
                if p1 != nil && node2 == nodePub {
×
2903
                        outPolicy, inPolicy = p2, p1
×
2904
                } else if p2 != nil && node1 != nodePub {
×
2905
                        outPolicy, inPolicy = p2, p1
×
2906
                }
×
2907

2908
                var cachedInPolicy *models.CachedEdgePolicy
×
2909
                if inPolicy != nil {
×
2910
                        cachedInPolicy = inPolicy
×
2911
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2912
                        cachedInPolicy.ToNodeFeatures = features
×
2913
                }
×
2914

2915
                directedChannel := &DirectedChannel{
×
2916
                        ChannelID:    edge.ChannelID,
×
2917
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2918
                        OtherNode:    edge.NodeKey2Bytes,
×
2919
                        Capacity:     edge.Capacity,
×
2920
                        OutPolicySet: outPolicy != nil,
×
2921
                        InPolicy:     cachedInPolicy,
×
2922
                }
×
2923
                if outPolicy != nil {
×
2924
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2925
                                directedChannel.InboundFee = fee
×
2926
                        })
×
2927
                }
2928

2929
                if nodePub == edge.NodeKey2Bytes {
×
2930
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2931
                }
×
2932

2933
                if err := cb(directedChannel); err != nil {
×
2934
                        return err
×
2935
                }
×
2936
        }
2937

2938
        return nil
×
2939
}
2940

2941
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2942
// and executes the provided callback for each node.
2943
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
2944
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
2945

×
2946
        lastID := int64(-1)
×
2947

×
2948
        for {
×
2949
                nodes, err := db.ListNodeIDsAndPubKeys(
×
2950
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
2951
                                Version: int16(ProtocolV1),
×
2952
                                ID:      lastID,
×
2953
                                Limit:   pageSize,
×
2954
                        },
×
2955
                )
×
2956
                if err != nil {
×
2957
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
2958
                }
×
2959

2960
                if len(nodes) == 0 {
×
2961
                        break
×
2962
                }
2963

2964
                for _, node := range nodes {
×
2965
                        var pub route.Vertex
×
2966
                        copy(pub[:], node.PubKey)
×
2967

×
2968
                        if err := cb(node.ID, pub); err != nil {
×
2969
                                return fmt.Errorf("forEachNodeCacheable "+
×
2970
                                        "callback failed for node(id=%d): %w",
×
2971
                                        node.ID, err)
×
2972
                        }
×
2973

2974
                        lastID = node.ID
×
2975
                }
2976
        }
2977

2978
        return nil
×
2979
}
2980

2981
// forEachNodeChannel iterates through all channels of a node, executing
2982
// the passed callback on each. The call-back is provided with the channel's
2983
// edge information, the outgoing policy and the incoming policy for the
2984
// channel and node combo.
2985
func forEachNodeChannel(ctx context.Context, db SQLQueries,
2986
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
2987
                *models.ChannelEdgePolicy,
2988
                *models.ChannelEdgePolicy) error) error {
×
2989

×
2990
        // Get all the V1 channels for this node.Add commentMore actions
×
2991
        rows, err := db.ListChannelsByNodeID(
×
2992
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2993
                        Version: int16(ProtocolV1),
×
2994
                        NodeID1: id,
×
2995
                },
×
2996
        )
×
2997
        if err != nil {
×
2998
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2999
        }
×
3000

3001
        // Call the call-back for each channel and its known policies.
3002
        for _, row := range rows {
×
3003
                node1, node2, err := buildNodeVertices(
×
3004
                        row.Node1Pubkey, row.Node2Pubkey,
×
3005
                )
×
3006
                if err != nil {
×
3007
                        return fmt.Errorf("unable to build node vertices: %w",
×
3008
                                err)
×
3009
                }
×
3010

3011
                edge, err := getAndBuildEdgeInfo(
×
NEW
3012
                        ctx, db, chain, row.GraphChannel.ID, row.GraphChannel,
×
NEW
3013
                        node1, node2,
×
3014
                )
×
3015
                if err != nil {
×
3016
                        return fmt.Errorf("unable to build channel info: %w",
×
3017
                                err)
×
3018
                }
×
3019

3020
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3021
                if err != nil {
×
3022
                        return fmt.Errorf("unable to extract channel "+
×
3023
                                "policies: %w", err)
×
3024
                }
×
3025

3026
                p1, p2, err := getAndBuildChanPolicies(
×
3027
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3028
                )
×
3029
                if err != nil {
×
3030
                        return fmt.Errorf("unable to build channel "+
×
3031
                                "policies: %w", err)
×
3032
                }
×
3033

3034
                // Determine the outgoing and incoming policy for this
3035
                // channel and node combo.
NEW
3036
                p1ToNode := row.GraphChannel.NodeID2
×
NEW
3037
                p2ToNode := row.GraphChannel.NodeID1
×
3038
                outPolicy, inPolicy := p1, p2
×
3039
                if (p1 != nil && p1ToNode == id) ||
×
3040
                        (p2 != nil && p2ToNode != id) {
×
3041

×
3042
                        outPolicy, inPolicy = p2, p1
×
3043
                }
×
3044

3045
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3046
                        return err
×
3047
                }
×
3048
        }
3049

3050
        return nil
×
3051
}
3052

3053
// updateChanEdgePolicy upserts the channel policy info we have stored for
3054
// a channel we already know of.
3055
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3056
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3057
        error) {
×
3058

×
3059
        var (
×
3060
                node1Pub, node2Pub route.Vertex
×
3061
                isNode1            bool
×
3062
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3063
        )
×
3064

×
3065
        // Check that this edge policy refers to a channel that we already
×
3066
        // know of. We do this explicitly so that we can return the appropriate
×
3067
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3068
        // abort the transaction which would abort the entire batch.
×
3069
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3070
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3071
                        Scid:    chanIDB,
×
3072
                        Version: int16(ProtocolV1),
×
3073
                },
×
3074
        )
×
3075
        if errors.Is(err, sql.ErrNoRows) {
×
3076
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3077
        } else if err != nil {
×
3078
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3079
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3080
        }
×
3081

3082
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3083
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3084

×
3085
        // Figure out which node this edge is from.
×
3086
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3087
        nodeID := dbChan.NodeID1
×
3088
        if !isNode1 {
×
3089
                nodeID = dbChan.NodeID2
×
3090
        }
×
3091

3092
        var (
×
3093
                inboundBase sql.NullInt64
×
3094
                inboundRate sql.NullInt64
×
3095
        )
×
3096
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3097
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3098
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3099
        })
×
3100

3101
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3102
                Version:     int16(ProtocolV1),
×
3103
                ChannelID:   dbChan.ID,
×
3104
                NodeID:      nodeID,
×
3105
                Timelock:    int32(edge.TimeLockDelta),
×
3106
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3107
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3108
                MinHtlcMsat: int64(edge.MinHTLC),
×
3109
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3110
                Disabled: sql.NullBool{
×
3111
                        Valid: true,
×
3112
                        Bool:  edge.IsDisabled(),
×
3113
                },
×
3114
                MaxHtlcMsat: sql.NullInt64{
×
3115
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3116
                        Int64: int64(edge.MaxHTLC),
×
3117
                },
×
3118
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3119
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3120
                InboundBaseFeeMsat:      inboundBase,
×
3121
                InboundFeeRateMilliMsat: inboundRate,
×
3122
                Signature:               edge.SigBytes,
×
3123
        })
×
3124
        if err != nil {
×
3125
                return node1Pub, node2Pub, isNode1,
×
3126
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3127
        }
×
3128

3129
        // Convert the flat extra opaque data into a map of TLV types to
3130
        // values.
3131
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3132
        if err != nil {
×
3133
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3134
                        "marshal extra opaque data: %w", err)
×
3135
        }
×
3136

3137
        // Update the channel policy's extra signed fields.
3138
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3139
        if err != nil {
×
3140
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3141
                        "policy extra TLVs: %w", err)
×
3142
        }
×
3143

3144
        return node1Pub, node2Pub, isNode1, nil
×
3145
}
3146

3147
// getNodeByPubKey attempts to look up a target node by its public key.
3148
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3149
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3150

×
3151
        dbNode, err := db.GetNodeByPubKey(
×
3152
                ctx, sqlc.GetNodeByPubKeyParams{
×
3153
                        Version: int16(ProtocolV1),
×
3154
                        PubKey:  pubKey[:],
×
3155
                },
×
3156
        )
×
3157
        if errors.Is(err, sql.ErrNoRows) {
×
3158
                return 0, nil, ErrGraphNodeNotFound
×
3159
        } else if err != nil {
×
3160
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3161
        }
×
3162

3163
        node, err := buildNode(ctx, db, &dbNode)
×
3164
        if err != nil {
×
3165
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3166
        }
×
3167

3168
        return dbNode.ID, node, nil
×
3169
}
3170

3171
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3172
// provided database channel row and the public keys of the two nodes
3173
// involved in the channel.
3174
func buildCacheableChannelInfo(dbChan sqlc.GraphChannel, node1Pub,
3175
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3176

×
3177
        return &models.CachedEdgeInfo{
×
3178
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
3179
                NodeKey1Bytes: node1Pub,
×
3180
                NodeKey2Bytes: node2Pub,
×
3181
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
3182
        }
×
3183
}
×
3184

3185
// buildNode constructs a LightningNode instance from the given database node
3186
// record. The node's features, addresses and extra signed fields are also
3187
// fetched from the database and set on the node.
3188
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.GraphNode) (
3189
        *models.LightningNode, error) {
×
3190

×
3191
        if dbNode.Version != int16(ProtocolV1) {
×
3192
                return nil, fmt.Errorf("unsupported node version: %d",
×
3193
                        dbNode.Version)
×
3194
        }
×
3195

3196
        var pub [33]byte
×
3197
        copy(pub[:], dbNode.PubKey)
×
3198

×
3199
        node := &models.LightningNode{
×
3200
                PubKeyBytes: pub,
×
3201
                Features:    lnwire.EmptyFeatureVector(),
×
3202
                LastUpdate:  time.Unix(0, 0),
×
3203
        }
×
3204

×
3205
        if len(dbNode.Signature) == 0 {
×
3206
                return node, nil
×
3207
        }
×
3208

3209
        node.HaveNodeAnnouncement = true
×
3210
        node.AuthSigBytes = dbNode.Signature
×
3211
        node.Alias = dbNode.Alias.String
×
3212
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3213

×
3214
        var err error
×
3215
        if dbNode.Color.Valid {
×
3216
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3217
                if err != nil {
×
3218
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3219
                                err)
×
3220
                }
×
3221
        }
3222

3223
        // Fetch the node's features.
3224
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
3225
        if err != nil {
×
3226
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3227
                        "features: %w", dbNode.ID, err)
×
3228
        }
×
3229

3230
        // Fetch the node's addresses.
3231
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
3232
        if err != nil {
×
3233
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3234
                        "addresses: %w", dbNode.ID, err)
×
3235
        }
×
3236

3237
        // Fetch the node's extra signed fields.
3238
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
3239
        if err != nil {
×
3240
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3241
                        "extra signed fields: %w", dbNode.ID, err)
×
3242
        }
×
3243

3244
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
3245
        if err != nil {
×
3246
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3247
                        "fields: %w", err)
×
3248
        }
×
3249

3250
        if len(recs) != 0 {
×
3251
                node.ExtraOpaqueData = recs
×
3252
        }
×
3253

3254
        return node, nil
×
3255
}
3256

3257
// getNodeFeatures fetches the feature bits and constructs the feature vector
3258
// for a node with the given DB ID.
3259
func getNodeFeatures(ctx context.Context, db SQLQueries,
3260
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3261

×
3262
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3263
        if err != nil {
×
3264
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3265
                        nodeID, err)
×
3266
        }
×
3267

3268
        features := lnwire.EmptyFeatureVector()
×
3269
        for _, feature := range rows {
×
3270
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3271
        }
×
3272

3273
        return features, nil
×
3274
}
3275

3276
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3277
// given DB ID.
3278
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3279
        nodeID int64) (map[uint64][]byte, error) {
×
3280

×
3281
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3282
        if err != nil {
×
3283
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
3284
                        "signed fields: %w", nodeID, err)
×
3285
        }
×
3286

3287
        extraFields := make(map[uint64][]byte)
×
3288
        for _, field := range fields {
×
3289
                extraFields[uint64(field.Type)] = field.Value
×
3290
        }
×
3291

3292
        return extraFields, nil
×
3293
}
3294

3295
// upsertNode upserts the node record into the database. If the node already
3296
// exists, then the node's information is updated. If the node doesn't exist,
3297
// then a new node is created. The node's features, addresses and extra TLV
3298
// types are also updated. The node's DB ID is returned.
3299
func upsertNode(ctx context.Context, db SQLQueries,
3300
        node *models.LightningNode) (int64, error) {
×
3301

×
3302
        params := sqlc.UpsertNodeParams{
×
3303
                Version: int16(ProtocolV1),
×
3304
                PubKey:  node.PubKeyBytes[:],
×
3305
        }
×
3306

×
3307
        if node.HaveNodeAnnouncement {
×
3308
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3309
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3310
                params.Alias = sqldb.SQLStr(node.Alias)
×
3311
                params.Signature = node.AuthSigBytes
×
3312
        }
×
3313

3314
        nodeID, err := db.UpsertNode(ctx, params)
×
3315
        if err != nil {
×
3316
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3317
                        err)
×
3318
        }
×
3319

3320
        // We can exit here if we don't have the announcement yet.
3321
        if !node.HaveNodeAnnouncement {
×
3322
                return nodeID, nil
×
3323
        }
×
3324

3325
        // Update the node's features.
3326
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3327
        if err != nil {
×
3328
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3329
        }
×
3330

3331
        // Update the node's addresses.
3332
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3333
        if err != nil {
×
3334
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3335
        }
×
3336

3337
        // Convert the flat extra opaque data into a map of TLV types to
3338
        // values.
3339
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3340
        if err != nil {
×
3341
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3342
                        err)
×
3343
        }
×
3344

3345
        // Update the node's extra signed fields.
3346
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3347
        if err != nil {
×
3348
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3349
        }
×
3350

3351
        return nodeID, nil
×
3352
}
3353

3354
// upsertNodeFeatures updates the node's features node_features table. This
3355
// includes deleting any feature bits no longer present and inserting any new
3356
// feature bits. If the feature bit does not yet exist in the features table,
3357
// then an entry is created in that table first.
3358
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3359
        features *lnwire.FeatureVector) error {
×
3360

×
3361
        // Get any existing features for the node.
×
3362
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3363
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3364
                return err
×
3365
        }
×
3366

3367
        // Copy the nodes latest set of feature bits.
3368
        newFeatures := make(map[int32]struct{})
×
3369
        if features != nil {
×
3370
                for feature := range features.Features() {
×
3371
                        newFeatures[int32(feature)] = struct{}{}
×
3372
                }
×
3373
        }
3374

3375
        // For any current feature that already exists in the DB, remove it from
3376
        // the in-memory map. For any existing feature that does not exist in
3377
        // the in-memory map, delete it from the database.
3378
        for _, feature := range existingFeatures {
×
3379
                // The feature is still present, so there are no updates to be
×
3380
                // made.
×
3381
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3382
                        delete(newFeatures, feature.FeatureBit)
×
3383
                        continue
×
3384
                }
3385

3386
                // The feature is no longer present, so we remove it from the
3387
                // database.
3388
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3389
                        NodeID:     nodeID,
×
3390
                        FeatureBit: feature.FeatureBit,
×
3391
                })
×
3392
                if err != nil {
×
3393
                        return fmt.Errorf("unable to delete node(%d) "+
×
3394
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3395
                                err)
×
3396
                }
×
3397
        }
3398

3399
        // Any remaining entries in newFeatures are new features that need to be
3400
        // added to the database for the first time.
3401
        for feature := range newFeatures {
×
3402
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3403
                        NodeID:     nodeID,
×
3404
                        FeatureBit: feature,
×
3405
                })
×
3406
                if err != nil {
×
3407
                        return fmt.Errorf("unable to insert node(%d) "+
×
3408
                                "feature(%v): %w", nodeID, feature, err)
×
3409
                }
×
3410
        }
3411

3412
        return nil
×
3413
}
3414

3415
// fetchNodeFeatures fetches the features for a node with the given public key.
3416
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3417
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3418

×
3419
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3420
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3421
                        PubKey:  nodePub[:],
×
3422
                        Version: int16(ProtocolV1),
×
3423
                },
×
3424
        )
×
3425
        if err != nil {
×
3426
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3427
                        nodePub, err)
×
3428
        }
×
3429

3430
        features := lnwire.EmptyFeatureVector()
×
3431
        for _, bit := range rows {
×
3432
                features.Set(lnwire.FeatureBit(bit))
×
3433
        }
×
3434

3435
        return features, nil
×
3436
}
3437

3438
// dbAddressType is an enum type that represents the different address types
3439
// that we store in the node_addresses table. The address type determines how
3440
// the address is to be serialised/deserialize.
3441
type dbAddressType uint8
3442

3443
const (
3444
        addressTypeIPv4   dbAddressType = 1
3445
        addressTypeIPv6   dbAddressType = 2
3446
        addressTypeTorV2  dbAddressType = 3
3447
        addressTypeTorV3  dbAddressType = 4
3448
        addressTypeOpaque dbAddressType = math.MaxInt8
3449
)
3450

3451
// upsertNodeAddresses updates the node's addresses in the database. This
3452
// includes deleting any existing addresses and inserting the new set of
3453
// addresses. The deletion is necessary since the ordering of the addresses may
3454
// change, and we need to ensure that the database reflects the latest set of
3455
// addresses so that at the time of reconstructing the node announcement, the
3456
// order is preserved and the signature over the message remains valid.
3457
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3458
        addresses []net.Addr) error {
×
3459

×
3460
        // Delete any existing addresses for the node. This is required since
×
3461
        // even if the new set of addresses is the same, the ordering may have
×
3462
        // changed for a given address type.
×
3463
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3464
        if err != nil {
×
3465
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3466
                        nodeID, err)
×
3467
        }
×
3468

3469
        // Copy the nodes latest set of addresses.
3470
        newAddresses := map[dbAddressType][]string{
×
3471
                addressTypeIPv4:   {},
×
3472
                addressTypeIPv6:   {},
×
3473
                addressTypeTorV2:  {},
×
3474
                addressTypeTorV3:  {},
×
3475
                addressTypeOpaque: {},
×
3476
        }
×
3477
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3478
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3479
        }
×
3480

3481
        for _, address := range addresses {
×
3482
                switch addr := address.(type) {
×
3483
                case *net.TCPAddr:
×
3484
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3485
                                addAddr(addressTypeIPv4, addr)
×
3486
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3487
                                addAddr(addressTypeIPv6, addr)
×
3488
                        } else {
×
3489
                                return fmt.Errorf("unhandled IP address: %v",
×
3490
                                        addr)
×
3491
                        }
×
3492

3493
                case *tor.OnionAddr:
×
3494
                        switch len(addr.OnionService) {
×
3495
                        case tor.V2Len:
×
3496
                                addAddr(addressTypeTorV2, addr)
×
3497
                        case tor.V3Len:
×
3498
                                addAddr(addressTypeTorV3, addr)
×
3499
                        default:
×
3500
                                return fmt.Errorf("invalid length for a tor " +
×
3501
                                        "address")
×
3502
                        }
3503

3504
                case *lnwire.OpaqueAddrs:
×
3505
                        addAddr(addressTypeOpaque, addr)
×
3506

3507
                default:
×
3508
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3509
                }
3510
        }
3511

3512
        // Any remaining entries in newAddresses are new addresses that need to
3513
        // be added to the database for the first time.
3514
        for addrType, addrList := range newAddresses {
×
3515
                for position, addr := range addrList {
×
3516
                        err := db.InsertNodeAddress(
×
3517
                                ctx, sqlc.InsertNodeAddressParams{
×
3518
                                        NodeID:   nodeID,
×
3519
                                        Type:     int16(addrType),
×
3520
                                        Address:  addr,
×
3521
                                        Position: int32(position),
×
3522
                                },
×
3523
                        )
×
3524
                        if err != nil {
×
3525
                                return fmt.Errorf("unable to insert "+
×
3526
                                        "node(%d) address(%v): %w", nodeID,
×
3527
                                        addr, err)
×
3528
                        }
×
3529
                }
3530
        }
3531

3532
        return nil
×
3533
}
3534

3535
// getNodeAddresses fetches the addresses for a node with the given public key.
3536
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3537
        []net.Addr, error) {
×
3538

×
3539
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
3540
        // are returned in the same order as they were inserted.
×
3541
        rows, err := db.GetNodeAddressesByPubKey(
×
3542
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
3543
                        Version: int16(ProtocolV1),
×
3544
                        PubKey:  nodePub,
×
3545
                },
×
3546
        )
×
3547
        if err != nil {
×
3548
                return false, nil, err
×
3549
        }
×
3550

3551
        // GetNodeAddressesByPubKey uses a left join so there should always be
3552
        // at least one row returned if the node exists even if it has no
3553
        // addresses.
3554
        if len(rows) == 0 {
×
3555
                return false, nil, nil
×
3556
        }
×
3557

3558
        addresses := make([]net.Addr, 0, len(rows))
×
3559
        for _, addr := range rows {
×
3560
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3561
                        continue
×
3562
                }
3563

3564
                address := addr.Address.String
×
3565

×
3566
                switch dbAddressType(addr.Type.Int16) {
×
3567
                case addressTypeIPv4:
×
3568
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3569
                        if err != nil {
×
3570
                                return false, nil, nil
×
3571
                        }
×
3572
                        tcp.IP = tcp.IP.To4()
×
3573

×
3574
                        addresses = append(addresses, tcp)
×
3575

3576
                case addressTypeIPv6:
×
3577
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3578
                        if err != nil {
×
3579
                                return false, nil, nil
×
3580
                        }
×
3581
                        addresses = append(addresses, tcp)
×
3582

3583
                case addressTypeTorV3, addressTypeTorV2:
×
3584
                        service, portStr, err := net.SplitHostPort(address)
×
3585
                        if err != nil {
×
3586
                                return false, nil, fmt.Errorf("unable to "+
×
3587
                                        "split tor v3 address: %v",
×
3588
                                        addr.Address)
×
3589
                        }
×
3590

3591
                        port, err := strconv.Atoi(portStr)
×
3592
                        if err != nil {
×
3593
                                return false, nil, err
×
3594
                        }
×
3595

3596
                        addresses = append(addresses, &tor.OnionAddr{
×
3597
                                OnionService: service,
×
3598
                                Port:         port,
×
3599
                        })
×
3600

3601
                case addressTypeOpaque:
×
3602
                        opaque, err := hex.DecodeString(address)
×
3603
                        if err != nil {
×
3604
                                return false, nil, fmt.Errorf("unable to "+
×
3605
                                        "decode opaque address: %v", addr)
×
3606
                        }
×
3607

3608
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3609
                                Payload: opaque,
×
3610
                        })
×
3611

3612
                default:
×
3613
                        return false, nil, fmt.Errorf("unknown address "+
×
3614
                                "type: %v", addr.Type)
×
3615
                }
3616
        }
3617

3618
        // If we have no addresses, then we'll return nil instead of an
3619
        // empty slice.
3620
        if len(addresses) == 0 {
×
3621
                addresses = nil
×
3622
        }
×
3623

3624
        return true, addresses, nil
×
3625
}
3626

3627
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3628
// database. This includes updating any existing types, inserting any new types,
3629
// and deleting any types that are no longer present.
3630
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3631
        nodeID int64, extraFields map[uint64][]byte) error {
×
3632

×
3633
        // Get any existing extra signed fields for the node.
×
3634
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3635
        if err != nil {
×
3636
                return err
×
3637
        }
×
3638

3639
        // Make a lookup map of the existing field types so that we can use it
3640
        // to keep track of any fields we should delete.
3641
        m := make(map[uint64]bool)
×
3642
        for _, field := range existingFields {
×
3643
                m[uint64(field.Type)] = true
×
3644
        }
×
3645

3646
        // For all the new fields, we'll upsert them and remove them from the
3647
        // map of existing fields.
3648
        for tlvType, value := range extraFields {
×
3649
                err = db.UpsertNodeExtraType(
×
3650
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3651
                                NodeID: nodeID,
×
3652
                                Type:   int64(tlvType),
×
3653
                                Value:  value,
×
3654
                        },
×
3655
                )
×
3656
                if err != nil {
×
3657
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3658
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3659
                }
×
3660

3661
                // Remove the field from the map of existing fields if it was
3662
                // present.
3663
                delete(m, tlvType)
×
3664
        }
3665

3666
        // For all the fields that are left in the map of existing fields, we'll
3667
        // delete them as they are no longer present in the new set of fields.
3668
        for tlvType := range m {
×
3669
                err = db.DeleteExtraNodeType(
×
3670
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3671
                                NodeID: nodeID,
×
3672
                                Type:   int64(tlvType),
×
3673
                        },
×
3674
                )
×
3675
                if err != nil {
×
3676
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3677
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3678
                }
×
3679
        }
3680

3681
        return nil
×
3682
}
3683

3684
// srcNodeInfo holds the information about the source node of the graph.
3685
type srcNodeInfo struct {
3686
        // id is the DB level ID of the source node entry in the "nodes" table.
3687
        id int64
3688

3689
        // pub is the public key of the source node.
3690
        pub route.Vertex
3691
}
3692

3693
// sourceNode returns the DB node ID and pub key of the source node for the
3694
// specified protocol version.
3695
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3696
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3697

×
3698
        s.srcNodeMu.Lock()
×
3699
        defer s.srcNodeMu.Unlock()
×
3700

×
3701
        // If we already have the source node ID and pub key cached, then
×
3702
        // return them.
×
3703
        if info, ok := s.srcNodes[version]; ok {
×
3704
                return info.id, info.pub, nil
×
3705
        }
×
3706

3707
        var pubKey route.Vertex
×
3708

×
3709
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3710
        if err != nil {
×
3711
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3712
                        err)
×
3713
        }
×
3714

3715
        if len(nodes) == 0 {
×
3716
                return 0, pubKey, ErrSourceNodeNotSet
×
3717
        } else if len(nodes) > 1 {
×
3718
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3719
                        "protocol %s found", version)
×
3720
        }
×
3721

3722
        copy(pubKey[:], nodes[0].PubKey)
×
3723

×
3724
        s.srcNodes[version] = &srcNodeInfo{
×
3725
                id:  nodes[0].NodeID,
×
3726
                pub: pubKey,
×
3727
        }
×
3728

×
3729
        return nodes[0].NodeID, pubKey, nil
×
3730
}
3731

3732
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3733
// This then produces a map from TLV type to value. If the input is not a
3734
// valid TLV stream, then an error is returned.
3735
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3736
        r := bytes.NewReader(data)
×
3737

×
3738
        tlvStream, err := tlv.NewStream()
×
3739
        if err != nil {
×
3740
                return nil, err
×
3741
        }
×
3742

3743
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3744
        // pass it into the P2P decoding variant.
3745
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3746
        if err != nil {
×
3747
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3748
        }
×
3749
        if len(parsedTypes) == 0 {
×
3750
                return nil, nil
×
3751
        }
×
3752

3753
        records := make(map[uint64][]byte)
×
3754
        for k, v := range parsedTypes {
×
3755
                records[uint64(k)] = v
×
3756
        }
×
3757

3758
        return records, nil
×
3759
}
3760

3761
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3762
// channel.
3763
type dbChanInfo struct {
3764
        channelID int64
3765
        node1ID   int64
3766
        node2ID   int64
3767
}
3768

3769
// insertChannel inserts a new channel record into the database.
3770
func insertChannel(ctx context.Context, db SQLQueries,
3771
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3772

×
3773
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3774

×
3775
        // Make sure that the channel doesn't already exist. We do this
×
3776
        // explicitly instead of relying on catching a unique constraint error
×
3777
        // because relying on SQL to throw that error would abort the entire
×
3778
        // batch of transactions.
×
3779
        _, err := db.GetChannelBySCID(
×
3780
                ctx, sqlc.GetChannelBySCIDParams{
×
3781
                        Scid:    chanIDB,
×
3782
                        Version: int16(ProtocolV1),
×
3783
                },
×
3784
        )
×
3785
        if err == nil {
×
3786
                return nil, ErrEdgeAlreadyExist
×
3787
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3788
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3789
        }
×
3790

3791
        // Make sure that at least a "shell" entry for each node is present in
3792
        // the nodes table.
3793
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3794
        if err != nil {
×
3795
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3796
        }
×
3797

3798
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3799
        if err != nil {
×
3800
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3801
        }
×
3802

3803
        var capacity sql.NullInt64
×
3804
        if edge.Capacity != 0 {
×
3805
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3806
        }
×
3807

3808
        createParams := sqlc.CreateChannelParams{
×
3809
                Version:     int16(ProtocolV1),
×
3810
                Scid:        chanIDB,
×
3811
                NodeID1:     node1DBID,
×
3812
                NodeID2:     node2DBID,
×
3813
                Outpoint:    edge.ChannelPoint.String(),
×
3814
                Capacity:    capacity,
×
3815
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3816
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3817
        }
×
3818

×
3819
        if edge.AuthProof != nil {
×
3820
                proof := edge.AuthProof
×
3821

×
3822
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3823
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3824
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3825
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3826
        }
×
3827

3828
        // Insert the new channel record.
3829
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3830
        if err != nil {
×
3831
                return nil, err
×
3832
        }
×
3833

3834
        // Insert any channel features.
3835
        for feature := range edge.Features.Features() {
×
3836
                err = db.InsertChannelFeature(
×
3837
                        ctx, sqlc.InsertChannelFeatureParams{
×
3838
                                ChannelID:  dbChanID,
×
3839
                                FeatureBit: int32(feature),
×
3840
                        },
×
3841
                )
×
3842
                if err != nil {
×
3843
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
3844
                                "feature(%v): %w", dbChanID, feature, err)
×
3845
                }
×
3846
        }
3847

3848
        // Finally, insert any extra TLV fields in the channel announcement.
3849
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3850
        if err != nil {
×
3851
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
3852
                        "data: %w", err)
×
3853
        }
×
3854

3855
        for tlvType, value := range extra {
×
3856
                err := db.CreateChannelExtraType(
×
3857
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3858
                                ChannelID: dbChanID,
×
3859
                                Type:      int64(tlvType),
×
3860
                                Value:     value,
×
3861
                        },
×
3862
                )
×
3863
                if err != nil {
×
3864
                        return nil, fmt.Errorf("unable to upsert "+
×
3865
                                "channel(%d) extra signed field(%v): %w",
×
3866
                                edge.ChannelID, tlvType, err)
×
3867
                }
×
3868
        }
3869

3870
        return &dbChanInfo{
×
3871
                channelID: dbChanID,
×
3872
                node1ID:   node1DBID,
×
3873
                node2ID:   node2DBID,
×
3874
        }, nil
×
3875
}
3876

3877
// maybeCreateShellNode checks if a shell node entry exists for the
3878
// given public key. If it does not exist, then a new shell node entry is
3879
// created. The ID of the node is returned. A shell node only has a protocol
3880
// version and public key persisted.
3881
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3882
        pubKey route.Vertex) (int64, error) {
×
3883

×
3884
        dbNode, err := db.GetNodeByPubKey(
×
3885
                ctx, sqlc.GetNodeByPubKeyParams{
×
3886
                        PubKey:  pubKey[:],
×
3887
                        Version: int16(ProtocolV1),
×
3888
                },
×
3889
        )
×
3890
        // The node exists. Return the ID.
×
3891
        if err == nil {
×
3892
                return dbNode.ID, nil
×
3893
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3894
                return 0, err
×
3895
        }
×
3896

3897
        // Otherwise, the node does not exist, so we create a shell entry for
3898
        // it.
3899
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3900
                Version: int16(ProtocolV1),
×
3901
                PubKey:  pubKey[:],
×
3902
        })
×
3903
        if err != nil {
×
3904
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3905
        }
×
3906

3907
        return id, nil
×
3908
}
3909

3910
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3911
// the database. This includes deleting any existing types and then inserting
3912
// the new types.
3913
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3914
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3915

×
3916
        // Delete all existing extra signed fields for the channel policy.
×
3917
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3918
        if err != nil {
×
3919
                return fmt.Errorf("unable to delete "+
×
3920
                        "existing policy extra signed fields for policy %d: %w",
×
3921
                        chanPolicyID, err)
×
3922
        }
×
3923

3924
        // Insert all new extra signed fields for the channel policy.
3925
        for tlvType, value := range extraFields {
×
3926
                err = db.InsertChanPolicyExtraType(
×
3927
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3928
                                ChannelPolicyID: chanPolicyID,
×
3929
                                Type:            int64(tlvType),
×
3930
                                Value:           value,
×
3931
                        },
×
3932
                )
×
3933
                if err != nil {
×
3934
                        return fmt.Errorf("unable to insert "+
×
3935
                                "channel_policy(%d) extra signed field(%v): %w",
×
3936
                                chanPolicyID, tlvType, err)
×
3937
                }
×
3938
        }
3939

3940
        return nil
×
3941
}
3942

3943
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3944
// provided dbChanRow and also fetches any other required information
3945
// to construct the edge info.
3946
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
3947
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.GraphChannel, node1,
3948
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3949

×
3950
        if dbChan.Version != int16(ProtocolV1) {
×
3951
                return nil, fmt.Errorf("unsupported channel version: %d",
×
3952
                        dbChan.Version)
×
3953
        }
×
3954

3955
        fv, extras, err := getChanFeaturesAndExtras(
×
3956
                ctx, db, dbChanID,
×
3957
        )
×
3958
        if err != nil {
×
3959
                return nil, err
×
3960
        }
×
3961

3962
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
3963
        if err != nil {
×
3964
                return nil, err
×
3965
        }
×
3966

3967
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3968
        if err != nil {
×
3969
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3970
                        "fields: %w", err)
×
3971
        }
×
3972
        if recs == nil {
×
3973
                recs = make([]byte, 0)
×
3974
        }
×
3975

3976
        var btcKey1, btcKey2 route.Vertex
×
3977
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
3978
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
3979

×
3980
        channel := &models.ChannelEdgeInfo{
×
3981
                ChainHash:        chain,
×
3982
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
3983
                NodeKey1Bytes:    node1,
×
3984
                NodeKey2Bytes:    node2,
×
3985
                BitcoinKey1Bytes: btcKey1,
×
3986
                BitcoinKey2Bytes: btcKey2,
×
3987
                ChannelPoint:     *op,
×
3988
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
3989
                Features:         fv,
×
3990
                ExtraOpaqueData:  recs,
×
3991
        }
×
3992

×
3993
        // We always set all the signatures at the same time, so we can
×
3994
        // safely check if one signature is present to determine if we have the
×
3995
        // rest of the signatures for the auth proof.
×
3996
        if len(dbChan.Bitcoin1Signature) > 0 {
×
3997
                channel.AuthProof = &models.ChannelAuthProof{
×
3998
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
3999
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4000
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4001
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4002
                }
×
4003
        }
×
4004

4005
        return channel, nil
×
4006
}
4007

4008
// buildNodeVertices is a helper that converts raw node public keys
4009
// into route.Vertex instances.
4010
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4011
        route.Vertex, error) {
×
4012

×
4013
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4014
        if err != nil {
×
4015
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4016
                        "create vertex from node1 pubkey: %w", err)
×
4017
        }
×
4018

4019
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4020
        if err != nil {
×
4021
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4022
                        "create vertex from node2 pubkey: %w", err)
×
4023
        }
×
4024

4025
        return node1Vertex, node2Vertex, nil
×
4026
}
4027

4028
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
4029
// for a channel with the given ID.
4030
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
4031
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
4032

×
4033
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
4034
        if err != nil {
×
4035
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
4036
                        "features and extras: %w", err)
×
4037
        }
×
4038

4039
        var (
×
4040
                fv     = lnwire.EmptyFeatureVector()
×
4041
                extras = make(map[uint64][]byte)
×
4042
        )
×
4043
        for _, row := range rows {
×
4044
                if row.IsFeature {
×
4045
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
4046

×
4047
                        continue
×
4048
                }
4049

4050
                tlvType, ok := row.ExtraKey.(int64)
×
4051
                if !ok {
×
4052
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4053
                                "TLV type: %T", row.ExtraKey)
×
4054
                }
×
4055

4056
                valueBytes, ok := row.Value.([]byte)
×
4057
                if !ok {
×
4058
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4059
                                "Value: %T", row.Value)
×
4060
                }
×
4061

4062
                extras[uint64(tlvType)] = valueBytes
×
4063
        }
4064

4065
        return fv, extras, nil
×
4066
}
4067

4068
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4069
// retrieves all the extra info required to build the complete
4070
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4071
// the provided sqlc.GraphChannelPolicy records are nil.
4072
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4073
        dbPol1, dbPol2 *sqlc.GraphChannelPolicy, channelID uint64, node1,
4074
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4075
        *models.ChannelEdgePolicy, error) {
×
4076

×
4077
        if dbPol1 == nil && dbPol2 == nil {
×
4078
                return nil, nil, nil
×
4079
        }
×
4080

4081
        var (
×
4082
                policy1ID int64
×
4083
                policy2ID int64
×
4084
        )
×
4085
        if dbPol1 != nil {
×
4086
                policy1ID = dbPol1.ID
×
4087
        }
×
4088
        if dbPol2 != nil {
×
4089
                policy2ID = dbPol2.ID
×
4090
        }
×
4091
        rows, err := db.GetChannelPolicyExtraTypes(
×
4092
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
4093
                        ID:   policy1ID,
×
4094
                        ID_2: policy2ID,
×
4095
                },
×
4096
        )
×
4097
        if err != nil {
×
4098
                return nil, nil, err
×
4099
        }
×
4100

4101
        var (
×
4102
                dbPol1Extras = make(map[uint64][]byte)
×
4103
                dbPol2Extras = make(map[uint64][]byte)
×
4104
        )
×
4105
        for _, row := range rows {
×
4106
                switch row.PolicyID {
×
4107
                case policy1ID:
×
4108
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
4109
                case policy2ID:
×
4110
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
4111
                default:
×
4112
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
4113
                                "in row: %v", row.PolicyID, row)
×
4114
                }
4115
        }
4116

4117
        var pol1, pol2 *models.ChannelEdgePolicy
×
4118
        if dbPol1 != nil {
×
4119
                pol1, err = buildChanPolicy(
×
4120
                        *dbPol1, channelID, dbPol1Extras, node2,
×
4121
                )
×
4122
                if err != nil {
×
4123
                        return nil, nil, err
×
4124
                }
×
4125
        }
4126
        if dbPol2 != nil {
×
4127
                pol2, err = buildChanPolicy(
×
4128
                        *dbPol2, channelID, dbPol2Extras, node1,
×
4129
                )
×
4130
                if err != nil {
×
4131
                        return nil, nil, err
×
4132
                }
×
4133
        }
4134

4135
        return pol1, pol2, nil
×
4136
}
4137

4138
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4139
// provided sqlc.GraphChannelPolicy and other required information.
4140
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4141
        extras map[uint64][]byte,
4142
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4143

×
4144
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4145
        if err != nil {
×
4146
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4147
                        "fields: %w", err)
×
4148
        }
×
4149

4150
        var inboundFee fn.Option[lnwire.Fee]
×
4151
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4152
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4153

×
4154
                inboundFee = fn.Some(lnwire.Fee{
×
4155
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4156
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4157
                })
×
4158
        }
×
4159

4160
        return &models.ChannelEdgePolicy{
×
4161
                SigBytes:  dbPolicy.Signature,
×
4162
                ChannelID: channelID,
×
4163
                LastUpdate: time.Unix(
×
4164
                        dbPolicy.LastUpdate.Int64, 0,
×
4165
                ),
×
4166
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4167
                        dbPolicy.MessageFlags,
×
4168
                ),
×
4169
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4170
                        dbPolicy.ChannelFlags,
×
4171
                ),
×
4172
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4173
                MinHTLC: lnwire.MilliSatoshi(
×
4174
                        dbPolicy.MinHtlcMsat,
×
4175
                ),
×
4176
                MaxHTLC: lnwire.MilliSatoshi(
×
4177
                        dbPolicy.MaxHtlcMsat.Int64,
×
4178
                ),
×
4179
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4180
                        dbPolicy.BaseFeeMsat,
×
4181
                ),
×
4182
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4183
                ToNode:                    toNode,
×
4184
                InboundFee:                inboundFee,
×
4185
                ExtraOpaqueData:           recs,
×
4186
        }, nil
×
4187
}
4188

4189
// buildNodes builds the models.LightningNode instances for the
4190
// given row which is expected to be a sqlc type that contains node information.
4191
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4192
        dbNode2 sqlc.GraphNode) (*models.LightningNode, *models.LightningNode,
4193
        error) {
×
4194

×
4195
        node1, err := buildNode(ctx, db, &dbNode1)
×
4196
        if err != nil {
×
4197
                return nil, nil, err
×
4198
        }
×
4199

4200
        node2, err := buildNode(ctx, db, &dbNode2)
×
4201
        if err != nil {
×
4202
                return nil, nil, err
×
4203
        }
×
4204

4205
        return node1, node2, nil
×
4206
}
4207

4208
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4209
// row which is expected to be a sqlc type that contains channel policy
4210
// information. It returns two policies, which may be nil if the policy
4211
// information is not present in the row.
4212
//
4213
//nolint:ll,dupl,funlen
4214
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
NEW
4215
        *sqlc.GraphChannelPolicy, error) {
×
4216

×
NEW
4217
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4218
        switch r := row.(type) {
×
4219
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4220
                if r.Policy1ID.Valid {
×
NEW
4221
                        policy1 = &sqlc.GraphChannelPolicy{
×
4222
                                ID:                      r.Policy1ID.Int64,
×
4223
                                Version:                 r.Policy1Version.Int16,
×
NEW
4224
                                ChannelID:               r.GraphChannel.ID,
×
4225
                                NodeID:                  r.Policy1NodeID.Int64,
×
4226
                                Timelock:                r.Policy1Timelock.Int32,
×
4227
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4228
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4229
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4230
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4231
                                LastUpdate:              r.Policy1LastUpdate,
×
4232
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4233
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4234
                                Disabled:                r.Policy1Disabled,
×
4235
                                MessageFlags:            r.Policy1MessageFlags,
×
4236
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4237
                                Signature:               r.Policy1Signature,
×
4238
                        }
×
4239
                }
×
4240
                if r.Policy2ID.Valid {
×
NEW
4241
                        policy2 = &sqlc.GraphChannelPolicy{
×
4242
                                ID:                      r.Policy2ID.Int64,
×
4243
                                Version:                 r.Policy2Version.Int16,
×
NEW
4244
                                ChannelID:               r.GraphChannel.ID,
×
4245
                                NodeID:                  r.Policy2NodeID.Int64,
×
4246
                                Timelock:                r.Policy2Timelock.Int32,
×
4247
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4248
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4249
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4250
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4251
                                LastUpdate:              r.Policy2LastUpdate,
×
4252
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4253
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4254
                                Disabled:                r.Policy2Disabled,
×
4255
                                MessageFlags:            r.Policy2MessageFlags,
×
4256
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4257
                                Signature:               r.Policy2Signature,
×
4258
                        }
×
4259
                }
×
4260

4261
                return policy1, policy2, nil
×
4262

4263
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4264
                if r.Policy1ID.Valid {
×
NEW
4265
                        policy1 = &sqlc.GraphChannelPolicy{
×
4266
                                ID:                      r.Policy1ID.Int64,
×
4267
                                Version:                 r.Policy1Version.Int16,
×
NEW
4268
                                ChannelID:               r.GraphChannel.ID,
×
4269
                                NodeID:                  r.Policy1NodeID.Int64,
×
4270
                                Timelock:                r.Policy1Timelock.Int32,
×
4271
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4272
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4273
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4274
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4275
                                LastUpdate:              r.Policy1LastUpdate,
×
4276
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4277
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4278
                                Disabled:                r.Policy1Disabled,
×
4279
                                MessageFlags:            r.Policy1MessageFlags,
×
4280
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4281
                                Signature:               r.Policy1Signature,
×
4282
                        }
×
4283
                }
×
4284
                if r.Policy2ID.Valid {
×
NEW
4285
                        policy2 = &sqlc.GraphChannelPolicy{
×
4286
                                ID:                      r.Policy2ID.Int64,
×
4287
                                Version:                 r.Policy2Version.Int16,
×
NEW
4288
                                ChannelID:               r.GraphChannel.ID,
×
4289
                                NodeID:                  r.Policy2NodeID.Int64,
×
4290
                                Timelock:                r.Policy2Timelock.Int32,
×
4291
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4292
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4293
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4294
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4295
                                LastUpdate:              r.Policy2LastUpdate,
×
4296
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4297
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4298
                                Disabled:                r.Policy2Disabled,
×
4299
                                MessageFlags:            r.Policy2MessageFlags,
×
4300
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4301
                                Signature:               r.Policy2Signature,
×
4302
                        }
×
4303
                }
×
4304

4305
                return policy1, policy2, nil
×
4306

4307
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4308
                if r.Policy1ID.Valid {
×
NEW
4309
                        policy1 = &sqlc.GraphChannelPolicy{
×
4310
                                ID:                      r.Policy1ID.Int64,
×
4311
                                Version:                 r.Policy1Version.Int16,
×
NEW
4312
                                ChannelID:               r.GraphChannel.ID,
×
4313
                                NodeID:                  r.Policy1NodeID.Int64,
×
4314
                                Timelock:                r.Policy1Timelock.Int32,
×
4315
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4316
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4317
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4318
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4319
                                LastUpdate:              r.Policy1LastUpdate,
×
4320
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4321
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4322
                                Disabled:                r.Policy1Disabled,
×
4323
                                MessageFlags:            r.Policy1MessageFlags,
×
4324
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4325
                                Signature:               r.Policy1Signature,
×
4326
                        }
×
4327
                }
×
4328
                if r.Policy2ID.Valid {
×
NEW
4329
                        policy2 = &sqlc.GraphChannelPolicy{
×
4330
                                ID:                      r.Policy2ID.Int64,
×
4331
                                Version:                 r.Policy2Version.Int16,
×
NEW
4332
                                ChannelID:               r.GraphChannel.ID,
×
4333
                                NodeID:                  r.Policy2NodeID.Int64,
×
4334
                                Timelock:                r.Policy2Timelock.Int32,
×
4335
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4336
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4337
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4338
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4339
                                LastUpdate:              r.Policy2LastUpdate,
×
4340
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4341
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4342
                                Disabled:                r.Policy2Disabled,
×
4343
                                MessageFlags:            r.Policy2MessageFlags,
×
4344
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4345
                                Signature:               r.Policy2Signature,
×
4346
                        }
×
4347
                }
×
4348

4349
                return policy1, policy2, nil
×
4350

4351
        case sqlc.ListChannelsByNodeIDRow:
×
4352
                if r.Policy1ID.Valid {
×
NEW
4353
                        policy1 = &sqlc.GraphChannelPolicy{
×
4354
                                ID:                      r.Policy1ID.Int64,
×
4355
                                Version:                 r.Policy1Version.Int16,
×
NEW
4356
                                ChannelID:               r.GraphChannel.ID,
×
4357
                                NodeID:                  r.Policy1NodeID.Int64,
×
4358
                                Timelock:                r.Policy1Timelock.Int32,
×
4359
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4360
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4361
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4362
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4363
                                LastUpdate:              r.Policy1LastUpdate,
×
4364
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4365
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4366
                                Disabled:                r.Policy1Disabled,
×
4367
                                MessageFlags:            r.Policy1MessageFlags,
×
4368
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4369
                                Signature:               r.Policy1Signature,
×
4370
                        }
×
4371
                }
×
4372
                if r.Policy2ID.Valid {
×
NEW
4373
                        policy2 = &sqlc.GraphChannelPolicy{
×
4374
                                ID:                      r.Policy2ID.Int64,
×
4375
                                Version:                 r.Policy2Version.Int16,
×
NEW
4376
                                ChannelID:               r.GraphChannel.ID,
×
4377
                                NodeID:                  r.Policy2NodeID.Int64,
×
4378
                                Timelock:                r.Policy2Timelock.Int32,
×
4379
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4380
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4381
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4382
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4383
                                LastUpdate:              r.Policy2LastUpdate,
×
4384
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4385
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4386
                                Disabled:                r.Policy2Disabled,
×
4387
                                MessageFlags:            r.Policy2MessageFlags,
×
4388
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4389
                                Signature:               r.Policy2Signature,
×
4390
                        }
×
4391
                }
×
4392

4393
                return policy1, policy2, nil
×
4394

4395
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4396
                if r.Policy1ID.Valid {
×
NEW
4397
                        policy1 = &sqlc.GraphChannelPolicy{
×
4398
                                ID:                      r.Policy1ID.Int64,
×
4399
                                Version:                 r.Policy1Version.Int16,
×
NEW
4400
                                ChannelID:               r.GraphChannel.ID,
×
4401
                                NodeID:                  r.Policy1NodeID.Int64,
×
4402
                                Timelock:                r.Policy1Timelock.Int32,
×
4403
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4404
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4405
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4406
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4407
                                LastUpdate:              r.Policy1LastUpdate,
×
4408
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4409
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4410
                                Disabled:                r.Policy1Disabled,
×
4411
                                MessageFlags:            r.Policy1MessageFlags,
×
4412
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4413
                                Signature:               r.Policy1Signature,
×
4414
                        }
×
4415
                }
×
4416
                if r.Policy2ID.Valid {
×
NEW
4417
                        policy2 = &sqlc.GraphChannelPolicy{
×
4418
                                ID:                      r.Policy2ID.Int64,
×
4419
                                Version:                 r.Policy2Version.Int16,
×
NEW
4420
                                ChannelID:               r.GraphChannel.ID,
×
4421
                                NodeID:                  r.Policy2NodeID.Int64,
×
4422
                                Timelock:                r.Policy2Timelock.Int32,
×
4423
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4424
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4425
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4426
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4427
                                LastUpdate:              r.Policy2LastUpdate,
×
4428
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4429
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4430
                                Disabled:                r.Policy2Disabled,
×
4431
                                MessageFlags:            r.Policy2MessageFlags,
×
4432
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4433
                                Signature:               r.Policy2Signature,
×
4434
                        }
×
4435
                }
×
4436

4437
                return policy1, policy2, nil
×
4438
        default:
×
4439
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4440
                        "extractChannelPolicies: %T", r)
×
4441
        }
4442
}
4443

4444
// channelIDToBytes converts a channel ID (SCID) to a byte array
4445
// representation.
4446
func channelIDToBytes(channelID uint64) []byte {
×
4447
        var chanIDB [8]byte
×
4448
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4449

×
4450
        return chanIDB[:]
×
4451
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc