• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15606538456

12 Jun 2025 09:14AM UTC coverage: 67.414% (+9.1%) from 58.333%
15606538456

Pull #9932

github

web-flow
Merge 25e652669 into 35102e7c3
Pull Request #9932: [draft] graph/db+sqldb: graph store SQL implementation + migration

23 of 3319 new or added lines in 7 files covered. (0.69%)

39 existing lines in 8 files now uncovered.

134459 of 199453 relevant lines covered (67.41%)

21872.74 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "math"
11
        "net"
12
        "sort"
13
        "strconv"
14
        "sync"
15
        "time"
16

17
        "github.com/btcsuite/btcd/btcec/v2"
18
        "github.com/btcsuite/btcd/btcutil"
19
        "github.com/btcsuite/btcd/chaincfg/chainhash"
20
        "github.com/btcsuite/btcd/wire"
21
        "github.com/lightningnetwork/lnd/aliasmgr"
22
        "github.com/lightningnetwork/lnd/batch"
23
        "github.com/lightningnetwork/lnd/fn/v2"
24
        "github.com/lightningnetwork/lnd/graph/db/models"
25
        "github.com/lightningnetwork/lnd/lnwire"
26
        "github.com/lightningnetwork/lnd/routing/route"
27
        "github.com/lightningnetwork/lnd/sqldb"
28
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
29
        "github.com/lightningnetwork/lnd/tlv"
30
        "github.com/lightningnetwork/lnd/tor"
31
)
32

33
// ProtocolVersion is an enum that defines the gossip protocol version of a
34
// message.
35
type ProtocolVersion uint8
36

37
const (
38
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
39
        ProtocolV1 ProtocolVersion = 1
40
)
41

42
// String returns a string representation of the protocol version.
43
func (v ProtocolVersion) String() string {
×
44
        return fmt.Sprintf("V%d", v)
×
45
}
×
46

47
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
48
// execute queries against the SQL graph tables.
49
//
50
//nolint:ll,interfacebloat
51
type SQLQueries interface {
52
        /*
53
                Node queries.
54
        */
55
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
56
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
57
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
58
        ListNodes(ctx context.Context, version int16) ([]sqlc.Node, error)
59
        ListNodeIDsAndPubKeys(ctx context.Context, version int16) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
60
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
61
        GetUnconnectedNodes(ctx context.Context) ([]sqlc.GetUnconnectedNodesRow, error)
62
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
63

64
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
65
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
66
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
67

68
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
69
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
70
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
71

72
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
73
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
74
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
75
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
76

77
        /*
78
                Source node queries.
79
        */
80
        AddSourceNode(ctx context.Context, nodeID int64) error
81
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
82

83
        /*
84
                Channel queries.
85
        */
86
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
87
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) error
88
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
89
        GetChannelByOutpoint(ctx context.Context, arg sqlc.GetChannelByOutpointParams) (sqlc.GetChannelByOutpointRow, error)
90
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
91
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
92
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
93
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
94
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
95
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
96
        ListAllChannels(ctx context.Context, version int16) ([]sqlc.ListAllChannelsRow, error)
97
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
98
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
99
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
100
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
101
        DeleteChannel(ctx context.Context, id int64) error
102

103
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
104
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
105

106
        /*
107
                Channel Policy table queries.
108
        */
109
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
110
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.ChannelPolicy, error)
111
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
112

113
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
114
        DeleteChannelPolicyExtraType(ctx context.Context, arg sqlc.DeleteChannelPolicyExtraTypeParams) error
115
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
116

117
        /*
118
                Zombie index queries.
119
        */
120
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
121
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error)
122
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
123
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) error
124
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
125

126
        /*
127
                Prune log table queries.
128
        */
129
        GetPruneTip(ctx context.Context) (sqlc.PruneLog, error)
130
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
131
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
132

133
        /*
134
                Closed SCID table queries.
135
        */
136
        InsertClosedChannel(ctx context.Context, scid []byte) error
137
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
138
}
139

140
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
141
// database operations.
142
type BatchedSQLQueries interface {
143
        SQLQueries
144
        sqldb.BatchedTx[SQLQueries]
145
}
146

147
// SQLStore is an implementation of the V1Store interface that uses a SQL
148
// database as the backend.
149
//
150
// NOTE: currently, this temporarily embeds the KVStore struct so that we can
151
// implement the V1Store interface incrementally. For any method not
152
// implemented,  things will fall back to the KVStore. This is ONLY the case
153
// for the time being while this struct is purely used in unit tests only.
154
type SQLStore struct {
155
        cfg *SQLStoreConfig
156
        db  BatchedSQLQueries
157

158
        // cacheMu guards all caches (rejectCache and chanCache). If
159
        // this mutex will be acquired at the same time as the DB mutex then
160
        // the cacheMu MUST be acquired first to prevent deadlock.
161
        cacheMu     sync.RWMutex
162
        rejectCache *rejectCache
163
        chanCache   *channelCache
164

165
        chanScheduler batch.Scheduler[SQLQueries]
166
        nodeScheduler batch.Scheduler[SQLQueries]
167

168
        srcNodeID  int64
169
        srcNodePub route.Vertex
170
        srcNodeMu  sync.Mutex
171
}
172

173
// A compile-time assertion to ensure that SQLStore implements the V1Store
174
// interface.
175
var _ V1Store = (*SQLStore)(nil)
176

177
// SQLStoreConfig holds the configuration for the SQLStore.
178
type SQLStoreConfig struct {
179
        // ChainHash is the genesis hash for the chain that all the gossip
180
        // messages in this store are aimed at.
181
        ChainHash chainhash.Hash
182
}
183

184
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
185
// storage backend.
186
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
187
        options ...StoreOptionModifier) (*SQLStore, error) {
×
188

×
189
        opts := DefaultOptions()
×
190
        for _, o := range options {
×
191
                o(opts)
×
192
        }
×
193

194
        if opts.NoMigration {
×
195
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
196
                        "supported for SQL stores")
×
197
        }
×
198

199
        s := &SQLStore{
×
NEW
200
                cfg:         cfg,
×
201
                db:          db,
×
202
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
203
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
204
        }
×
205

×
206
        s.chanScheduler = batch.NewTimeScheduler(
×
207
                db, &s.cacheMu, opts.BatchCommitInterval,
×
208
        )
×
209
        s.nodeScheduler = batch.NewTimeScheduler(
×
210
                db, nil, opts.BatchCommitInterval,
×
211
        )
×
212

×
213
        return s, nil
×
214
}
215

216
// AddLightningNode adds a vertex/node to the graph database. If the node is not
217
// in the database from before, this will add a new, unconnected one to the
218
// graph. If it is present from before, this will update that node's
219
// information.
220
//
221
// NOTE: part of the V1Store interface.
222
func (s *SQLStore) AddLightningNode(node *models.LightningNode,
223
        opts ...batch.SchedulerOption) error {
×
224

×
225
        ctx := context.TODO()
×
226

×
227
        r := &batch.Request[SQLQueries]{
×
228
                Opts: batch.NewSchedulerOptions(opts...),
×
229
                Do: func(queries SQLQueries) error {
×
230
                        _, err := upsertNode(ctx, queries, node)
×
231
                        return err
×
232
                },
×
233
        }
234

235
        return s.nodeScheduler.Execute(ctx, r)
×
236
}
237

238
// FetchLightningNode attempts to look up a target node by its identity public
239
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
240
// returned.
241
//
242
// NOTE: part of the V1Store interface.
243
func (s *SQLStore) FetchLightningNode(pubKey route.Vertex) (
244
        *models.LightningNode, error) {
×
245

×
246
        ctx := context.TODO()
×
247

×
248
        var node *models.LightningNode
×
249
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
250
                var err error
×
251
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
252

×
253
                return err
×
254
        }, sqldb.NoOpReset)
×
255
        if err != nil {
×
256
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
257
        }
×
258

259
        return node, nil
×
260
}
261

262
// HasLightningNode determines if the graph has a vertex identified by the
263
// target node identity public key. If the node exists in the database, a
264
// timestamp of when the data for the node was lasted updated is returned along
265
// with a true boolean. Otherwise, an empty time.Time is returned with a false
266
// boolean.
267
//
268
// NOTE: part of the V1Store interface.
269
func (s *SQLStore) HasLightningNode(pubKey [33]byte) (time.Time, bool,
270
        error) {
×
271

×
272
        ctx := context.TODO()
×
273

×
274
        var (
×
275
                exists     bool
×
276
                lastUpdate time.Time
×
277
        )
×
278
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
279
                dbNode, err := db.GetNodeByPubKey(
×
280
                        ctx, sqlc.GetNodeByPubKeyParams{
×
281
                                Version: int16(ProtocolV1),
×
282
                                PubKey:  pubKey[:],
×
283
                        },
×
284
                )
×
285
                if errors.Is(err, sql.ErrNoRows) {
×
286
                        return nil
×
287
                } else if err != nil {
×
288
                        return fmt.Errorf("unable to fetch node: %w", err)
×
289
                }
×
290

291
                exists = true
×
292

×
293
                if dbNode.LastUpdate.Valid {
×
294
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
295
                }
×
296

297
                return nil
×
298
        }, sqldb.NoOpReset)
299
        if err != nil {
×
300
                return time.Time{}, false,
×
301
                        fmt.Errorf("unable to fetch node: %w", err)
×
302
        }
×
303

304
        return lastUpdate, exists, nil
×
305
}
306

307
// AddrsForNode returns all known addresses for the target node public key
308
// that the graph DB is aware of. The returned boolean indicates if the
309
// given node is unknown to the graph DB or not.
310
//
311
// NOTE: part of the V1Store interface.
312
func (s *SQLStore) AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr,
313
        error) {
×
314

×
315
        ctx := context.TODO()
×
316

×
317
        var (
×
318
                addresses []net.Addr
×
319
                known     bool
×
320
        )
×
321
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
322
                var err error
×
323
                known, addresses, err = getNodeAddresses(
×
324
                        ctx, db, nodePub.SerializeCompressed(),
×
325
                )
×
326
                if err != nil {
×
327
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
328
                                err)
×
329
                }
×
330

331
                return nil
×
332
        }, sqldb.NoOpReset)
333
        if err != nil {
×
334
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
335
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
336
        }
×
337

338
        return known, addresses, nil
×
339
}
340

341
// DeleteLightningNode starts a new database transaction to remove a vertex/node
342
// from the database according to the node's public key.
343
//
344
// NOTE: part of the V1Store interface.
345
func (s *SQLStore) DeleteLightningNode(pubKey route.Vertex) error {
×
346
        ctx := context.TODO()
×
347

×
348
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
349
                res, err := db.DeleteNodeByPubKey(
×
350
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
351
                                Version: int16(ProtocolV1),
×
352
                                PubKey:  pubKey[:],
×
353
                        },
×
354
                )
×
355
                if err != nil {
×
356
                        return err
×
357
                }
×
358

359
                rows, err := res.RowsAffected()
×
360
                if err != nil {
×
361
                        return err
×
362
                }
×
363

364
                if rows == 0 {
×
365
                        return ErrGraphNodeNotFound
×
366
                } else if rows > 1 {
×
367
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
368
                }
×
369

370
                return err
×
371
        }, sqldb.NoOpReset)
372
        if err != nil {
×
373
                return fmt.Errorf("unable to delete node: %w", err)
×
374
        }
×
375

376
        return nil
×
377
}
378

379
// FetchNodeFeatures returns the features of the given node. If no features are
380
// known for the node, an empty feature vector is returned.
381
//
382
// NOTE: this is part of the graphdb.NodeTraverser interface.
383
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
384
        *lnwire.FeatureVector, error) {
×
385

×
386
        ctx := context.TODO()
×
387

×
388
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
389
}
×
390

391
// DisabledChannelIDs returns the channel ids of disabled channels.
392
// A channel is disabled when two of the associated ChanelEdgePolicies
393
// have their disabled bit on.
394
//
395
// NOTE: part of the V1Store interface.
NEW
396
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
NEW
397
        var (
×
NEW
398
                ctx     = context.TODO()
×
NEW
399
                chanIDs []uint64
×
NEW
400
        )
×
NEW
401
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
402
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
NEW
403
                if err != nil {
×
NEW
404
                        return fmt.Errorf("unable to fetch disabled "+
×
NEW
405
                                "channels: %w", err)
×
NEW
406
                }
×
407

NEW
408
                for _, dbChanID := range dbChanIDs {
×
NEW
409
                        chanIDs = append(chanIDs, byteOrder.Uint64(dbChanID))
×
NEW
410
                }
×
411

NEW
412
                return nil
×
NEW
413
        }, func() {})
×
NEW
414
        if err != nil {
×
NEW
415
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
NEW
416
                        err)
×
NEW
417
        }
×
418

NEW
419
        return chanIDs, nil
×
420
}
421

422
// LookupAlias attempts to return the alias as advertised by the target node.
423
//
424
// NOTE: part of the V1Store interface.
425
func (s *SQLStore) LookupAlias(pub *btcec.PublicKey) (string, error) {
×
426
        var (
×
427
                ctx   = context.TODO()
×
428
                alias string
×
429
        )
×
430
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
431
                dbNode, err := db.GetNodeByPubKey(
×
432
                        ctx, sqlc.GetNodeByPubKeyParams{
×
433
                                Version: int16(ProtocolV1),
×
434
                                PubKey:  pub.SerializeCompressed(),
×
435
                        },
×
436
                )
×
437
                if errors.Is(err, sql.ErrNoRows) {
×
438
                        return ErrNodeAliasNotFound
×
439
                } else if err != nil {
×
440
                        return fmt.Errorf("unable to fetch node: %w", err)
×
441
                }
×
442

443
                if !dbNode.Alias.Valid {
×
444
                        return ErrNodeAliasNotFound
×
445
                }
×
446

447
                alias = dbNode.Alias.String
×
448

×
449
                return nil
×
450
        }, sqldb.NoOpReset)
451
        if err != nil {
×
452
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
453
        }
×
454

455
        return alias, nil
×
456
}
457

458
// SourceNode returns the source node of the graph. The source node is treated
459
// as the center node within a star-graph. This method may be used to kick off
460
// a path finding algorithm in order to explore the reachability of another
461
// node based off the source node.
462
//
463
// NOTE: part of the V1Store interface.
464
func (s *SQLStore) SourceNode() (*models.LightningNode, error) {
×
465
        ctx := context.TODO()
×
466

×
467
        var node *models.LightningNode
×
468
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
469
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
470
                if err != nil {
×
471
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
472
                                err)
×
473
                }
×
474

475
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
476

×
477
                return err
×
478
        }, sqldb.NoOpReset)
479
        if err != nil {
×
480
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
481
        }
×
482

483
        return node, nil
×
484
}
485

486
// SetSourceNode sets the source node within the graph database. The source
487
// node is to be used as the center of a star-graph within path finding
488
// algorithms.
489
//
490
// NOTE: part of the V1Store interface.
491
func (s *SQLStore) SetSourceNode(node *models.LightningNode) error {
×
492
        ctx := context.TODO()
×
493

×
494
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
495
                id, err := upsertNode(ctx, db, node)
×
496
                if err != nil {
×
497
                        return fmt.Errorf("unable to upsert source node: %w",
×
498
                                err)
×
499
                }
×
500

501
                // Make sure that if a source node for this version is already
502
                // set, then the ID is the same as the one we are about to set.
NEW
503
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
504
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
505
                        return fmt.Errorf("unable to fetch source node: %w",
×
506
                                err)
×
507
                } else if err == nil {
×
508
                        if dbSourceNodeID != id {
×
509
                                return fmt.Errorf("v1 source node already "+
×
510
                                        "set to a different node: %d vs %d",
×
511
                                        dbSourceNodeID, id)
×
512
                        }
×
513

514
                        return nil
×
515
                }
516

517
                return db.AddSourceNode(ctx, id)
×
518
        }, sqldb.NoOpReset)
519
}
520

521
// NodeUpdatesInHorizon returns all the known lightning node which have an
522
// update timestamp within the passed range. This method can be used by two
523
// nodes to quickly determine if they have the same set of up to date node
524
// announcements.
525
//
526
// NOTE: This is part of the V1Store interface.
527
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
528
        endTime time.Time) ([]models.LightningNode, error) {
×
529

×
530
        ctx := context.TODO()
×
531

×
532
        var nodes []models.LightningNode
×
533
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
534
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
535
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
536
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
537
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
538
                        },
×
539
                )
×
540
                if err != nil {
×
541
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
542
                }
×
543

544
                for _, dbNode := range dbNodes {
×
545
                        node, err := buildNode(ctx, db, &dbNode)
×
546
                        if err != nil {
×
547
                                return fmt.Errorf("unable to build node: %w",
×
548
                                        err)
×
549
                        }
×
550

551
                        nodes = append(nodes, *node)
×
552
                }
553

554
                return nil
×
555
        }, sqldb.NoOpReset)
556
        if err != nil {
×
557
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
558
        }
×
559

560
        return nodes, nil
×
561
}
562

563
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
564
// undirected edge from the two target nodes are created. The information stored
565
// denotes the static attributes of the channel, such as the channelID, the keys
566
// involved in creation of the channel, and the set of features that the channel
567
// supports. The chanPoint and chanID are used to uniquely identify the edge
568
// globally within the database.
569
//
570
// NOTE: part of the V1Store interface.
571
func (s *SQLStore) AddChannelEdge(edge *models.ChannelEdgeInfo,
572
        opts ...batch.SchedulerOption) error {
×
573

×
574
        ctx := context.TODO()
×
575

×
576
        var alreadyExists bool
×
577
        r := &batch.Request[SQLQueries]{
×
578
                Opts: batch.NewSchedulerOptions(opts...),
×
579
                Reset: func() {
×
580
                        alreadyExists = false
×
581
                },
×
582
                Do: func(tx SQLQueries) error {
×
NEW
583
                        _, err := insertChannel(ctx, tx, edge)
×
584

×
585
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
586
                        // succeed, but propagate the error via local state.
×
587
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
588
                                alreadyExists = true
×
589
                                return nil
×
590
                        }
×
591

592
                        return err
×
593
                },
594
                OnCommit: func(err error) error {
×
595
                        switch {
×
596
                        case err != nil:
×
597
                                return err
×
598
                        case alreadyExists:
×
599
                                return ErrEdgeAlreadyExist
×
600
                        default:
×
601
                                s.rejectCache.remove(edge.ChannelID)
×
602
                                s.chanCache.remove(edge.ChannelID)
×
603
                                return nil
×
604
                        }
605
                },
606
        }
607

NEW
608
        return s.chanScheduler.Execute(ctx, r)
×
609
}
610

611
// HighestChanID returns the "highest" known channel ID in the channel graph.
612
// This represents the "newest" channel from the PoV of the chain. This method
613
// can be used by peers to quickly determine if their graphs are in sync.
614
//
615
// NOTE: This is part of the V1Store interface.
NEW
616
func (s *SQLStore) HighestChanID() (uint64, error) {
×
NEW
617
        ctx := context.TODO()
×
NEW
618

×
NEW
619
        var highestChanID uint64
×
NEW
620
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
621
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
NEW
622
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
623
                        return nil
×
NEW
624
                } else if err != nil {
×
NEW
625
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
NEW
626
                                err)
×
NEW
627
                }
×
628

NEW
629
                highestChanID = byteOrder.Uint64(chanID)
×
NEW
630

×
NEW
631
                return nil
×
632
        }, sqldb.NoOpReset)
NEW
633
        if err != nil {
×
NEW
634
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
NEW
635
        }
×
636

NEW
637
        return highestChanID, nil
×
638
}
639

640
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
641
// within the database for the referenced channel. The `flags` attribute within
642
// the ChannelEdgePolicy determines which of the directed edges are being
643
// updated. If the flag is 1, then the first node's information is being
644
// updated, otherwise it's the second node's information. The node ordering is
645
// determined by the lexicographical ordering of the identity public keys of the
646
// nodes on either side of the channel.
647
//
648
// NOTE: part of the V1Store interface.
649
func (s *SQLStore) UpdateEdgePolicy(edge *models.ChannelEdgePolicy,
NEW
650
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
NEW
651

×
NEW
652
        ctx := context.TODO()
×
NEW
653

×
NEW
654
        var (
×
NEW
655
                isUpdate1    bool
×
NEW
656
                edgeNotFound bool
×
NEW
657
                from, to     route.Vertex
×
NEW
658
        )
×
NEW
659

×
NEW
660
        r := &batch.Request[SQLQueries]{
×
NEW
661
                Opts: batch.NewSchedulerOptions(opts...),
×
NEW
662
                Reset: func() {
×
NEW
663
                        isUpdate1 = false
×
NEW
664
                        edgeNotFound = false
×
NEW
665
                },
×
NEW
666
                Do: func(tx SQLQueries) error {
×
NEW
667
                        var err error
×
NEW
668
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
NEW
669
                                ctx, tx, edge,
×
NEW
670
                        )
×
NEW
671
                        if err != nil {
×
NEW
672
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
NEW
673
                        }
×
674

675
                        // Silence ErrEdgeNotFound so that the batch can
676
                        // succeed, but propagate the error via local state.
NEW
677
                        if errors.Is(err, ErrEdgeNotFound) {
×
NEW
678
                                edgeNotFound = true
×
NEW
679
                                return nil
×
NEW
680
                        }
×
681

NEW
682
                        return err
×
683
                },
NEW
684
                OnCommit: func(err error) error {
×
NEW
685
                        switch {
×
NEW
686
                        case err != nil:
×
NEW
687
                                return err
×
NEW
688
                        case edgeNotFound:
×
NEW
689
                                return ErrEdgeNotFound
×
NEW
690
                        default:
×
NEW
691
                                s.updateEdgeCache(edge, isUpdate1)
×
NEW
692
                                return nil
×
693
                        }
694
                },
695
        }
696

NEW
697
        err := s.chanScheduler.Execute(ctx, r)
×
NEW
698

×
NEW
699
        return from, to, err
×
700
}
701

702
// updateEdgeCache updates our reject and channel caches with the new
703
// edge policy information.
704
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
NEW
705
        isUpdate1 bool) {
×
NEW
706

×
NEW
707
        // If an entry for this channel is found in reject cache, we'll modify
×
NEW
708
        // the entry with the updated timestamp for the direction that was just
×
NEW
709
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
NEW
710
        // during the next query for this edge.
×
NEW
711
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
NEW
712
                if isUpdate1 {
×
NEW
713
                        entry.upd1Time = e.LastUpdate.Unix()
×
NEW
714
                } else {
×
NEW
715
                        entry.upd2Time = e.LastUpdate.Unix()
×
NEW
716
                }
×
NEW
717
                s.rejectCache.insert(e.ChannelID, entry)
×
718
        }
719

720
        // If an entry for this channel is found in channel cache, we'll modify
721
        // the entry with the updated policy for the direction that was just
722
        // written. If the edge doesn't exist, we'll defer loading the info and
723
        // policies and lazily read from disk during the next query.
NEW
724
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
NEW
725
                if isUpdate1 {
×
NEW
726
                        channel.Policy1 = e
×
NEW
727
                } else {
×
NEW
728
                        channel.Policy2 = e
×
NEW
729
                }
×
NEW
730
                s.chanCache.insert(e.ChannelID, channel)
×
731
        }
732
}
733

734
// ForEachSourceNodeChannel iterates through all channels of the source node,
735
// executing the passed callback on each. The call-back is provided with the
736
// channel's outpoint, whether we have a policy for the channel and the channel
737
// peer's node information.
738
//
739
// NOTE: part of the V1Store interface.
740
func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint,
NEW
741
        havePolicy bool, otherNode *models.LightningNode) error) error {
×
NEW
742

×
NEW
743
        var ctx = context.TODO()
×
NEW
744

×
NEW
745
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
746
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
NEW
747
                if err != nil {
×
NEW
748
                        return fmt.Errorf("unable to fetch source node: %w",
×
NEW
749
                                err)
×
NEW
750
                }
×
751

NEW
752
                return forEachNodeChannel(
×
NEW
753
                        ctx, db, s.cfg.ChainHash, nodeID,
×
NEW
754
                        func(info *models.ChannelEdgeInfo,
×
NEW
755
                                outPolicy *models.ChannelEdgePolicy,
×
NEW
756
                                _ *models.ChannelEdgePolicy) error {
×
NEW
757

×
NEW
758
                                // Fetch the other node.
×
NEW
759
                                var (
×
NEW
760
                                        otherNodePub [33]byte
×
NEW
761
                                        node1        = info.NodeKey1Bytes
×
NEW
762
                                        node2        = info.NodeKey2Bytes
×
NEW
763
                                )
×
NEW
764
                                switch {
×
NEW
765
                                case bytes.Equal(node1[:], nodePub[:]):
×
NEW
766
                                        otherNodePub = node2
×
NEW
767
                                case bytes.Equal(node2[:], nodePub[:]):
×
NEW
768
                                        otherNodePub = node1
×
NEW
769
                                default:
×
NEW
770
                                        return fmt.Errorf("node not " +
×
NEW
771
                                                "participating in this channel")
×
772
                                }
773

NEW
774
                                _, otherNode, err := getNodeByPubKey(
×
NEW
775
                                        ctx, db, otherNodePub,
×
NEW
776
                                )
×
NEW
777
                                if err != nil {
×
NEW
778
                                        return fmt.Errorf("unable to fetch "+
×
NEW
779
                                                "other node(%x): %w",
×
NEW
780
                                                otherNodePub, err)
×
NEW
781
                                }
×
782

NEW
783
                                return cb(
×
NEW
784
                                        info.ChannelPoint, outPolicy != nil,
×
NEW
785
                                        otherNode,
×
NEW
786
                                )
×
787
                        },
788
                )
NEW
789
        }, func() {})
×
790
}
791

792
// ForEachNode iterates through all the stored vertices/nodes in the graph,
793
// executing the passed callback with each node encountered. If the callback
794
// returns an error, then the transaction is aborted and the iteration stops
795
// early. Any operations performed on the NodeTx passed to the call-back are
796
// executed under the same read transaction and so, methods on the NodeTx object
797
// _MUST_ only be called from within the call-back.
798
//
799
// NOTE: part of the V1Store interface.
NEW
800
func (s *SQLStore) ForEachNode(cb func(tx NodeRTx) error) error {
×
NEW
801
        var ctx = context.TODO()
×
NEW
802

×
NEW
803
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
804
                nodes, err := db.ListNodes(ctx, int16(ProtocolV1))
×
NEW
805
                if err != nil {
×
NEW
806
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
NEW
807
                }
×
808

NEW
809
                for _, dbNode := range nodes {
×
NEW
810
                        node, err := buildNode(ctx, db, &dbNode)
×
NEW
811
                        if err != nil {
×
NEW
812
                                return fmt.Errorf("unable to build "+
×
NEW
813
                                        "node(id=%d): %w", dbNode.ID, err)
×
NEW
814
                        }
×
815

NEW
816
                        err = cb(newSQLGraphNodeTx(
×
NEW
817
                                db, s.cfg.ChainHash, dbNode.ID, node,
×
NEW
818
                        ))
×
NEW
819
                        if err != nil {
×
NEW
820
                                return fmt.Errorf("unable to execute "+
×
NEW
821
                                        "callback for node(id=%d): %w",
×
NEW
822
                                        dbNode.ID, err)
×
NEW
823
                        }
×
824
                }
825

NEW
826
                return nil
×
827
        }, sqldb.NoOpReset)
828
}
829

830
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
831
// SQLStore and a SQL transaction.
832
type sqlGraphNodeTx struct {
833
        db    SQLQueries
834
        id    int64
835
        node  *models.LightningNode
836
        chain chainhash.Hash
837
}
838

839
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
840
// interface.
841
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
842

843
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
NEW
844
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
NEW
845

×
NEW
846
        return &sqlGraphNodeTx{
×
NEW
847
                db:    db,
×
NEW
848
                chain: chain,
×
NEW
849
                id:    id,
×
NEW
850
                node:  node,
×
NEW
851
        }
×
NEW
852
}
×
853

854
// Node returns the raw information of the node.
855
//
856
// NOTE: This is a part of the NodeRTx interface.
NEW
857
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
NEW
858
        return s.node
×
NEW
859
}
×
860

861
// ForEachChannel can be used to iterate over the node's channels under the same
862
// transaction used to fetch the node.
863
//
864
// NOTE: This is a part of the NodeRTx interface.
865
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
NEW
866
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
NEW
867

×
NEW
868
        ctx := context.TODO()
×
NEW
869

×
NEW
870
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
NEW
871
}
×
872

873
// FetchNode fetches the node with the given pub key under the same transaction
874
// used to fetch the current node. The returned node is also a NodeRTx and any
875
// operations on that NodeRTx will also be done under the same transaction.
876
//
877
// NOTE: This is a part of the NodeRTx interface.
NEW
878
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
NEW
879
        ctx := context.TODO()
×
NEW
880

×
NEW
881
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
NEW
882
        if err != nil {
×
NEW
883
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
NEW
884
                        nodePub, err)
×
NEW
885
        }
×
886

NEW
887
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
888
}
889

890
// ForEachNodeDirectedChannel iterates through all channels of a given node,
891
// executing the passed callback on the directed edge representing the channel
892
// and its incoming policy. If the callback returns an error, then the iteration
893
// is halted with the error propagated back up to the caller.
894
//
895
// Unknown policies are passed into the callback as nil values.
896
//
897
// NOTE: this is part of the graphdb.NodeTraverser interface.
898
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
NEW
899
        cb func(channel *DirectedChannel) error) error {
×
NEW
900

×
NEW
901
        var ctx = context.TODO()
×
NEW
902

×
NEW
903
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
904
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
NEW
905
        }, func() {})
×
906
}
907

908
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
909
// graph, executing the passed callback with each node encountered. If the
910
// callback returns an error, then the transaction is aborted and the iteration
911
// stops early.
912
//
913
// NOTE: This is a part of the V1Store interface.
914
func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex,
NEW
915
        *lnwire.FeatureVector) error) error {
×
NEW
916

×
NEW
917
        ctx := context.TODO()
×
NEW
918

×
NEW
919
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
920
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
NEW
921
                        nodePub route.Vertex) error {
×
NEW
922

×
NEW
923
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
NEW
924
                        if err != nil {
×
NEW
925
                                return fmt.Errorf("unable to fetch node "+
×
NEW
926
                                        "features: %w", err)
×
NEW
927
                        }
×
928

NEW
929
                        return cb(nodePub, features)
×
930
                })
NEW
931
        }, func() {})
×
NEW
932
        if err != nil {
×
NEW
933
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
NEW
934
        }
×
935

NEW
936
        return nil
×
937
}
938

939
// ForEachNodeChannel iterates through all channels of the given node,
940
// executing the passed callback with an edge info structure and the policies
941
// of each end of the channel. The first edge policy is the outgoing edge *to*
942
// the connecting node, while the second is the incoming edge *from* the
943
// connecting node. If the callback returns an error, then the iteration is
944
// halted with the error propagated back up to the caller.
945
//
946
// Unknown policies are passed into the callback as nil values.
947
//
948
// NOTE: part of the V1Store interface.
949
func (s *SQLStore) ForEachNodeChannel(nodePub route.Vertex,
950
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
NEW
951
                *models.ChannelEdgePolicy) error) error {
×
NEW
952

×
NEW
953
        var ctx = context.TODO()
×
NEW
954

×
NEW
955
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
956
                dbNode, err := db.GetNodeByPubKey(
×
NEW
957
                        ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
958
                                Version: int16(ProtocolV1),
×
NEW
959
                                PubKey:  nodePub[:],
×
NEW
960
                        },
×
NEW
961
                )
×
NEW
962
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
963
                        return nil
×
NEW
964
                } else if err != nil {
×
NEW
965
                        return fmt.Errorf("unable to fetch node: %w", err)
×
NEW
966
                }
×
967

NEW
968
                return forEachNodeChannel(
×
NEW
969
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
NEW
970
                )
×
NEW
971
        }, func() {})
×
972
}
973

974
// ChanUpdatesInHorizon returns all the known channel edges which have at least
975
// one edge that has an update timestamp within the specified horizon.
976
//
977
// NOTE: This is part of the V1Store interface.
978
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
NEW
979
        endTime time.Time) ([]ChannelEdge, error) {
×
NEW
980

×
NEW
981
        s.cacheMu.Lock()
×
NEW
982
        defer s.cacheMu.Unlock()
×
NEW
983

×
NEW
984
        var (
×
NEW
985
                ctx = context.TODO()
×
NEW
986
                // To ensure we don't return duplicate ChannelEdges, we'll use an
×
NEW
987
                // additional map to keep track of the edges already seen to prevent
×
NEW
988
                // re-adding it.
×
NEW
989
                edgesSeen    = make(map[uint64]struct{})
×
NEW
990
                edgesToCache = make(map[uint64]ChannelEdge)
×
NEW
991
                edges        []ChannelEdge
×
NEW
992
                hits         int
×
NEW
993
        )
×
NEW
994
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
995
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
NEW
996
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
NEW
997
                                Version: int16(ProtocolV1),
×
NEW
998
                                StartTime: sql.NullInt64{
×
NEW
999
                                        Valid: true,
×
NEW
1000
                                        Int64: startTime.Unix(),
×
NEW
1001
                                },
×
NEW
1002
                                EndTime: sql.NullInt64{
×
NEW
1003
                                        Valid: true,
×
NEW
1004
                                        Int64: endTime.Unix(),
×
NEW
1005
                                },
×
NEW
1006
                        },
×
NEW
1007
                )
×
NEW
1008
                if err != nil {
×
NEW
1009
                        return err
×
NEW
1010
                }
×
1011

NEW
1012
                for _, row := range rows {
×
NEW
1013
                        // If we've already retrieved the info and policies for
×
NEW
1014
                        // this edge, then we can skip it as we don't need to do
×
NEW
1015
                        // so again.
×
NEW
1016
                        chanIDInt := byteOrder.Uint64(row.Scid)
×
NEW
1017
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
NEW
1018
                                continue
×
1019
                        }
1020

NEW
1021
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
NEW
1022
                                hits++
×
NEW
1023
                                edgesSeen[chanIDInt] = struct{}{}
×
NEW
1024
                                edges = append(edges, channel)
×
NEW
1025

×
NEW
1026
                                continue
×
1027
                        }
1028

NEW
1029
                        node1, node2, err := getAndBuildNodes(ctx, db, row)
×
NEW
1030
                        if err != nil {
×
NEW
1031
                                return err
×
NEW
1032
                        }
×
1033

NEW
1034
                        channel, p1, p2, err := getAndBuildEdgeInfoAndPolicies(
×
NEW
1035
                                ctx, db, s.cfg.ChainHash, row.ID, row,
×
NEW
1036
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
NEW
1037
                        )
×
NEW
1038
                        if err != nil {
×
NEW
1039
                                return err
×
NEW
1040
                        }
×
1041

NEW
1042
                        edgesSeen[chanIDInt] = struct{}{}
×
NEW
1043
                        chanEdge := ChannelEdge{
×
NEW
1044
                                Info:    channel,
×
NEW
1045
                                Policy1: p1,
×
NEW
1046
                                Policy2: p2,
×
NEW
1047
                                Node1:   node1,
×
NEW
1048
                                Node2:   node2,
×
NEW
1049
                        }
×
NEW
1050
                        edges = append(edges, chanEdge)
×
NEW
1051
                        edgesToCache[chanIDInt] = chanEdge
×
1052
                }
1053

NEW
1054
                return nil
×
NEW
1055
        }, func() {
×
NEW
1056
                edgesSeen = make(map[uint64]struct{})
×
NEW
1057
                edgesToCache = make(map[uint64]ChannelEdge)
×
NEW
1058
                edges = nil
×
NEW
1059
        })
×
NEW
1060
        if err != nil {
×
NEW
1061
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
1062
        }
×
1063

1064
        // Insert any edges loaded from disk into the cache.
NEW
1065
        for chanid, channel := range edgesToCache {
×
NEW
1066
                s.chanCache.insert(chanid, channel)
×
NEW
1067
        }
×
1068

NEW
1069
        log.Debugf("ChanUpdatesInHorizon hit percentage: %f (%d/%d)",
×
NEW
1070
                float64(hits)/float64(len(edges)), hits, len(edges))
×
NEW
1071

×
NEW
1072
        return edges, nil
×
1073
}
1074

1075
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1076
// data to the call-back.
1077
//
1078
// NOTE: The callback contents MUST not be modified.
1079
//
1080
// NOTE: part of the V1Store interface.
1081
func (s *SQLStore) ForEachNodeCached(cb func(node route.Vertex,
NEW
1082
        chans map[uint64]*DirectedChannel) error) error {
×
NEW
1083

×
NEW
1084
        var ctx = context.TODO()
×
NEW
1085

×
NEW
1086
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1087
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
NEW
1088
                        nodePub route.Vertex) error {
×
NEW
1089

×
NEW
1090
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
NEW
1091
                        if err != nil {
×
NEW
1092
                                return fmt.Errorf("unable to fetch "+
×
NEW
1093
                                        "node(id=%d) features: %w", nodeID, err)
×
NEW
1094
                        }
×
1095

NEW
1096
                        toNodeCallback := func() route.Vertex {
×
NEW
1097
                                return nodePub
×
NEW
1098
                        }
×
1099

NEW
1100
                        rows, err := db.ListChannelsByNodeID(
×
NEW
1101
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
NEW
1102
                                        Version: int16(ProtocolV1),
×
NEW
1103
                                        NodeID1: nodeID,
×
NEW
1104
                                },
×
NEW
1105
                        )
×
NEW
1106
                        if err != nil {
×
NEW
1107
                                return fmt.Errorf("unable to fetch channels "+
×
NEW
1108
                                        "of node(id=%d): %w", nodeID, err)
×
NEW
1109
                        }
×
1110

NEW
1111
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
NEW
1112
                        for _, row := range rows {
×
NEW
1113
                                node1, node2, err := buildNodeVertices(
×
NEW
1114
                                        row.Node1Pubkey, row.Node2Pubkey,
×
NEW
1115
                                )
×
NEW
1116
                                if err != nil {
×
NEW
1117
                                        return err
×
NEW
1118
                                }
×
1119

NEW
1120
                                e, p1, p2, err := getAndBuildEdgeInfoAndPolicies( //nolint:ll
×
NEW
1121
                                        ctx, db, s.cfg.ChainHash, row.ID, row,
×
NEW
1122
                                        node1, node2,
×
NEW
1123
                                )
×
NEW
1124
                                if err != nil {
×
NEW
1125
                                        return err
×
NEW
1126
                                }
×
1127

1128
                                // Determine the outgoing and incoming policy
1129
                                // for this channel and node combo.
NEW
1130
                                outPolicy, inPolicy := p1, p2
×
NEW
1131
                                if p1 != nil && p1.ToNode == nodePub {
×
NEW
1132
                                        outPolicy, inPolicy = p2, p1
×
NEW
1133
                                } else if p2 != nil && p2.ToNode != nodePub {
×
NEW
1134
                                        outPolicy, inPolicy = p2, p1
×
NEW
1135
                                }
×
1136

NEW
1137
                                var cachedInPolicy *models.CachedEdgePolicy
×
NEW
1138
                                if inPolicy != nil {
×
NEW
1139
                                        cachedInPolicy = models.NewCachedPolicy(
×
NEW
1140
                                                p2,
×
NEW
1141
                                        )
×
NEW
1142
                                        cachedInPolicy.ToNodePubKey =
×
NEW
1143
                                                toNodeCallback
×
NEW
1144
                                        cachedInPolicy.ToNodeFeatures =
×
NEW
1145
                                                features
×
NEW
1146
                                }
×
1147

NEW
1148
                                var inboundFee lnwire.Fee
×
NEW
1149
                                outPolicy.InboundFee.WhenSome(
×
NEW
1150
                                        func(fee lnwire.Fee) {
×
NEW
1151
                                                inboundFee = fee
×
NEW
1152
                                        },
×
1153
                                )
1154

NEW
1155
                                directedChannel := &DirectedChannel{
×
NEW
1156
                                        ChannelID: e.ChannelID,
×
NEW
1157
                                        IsNode1: nodePub ==
×
NEW
1158
                                                e.NodeKey1Bytes,
×
NEW
1159
                                        OtherNode:    e.NodeKey2Bytes,
×
NEW
1160
                                        Capacity:     e.Capacity,
×
NEW
1161
                                        OutPolicySet: p1 != nil,
×
NEW
1162
                                        InPolicy:     cachedInPolicy,
×
NEW
1163
                                        InboundFee:   inboundFee,
×
NEW
1164
                                }
×
NEW
1165

×
NEW
1166
                                if nodePub == e.NodeKey2Bytes {
×
NEW
1167
                                        directedChannel.OtherNode =
×
NEW
1168
                                                e.NodeKey1Bytes
×
NEW
1169
                                }
×
1170

NEW
1171
                                channels[e.ChannelID] = directedChannel
×
1172
                        }
1173

NEW
1174
                        return cb(nodePub, channels)
×
1175
                })
1176
        }, sqldb.NoOpReset)
1177
}
1178

1179
// ForEachChannelCacheable iterates through all the channel edges stored
1180
// within the graph and invokes the passed callback for each edge. The
1181
// callback takes two edges as since this is a directed graph, both the
1182
// in/out edges are visited. If the callback returns an error, then the
1183
// transaction is aborted and the iteration stops early.
1184
//
1185
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1186
// pointer for that particular channel edge routing policy will be
1187
// passed into the callback.
1188
//
1189
// NOTE: this method is like ForEachChannel but fetches only the data
1190
// required for the graph cache.
1191
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1192
        *models.CachedEdgePolicy,
NEW
1193
        *models.CachedEdgePolicy) error) error {
×
NEW
1194

×
NEW
1195
        ctx := context.TODO()
×
NEW
1196

×
NEW
1197
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1198
                rows, err := db.ListAllChannels(ctx, int16(ProtocolV1))
×
NEW
1199
                if err != nil {
×
NEW
1200
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
1201
                }
×
1202

NEW
1203
                for _, row := range rows {
×
NEW
1204
                        node1, node2, err := buildNodeVertices(
×
NEW
1205
                                row.Node1Pubkey, row.Node2Pubkey,
×
NEW
1206
                        )
×
NEW
1207
                        if err != nil {
×
NEW
1208
                                return err
×
NEW
1209
                        }
×
1210

NEW
1211
                        edge, err := buildCacheableChannelInfo(
×
NEW
1212
                                row, node1, node2,
×
NEW
1213
                        )
×
NEW
1214
                        if err != nil {
×
NEW
1215
                                return err
×
NEW
1216
                        }
×
1217

NEW
1218
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
NEW
1219
                        if err != nil {
×
NEW
1220
                                return err
×
NEW
1221
                        }
×
1222

NEW
1223
                        var pol1, pol2 *models.CachedEdgePolicy
×
NEW
1224
                        if dbPol1 != nil {
×
NEW
1225
                                policy1, err := buildChanPolicy(
×
NEW
1226
                                        *dbPol1, edge.ChannelID, nil,
×
NEW
1227
                                        node2, true,
×
NEW
1228
                                )
×
NEW
1229
                                if err != nil {
×
NEW
1230
                                        return err
×
NEW
1231
                                }
×
1232

NEW
1233
                                pol1 = models.NewCachedPolicy(policy1)
×
1234
                        }
NEW
1235
                        if dbPol2 != nil {
×
NEW
1236
                                policy2, err := buildChanPolicy(
×
NEW
1237
                                        *dbPol2, edge.ChannelID, nil,
×
NEW
1238
                                        node1, false,
×
NEW
1239
                                )
×
NEW
1240
                                if err != nil {
×
NEW
1241
                                        return err
×
NEW
1242
                                }
×
1243

NEW
1244
                                pol2 = models.NewCachedPolicy(policy2)
×
1245
                        }
1246

NEW
1247
                        if err := cb(edge, pol1, pol2); err != nil {
×
NEW
1248
                                return err
×
NEW
1249
                        }
×
1250
                }
1251

NEW
1252
                return nil
×
NEW
1253
        }, func() {})
×
1254
}
1255

1256
// ForEachChannel iterates through all the channel edges stored within the
1257
// graph and invokes the passed callback for each edge. The callback takes two
1258
// edges as since this is a directed graph, both the in/out edges are visited.
1259
// If the callback returns an error, then the transaction is aborted and the
1260
// iteration stops early.
1261
//
1262
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1263
// for that particular channel edge routing policy will be passed into the
1264
// callback.
1265
//
1266
// NOTE: part of the V1Store interface.
1267
func (s *SQLStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
NEW
1268
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
NEW
1269

×
NEW
1270
        var ctx = context.TODO()
×
NEW
1271

×
NEW
1272
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1273
                rows, err := db.ListAllChannels(ctx, int16(ProtocolV1))
×
NEW
1274
                if err != nil {
×
NEW
1275
                        return err
×
NEW
1276
                }
×
1277

NEW
1278
                for _, row := range rows {
×
NEW
1279
                        node1, node2, err := buildNodeVertices(
×
NEW
1280
                                row.Node1Pubkey, row.Node2Pubkey,
×
NEW
1281
                        )
×
NEW
1282
                        if err != nil {
×
NEW
1283
                                return err
×
NEW
1284
                        }
×
1285

NEW
1286
                        edge, p1, p2, err := getAndBuildEdgeInfoAndPolicies(
×
NEW
1287
                                ctx, db, s.cfg.ChainHash, row.ID, row,
×
NEW
1288
                                node1, node2,
×
NEW
1289
                        )
×
NEW
1290
                        if err != nil {
×
NEW
1291
                                return fmt.Errorf("unable to build edge "+
×
NEW
1292
                                        "info and policies: %w", err)
×
NEW
1293
                        }
×
1294

NEW
1295
                        if err := cb(edge, p1, p2); err != nil {
×
NEW
1296
                                return err
×
NEW
1297
                        }
×
1298
                }
1299

NEW
1300
                return nil
×
NEW
1301
        }, func() {})
×
1302
}
1303

1304
// FilterChannelRange returns the channel ID's of all known channels which were
1305
// mined in a block height within the passed range. The channel IDs are grouped
1306
// by their common block height. This method can be used to quickly share with a
1307
// peer the set of channels we know of within a particular range to catch them
1308
// up after a period of time offline. If withTimestamps is true then the
1309
// timestamp info of the latest received channel update messages of the channel
1310
// will be included in the response.
1311
//
1312
// NOTE: This is part of the V1Store interface.
1313
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
NEW
1314
        withTimestamps bool) ([]BlockChannelRange, error) {
×
NEW
1315

×
NEW
1316
        var (
×
NEW
1317
                ctx       = context.TODO()
×
NEW
1318
                startSCID = &lnwire.ShortChannelID{
×
NEW
1319
                        BlockHeight: startHeight,
×
NEW
1320
                }
×
NEW
1321
                endSCID = lnwire.ShortChannelID{
×
NEW
1322
                        BlockHeight: endHeight,
×
NEW
1323
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
NEW
1324
                        TxPosition:  math.MaxUint16,
×
NEW
1325
                }
×
NEW
1326
        )
×
NEW
1327

×
NEW
1328
        var chanIDStart [8]byte
×
NEW
1329
        byteOrder.PutUint64(chanIDStart[:], startSCID.ToUint64())
×
NEW
1330
        var chanIDEnd [8]byte
×
NEW
1331
        byteOrder.PutUint64(chanIDEnd[:], endSCID.ToUint64())
×
NEW
1332

×
NEW
1333
        // 1) get all channels where channelID is between start and end chan ID.
×
NEW
1334
        // 2) skip if not public (ie, no channel_proof)
×
NEW
1335
        // 3) collect that channel.
×
NEW
1336
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
NEW
1337
        //    and add those timestamps to the collected channel.
×
NEW
1338
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
NEW
1339
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1340
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
NEW
1341
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
NEW
1342
                                StartScid: chanIDStart[:],
×
NEW
1343
                                EndScid:   chanIDEnd[:],
×
NEW
1344
                        },
×
NEW
1345
                )
×
NEW
1346
                if err != nil {
×
NEW
1347
                        return fmt.Errorf("unable to fetch channel range: %w",
×
NEW
1348
                                err)
×
NEW
1349
                }
×
1350

NEW
1351
                for _, dbChan := range dbChans {
×
NEW
1352
                        cid := lnwire.NewShortChanIDFromInt(
×
NEW
1353
                                byteOrder.Uint64(dbChan.Scid),
×
NEW
1354
                        )
×
NEW
1355
                        chanInfo := NewChannelUpdateInfo(
×
NEW
1356
                                cid, time.Time{}, time.Time{},
×
NEW
1357
                        )
×
NEW
1358

×
NEW
1359
                        if !withTimestamps {
×
NEW
1360
                                channelsPerBlock[cid.BlockHeight] = append(
×
NEW
1361
                                        channelsPerBlock[cid.BlockHeight],
×
NEW
1362
                                        chanInfo,
×
NEW
1363
                                )
×
NEW
1364

×
NEW
1365
                                continue
×
1366
                        }
1367

1368
                        //nolint:ll
NEW
1369
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
NEW
1370
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
1371
                                        Version:   int16(ProtocolV1),
×
NEW
1372
                                        ChannelID: dbChan.ID,
×
NEW
1373
                                        NodeID:    dbChan.NodeID1,
×
NEW
1374
                                },
×
NEW
1375
                        )
×
NEW
1376
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
NEW
1377
                                return fmt.Errorf("unable to fetch node1 "+
×
NEW
1378
                                        "policy: %w", err)
×
NEW
1379
                        } else if err == nil {
×
NEW
1380
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
NEW
1381
                                        node1Policy.LastUpdate.Int64, 0,
×
NEW
1382
                                )
×
NEW
1383
                        }
×
1384

1385
                        //nolint:ll
NEW
1386
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
NEW
1387
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
1388
                                        Version:   int16(ProtocolV1),
×
NEW
1389
                                        ChannelID: dbChan.ID,
×
NEW
1390
                                        NodeID:    dbChan.NodeID2,
×
NEW
1391
                                },
×
NEW
1392
                        )
×
NEW
1393
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
NEW
1394
                                return fmt.Errorf("unable to fetch node2 "+
×
NEW
1395
                                        "policy: %w", err)
×
NEW
1396
                        } else if err == nil {
×
NEW
1397
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
NEW
1398
                                        node2Policy.LastUpdate.Int64, 0,
×
NEW
1399
                                )
×
NEW
1400
                        }
×
1401

NEW
1402
                        channelsPerBlock[cid.BlockHeight] = append(
×
NEW
1403
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
NEW
1404
                        )
×
1405
                }
1406

NEW
1407
                return nil
×
NEW
1408
        }, func() {
×
NEW
1409
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
NEW
1410
        })
×
NEW
1411
        if err != nil {
×
NEW
1412
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
NEW
1413
        }
×
1414

NEW
1415
        if len(channelsPerBlock) == 0 {
×
NEW
1416
                return nil, nil
×
NEW
1417
        }
×
1418

1419
        // Return the channel ranges in ascending block height order.
NEW
1420
        blocks := make([]uint32, 0, len(channelsPerBlock))
×
NEW
1421
        for block := range channelsPerBlock {
×
NEW
1422
                blocks = append(blocks, block)
×
NEW
1423
        }
×
NEW
1424
        sort.Slice(blocks, func(i, j int) bool {
×
NEW
1425
                return blocks[i] < blocks[j]
×
NEW
1426
        })
×
1427

NEW
1428
        channelRanges := make([]BlockChannelRange, 0, len(channelsPerBlock))
×
NEW
1429
        for _, block := range blocks {
×
NEW
1430
                channelRanges = append(channelRanges, BlockChannelRange{
×
NEW
1431
                        Height:   block,
×
NEW
1432
                        Channels: channelsPerBlock[block],
×
NEW
1433
                })
×
NEW
1434
        }
×
1435

NEW
1436
        return channelRanges, nil
×
1437
}
1438

1439
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1440
// zombie. This method is used on an ad-hoc basis, when channels need to be
1441
// marked as zombies outside the normal pruning cycle.
1442
//
1443
// NOTE: part of the V1Store interface.
1444
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
NEW
1445
        pubKey1, pubKey2 [33]byte) error {
×
NEW
1446

×
NEW
1447
        ctx := context.TODO()
×
NEW
1448

×
NEW
1449
        s.cacheMu.Lock()
×
NEW
1450
        defer s.cacheMu.Unlock()
×
NEW
1451

×
NEW
1452
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
1453
                return db.UpsertZombieChannel(
×
NEW
1454
                        ctx, sqlc.UpsertZombieChannelParams{
×
NEW
1455
                                Version:  int16(ProtocolV1),
×
NEW
1456
                                Scid:     int64(chanID),
×
NEW
1457
                                NodeKey1: pubKey1[:],
×
NEW
1458
                                NodeKey2: pubKey2[:],
×
NEW
1459
                        },
×
NEW
1460
                )
×
NEW
1461
        }, func() {})
×
NEW
1462
        if err != nil {
×
NEW
1463
                return err
×
NEW
1464
        }
×
1465

NEW
1466
        s.rejectCache.remove(chanID)
×
NEW
1467
        s.chanCache.remove(chanID)
×
NEW
1468

×
NEW
1469
        return nil
×
1470
}
1471

1472
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1473
//
1474
// NOTE: part of the V1Store interface.
NEW
1475
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
NEW
1476
        s.cacheMu.Lock()
×
NEW
1477
        defer s.cacheMu.Unlock()
×
NEW
1478

×
NEW
1479
        var ctx = context.TODO()
×
NEW
1480
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
1481
                _, err := db.GetZombieChannel(
×
NEW
1482
                        ctx, sqlc.GetZombieChannelParams{
×
NEW
1483
                                Scid:    int64(chanID),
×
NEW
1484
                                Version: int16(ProtocolV1),
×
NEW
1485
                        },
×
NEW
1486
                )
×
NEW
1487
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
1488
                        return ErrZombieEdgeNotFound
×
NEW
1489
                } else if err != nil {
×
NEW
1490
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
NEW
1491
                                err)
×
NEW
1492
                }
×
1493

NEW
1494
                return db.DeleteZombieChannel(
×
NEW
1495
                        ctx, sqlc.DeleteZombieChannelParams{
×
NEW
1496
                                Scid:    int64(chanID),
×
NEW
1497
                                Version: int16(ProtocolV1),
×
NEW
1498
                        },
×
NEW
1499
                )
×
NEW
1500
        }, func() {})
×
NEW
1501
        if err != nil {
×
NEW
1502
                return err
×
NEW
1503
        }
×
1504

NEW
1505
        s.rejectCache.remove(chanID)
×
NEW
1506
        s.chanCache.remove(chanID)
×
NEW
1507

×
NEW
1508
        return err
×
1509
}
1510

1511
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1512
// zombie, then the two node public keys corresponding to this edge are also
1513
// returned.
1514
//
1515
// NOTE: part of the V1Store interface.
NEW
1516
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) {
×
NEW
1517
        var (
×
NEW
1518
                ctx              = context.TODO()
×
NEW
1519
                isZombie         bool
×
NEW
1520
                pubKey1, pubKey2 route.Vertex
×
NEW
1521
        )
×
NEW
1522
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1523
                zombie, err := db.GetZombieChannel(
×
NEW
1524
                        ctx, sqlc.GetZombieChannelParams{
×
NEW
1525
                                Scid:    int64(chanID),
×
NEW
1526
                                Version: int16(ProtocolV1),
×
NEW
1527
                        },
×
NEW
1528
                )
×
NEW
1529
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
1530
                        return nil
×
NEW
1531
                }
×
NEW
1532
                if err != nil {
×
NEW
1533
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
NEW
1534
                                err)
×
NEW
1535
                }
×
1536

NEW
1537
                copy(pubKey1[:], zombie.NodeKey1)
×
NEW
1538
                copy(pubKey2[:], zombie.NodeKey2)
×
NEW
1539
                isZombie = true
×
NEW
1540

×
NEW
1541
                return nil
×
NEW
1542
        }, func() {})
×
NEW
1543
        if err != nil {
×
NEW
1544
                return false, route.Vertex{}, route.Vertex{}
×
NEW
1545
        }
×
1546

NEW
1547
        return isZombie, pubKey1, pubKey2
×
1548
}
1549

1550
// NumZombies returns the current number of zombie channels in the graph.
1551
//
1552
// NOTE: part of the V1Store interface.
NEW
1553
func (s *SQLStore) NumZombies() (uint64, error) {
×
NEW
1554
        var (
×
NEW
1555
                ctx        = context.TODO()
×
NEW
1556
                numZombies uint64
×
NEW
1557
        )
×
NEW
1558
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1559
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
NEW
1560
                if err != nil {
×
NEW
1561
                        return fmt.Errorf("unable to count zombie channels: %w",
×
NEW
1562
                                err)
×
NEW
1563
                }
×
1564

NEW
1565
                numZombies = uint64(count)
×
NEW
1566

×
NEW
1567
                return nil
×
NEW
1568
        }, func() {})
×
NEW
1569
        if err != nil {
×
NEW
1570
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
NEW
1571
        }
×
1572

NEW
1573
        return numZombies, nil
×
1574
}
1575

1576
// DeleteChannelEdges removes edges with the given channel IDs from the
1577
// database and marks them as zombies. This ensures that we're unable to re-add
1578
// it to our database once again. If an edge does not exist within the
1579
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1580
// true, then when we mark these edges as zombies, we'll set up the keys such
1581
// that we require the node that failed to send the fresh update to be the one
1582
// that resurrects the channel from its zombie state. The markZombie bool
1583
// denotes whether to mark the channel as a zombie.
1584
//
1585
// NOTE: part of the V1Store interface.
1586
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
NEW
1587
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
NEW
1588

×
NEW
1589
        s.cacheMu.Lock()
×
NEW
1590
        defer s.cacheMu.Unlock()
×
NEW
1591

×
NEW
1592
        var (
×
NEW
1593
                ctx     = context.TODO()
×
NEW
1594
                deleted []*models.ChannelEdgeInfo
×
NEW
1595
        )
×
NEW
1596
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
1597
                for _, chanID := range chanIDs {
×
NEW
1598
                        var chanIDB [8]byte
×
NEW
1599
                        byteOrder.PutUint64(chanIDB[:], chanID)
×
NEW
1600

×
NEW
1601
                        row, err := db.GetChannelBySCIDWithPolicies(
×
NEW
1602
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
NEW
1603
                                        Scid:    chanIDB[:],
×
NEW
1604
                                        Version: int16(ProtocolV1),
×
NEW
1605
                                },
×
NEW
1606
                        )
×
NEW
1607
                        if errors.Is(err, sql.ErrNoRows) {
×
NEW
1608
                                return ErrEdgeNotFound
×
NEW
1609
                        } else if err != nil {
×
NEW
1610
                                return fmt.Errorf("unable to fetch channel: %w",
×
NEW
1611
                                        err)
×
NEW
1612
                        }
×
1613

NEW
1614
                        node1, node2, err := buildNodeVertices(
×
NEW
1615
                                row.Node1PubKey, row.Node2PubKey,
×
NEW
1616
                        )
×
NEW
1617
                        if err != nil {
×
NEW
1618
                                return err
×
NEW
1619
                        }
×
1620

NEW
1621
                        info, err := getAndBuildEdgeInfo(
×
NEW
1622
                                ctx, db, s.cfg.ChainHash, row.ID, row,
×
NEW
1623
                                node1, node2,
×
NEW
1624
                        )
×
NEW
1625
                        if err != nil {
×
NEW
1626
                                return err
×
NEW
1627
                        }
×
1628

NEW
1629
                        err = db.DeleteChannel(ctx, row.ID)
×
NEW
1630
                        if err != nil {
×
NEW
1631
                                return fmt.Errorf("unable to delete "+
×
NEW
1632
                                        "channel: %w", err)
×
NEW
1633
                        }
×
1634

NEW
1635
                        deleted = append(deleted, info)
×
NEW
1636

×
NEW
1637
                        if !markZombie {
×
NEW
1638
                                continue
×
1639
                        }
1640

NEW
1641
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
NEW
1642
                                info.NodeKey2Bytes
×
NEW
1643
                        if strictZombiePruning {
×
NEW
1644
                                var e1UpdateTime, e2UpdateTime *time.Time
×
NEW
1645
                                if row.Policy1LastUpdate.Valid {
×
NEW
1646
                                        e1Time := time.Unix(
×
NEW
1647
                                                row.Policy1LastUpdate.Int64, 0,
×
NEW
1648
                                        )
×
NEW
1649
                                        e1UpdateTime = &e1Time
×
NEW
1650
                                }
×
NEW
1651
                                if row.Policy2LastUpdate.Valid {
×
NEW
1652
                                        e2Time := time.Unix(
×
NEW
1653
                                                row.Policy2LastUpdate.Int64, 0,
×
NEW
1654
                                        )
×
NEW
1655
                                        e2UpdateTime = &e2Time
×
NEW
1656
                                }
×
1657

NEW
1658
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
NEW
1659
                                        info, e1UpdateTime, e2UpdateTime,
×
NEW
1660
                                )
×
1661
                        }
1662

NEW
1663
                        err = db.UpsertZombieChannel(
×
NEW
1664
                                ctx, sqlc.UpsertZombieChannelParams{
×
NEW
1665
                                        Version:  int16(ProtocolV1),
×
NEW
1666
                                        Scid:     int64(chanID),
×
NEW
1667
                                        NodeKey1: nodeKey1[:],
×
NEW
1668
                                        NodeKey2: nodeKey2[:],
×
NEW
1669
                                },
×
NEW
1670
                        )
×
NEW
1671
                        if err != nil {
×
NEW
1672
                                return fmt.Errorf("unable to mark channel as "+
×
NEW
1673
                                        "zombie: %w", err)
×
NEW
1674
                        }
×
1675
                }
1676

NEW
1677
                return nil
×
NEW
1678
        }, func() {
×
NEW
1679
                deleted = nil
×
NEW
1680
        })
×
NEW
1681
        if err != nil {
×
NEW
1682
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
NEW
1683
                        err)
×
NEW
1684
        }
×
1685

NEW
1686
        for _, chanID := range chanIDs {
×
NEW
1687
                s.rejectCache.remove(chanID)
×
NEW
1688
                s.chanCache.remove(chanID)
×
NEW
1689
        }
×
1690

NEW
1691
        return deleted, nil
×
1692
}
1693

1694
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1695
// channel identified by the channel ID. If the channel can't be found, then
1696
// ErrEdgeNotFound is returned. A struct which houses the general information
1697
// for the channel itself is returned as well as two structs that contain the
1698
// routing policies for the channel in either direction.
1699
//
1700
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1701
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1702
// the ChannelEdgeInfo will only include the public keys of each node.
1703
//
1704
// NOTE: part of the V1Store interface.
1705
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1706
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
NEW
1707
        *models.ChannelEdgePolicy, error) {
×
NEW
1708

×
NEW
1709
        var (
×
NEW
1710
                ctx              = context.TODO()
×
NEW
1711
                edge             *models.ChannelEdgeInfo
×
NEW
1712
                policy1, policy2 *models.ChannelEdgePolicy
×
NEW
1713
        )
×
NEW
1714
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1715
                var chanIDB [8]byte
×
NEW
1716
                byteOrder.PutUint64(chanIDB[:], chanID)
×
NEW
1717

×
NEW
1718
                row, err := db.GetChannelBySCIDWithPolicies(
×
NEW
1719
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
NEW
1720
                                Scid:    chanIDB[:],
×
NEW
1721
                                Version: int16(ProtocolV1),
×
NEW
1722
                        },
×
NEW
1723
                )
×
NEW
1724
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
1725
                        // First check if this edge is perhaps in the zombie
×
NEW
1726
                        // index.
×
NEW
1727
                        isZombie, err := db.IsZombieChannel(
×
NEW
1728
                                ctx, sqlc.IsZombieChannelParams{
×
NEW
1729
                                        Scid:    int64(chanID),
×
NEW
1730
                                        Version: int16(ProtocolV1),
×
NEW
1731
                                },
×
NEW
1732
                        )
×
NEW
1733
                        if err != nil {
×
NEW
1734
                                return fmt.Errorf("unable to check if "+
×
NEW
1735
                                        "channel is zombie: %w", err)
×
NEW
1736
                        } else if isZombie {
×
NEW
1737
                                return ErrZombieEdge
×
NEW
1738
                        }
×
1739

NEW
1740
                        return ErrEdgeNotFound
×
NEW
1741
                } else if err != nil {
×
NEW
1742
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
NEW
1743
                }
×
1744

NEW
1745
                node1, node2, err := buildNodeVertices(
×
NEW
1746
                        row.Node1PubKey, row.Node2PubKey,
×
NEW
1747
                )
×
NEW
1748
                if err != nil {
×
NEW
1749
                        return err
×
NEW
1750
                }
×
1751

NEW
1752
                edge, policy1, policy2, err = getAndBuildEdgeInfoAndPolicies(
×
NEW
1753
                        ctx, db, s.cfg.ChainHash, row.ID, row, node1, node2,
×
NEW
1754
                )
×
NEW
1755

×
NEW
1756
                return err
×
NEW
1757
        }, func() {})
×
NEW
1758
        if err != nil {
×
NEW
1759
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
NEW
1760
                        err)
×
NEW
1761
        }
×
1762

NEW
1763
        return edge, policy1, policy2, nil
×
1764
}
1765

1766
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1767
// the channel identified by the funding outpoint. If the channel can't be
1768
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1769
// information for the channel itself is returned as well as two structs that
1770
// contain the routing policies for the channel in either direction.
1771
//
1772
// NOTE: part of the V1Store interface.
1773
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1774
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
NEW
1775
        *models.ChannelEdgePolicy, error) {
×
NEW
1776

×
NEW
1777
        var (
×
NEW
1778
                ctx              = context.TODO()
×
NEW
1779
                edge             *models.ChannelEdgeInfo
×
NEW
1780
                policy1, policy2 *models.ChannelEdgePolicy
×
NEW
1781
        )
×
NEW
1782
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1783
                row, err := db.GetChannelByOutpointWithPolicies(
×
NEW
1784
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
NEW
1785
                                Outpoint: op.String(),
×
NEW
1786
                                Version:  int16(ProtocolV1),
×
NEW
1787
                        },
×
NEW
1788
                )
×
NEW
1789
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
1790
                        return ErrEdgeNotFound
×
NEW
1791
                } else if err != nil {
×
NEW
1792
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
NEW
1793
                }
×
1794

NEW
1795
                node1, node2, err := buildNodeVertices(
×
NEW
1796
                        row.Node1Pubkey, row.Node2Pubkey,
×
NEW
1797
                )
×
NEW
1798
                if err != nil {
×
NEW
1799
                        return err
×
NEW
1800
                }
×
1801

NEW
1802
                edge, policy1, policy2, err = getAndBuildEdgeInfoAndPolicies(
×
NEW
1803
                        ctx, db, s.cfg.ChainHash, row.ID, row, node1, node2,
×
NEW
1804
                )
×
NEW
1805

×
NEW
1806
                return err
×
NEW
1807
        }, func() {})
×
NEW
1808
        if err != nil {
×
NEW
1809
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
NEW
1810
                        err)
×
NEW
1811
        }
×
1812

NEW
1813
        return edge, policy1, policy2, nil
×
1814
}
1815

1816
// HasChannelEdge returns true if the database knows of a channel edge with the
1817
// passed channel ID, and false otherwise. If an edge with that ID is found
1818
// within the graph, then two time stamps representing the last time the edge
1819
// was updated for both directed edges are returned along with the boolean. If
1820
// it is not found, then the zombie index is checked and its result is returned
1821
// as the second boolean.
1822
//
1823
// NOTE: part of the V1Store interface.
1824
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
NEW
1825
        bool, error) {
×
NEW
1826

×
NEW
1827
        ctx := context.TODO()
×
NEW
1828

×
NEW
1829
        var (
×
NEW
1830
                exists          bool
×
NEW
1831
                isZombie        bool
×
NEW
1832
                node1LastUpdate time.Time
×
NEW
1833
                node2LastUpdate time.Time
×
NEW
1834
        )
×
NEW
1835

×
NEW
1836
        // We'll query the cache with the shared lock held to allow multiple
×
NEW
1837
        // readers to access values in the cache concurrently if they exist.
×
NEW
1838
        s.cacheMu.RLock()
×
NEW
1839
        if entry, ok := s.rejectCache.get(chanID); ok {
×
NEW
1840
                s.cacheMu.RUnlock()
×
NEW
1841
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
NEW
1842
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
NEW
1843
                exists, isZombie = entry.flags.unpack()
×
NEW
1844

×
NEW
1845
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
NEW
1846
        }
×
NEW
1847
        s.cacheMu.RUnlock()
×
NEW
1848

×
NEW
1849
        s.cacheMu.Lock()
×
NEW
1850
        defer s.cacheMu.Unlock()
×
NEW
1851

×
NEW
1852
        // The item was not found with the shared lock, so we'll acquire the
×
NEW
1853
        // exclusive lock and check the cache again in case another method added
×
NEW
1854
        // the entry to the cache while no lock was held.
×
NEW
1855
        if entry, ok := s.rejectCache.get(chanID); ok {
×
NEW
1856
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
NEW
1857
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
NEW
1858
                exists, isZombie = entry.flags.unpack()
×
NEW
1859

×
NEW
1860
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
NEW
1861
        }
×
1862

NEW
1863
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1864
                var chanIDB [8]byte
×
NEW
1865
                byteOrder.PutUint64(chanIDB[:], chanID)
×
NEW
1866

×
NEW
1867
                channel, err := db.GetChannelBySCID(
×
NEW
1868
                        ctx, sqlc.GetChannelBySCIDParams{
×
NEW
1869
                                Scid:    chanIDB[:],
×
NEW
1870
                                Version: int16(ProtocolV1),
×
NEW
1871
                        },
×
NEW
1872
                )
×
NEW
1873
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
1874
                        // Check if it is a zombie channel.
×
NEW
1875
                        isZombie, err = db.IsZombieChannel(
×
NEW
1876
                                ctx, sqlc.IsZombieChannelParams{
×
NEW
1877
                                        Scid:    int64(chanID),
×
NEW
1878
                                        Version: int16(ProtocolV1),
×
NEW
1879
                                },
×
NEW
1880
                        )
×
NEW
1881
                        if err != nil {
×
NEW
1882
                                return fmt.Errorf("could not check if channel "+
×
NEW
1883
                                        "is zombie: %w", err)
×
NEW
1884
                        }
×
1885

NEW
1886
                        return nil
×
NEW
1887
                } else if err != nil {
×
NEW
1888
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
NEW
1889
                }
×
1890

NEW
1891
                exists = true
×
NEW
1892

×
NEW
1893
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
NEW
1894
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
1895
                                Version:   int16(ProtocolV1),
×
NEW
1896
                                ChannelID: channel.ID,
×
NEW
1897
                                NodeID:    channel.NodeID1,
×
NEW
1898
                        },
×
NEW
1899
                )
×
NEW
1900
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
NEW
1901
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
NEW
1902
                                err)
×
NEW
1903
                } else if err == nil {
×
NEW
1904
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
NEW
1905
                }
×
1906

NEW
1907
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
NEW
1908
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
1909
                                Version:   int16(ProtocolV1),
×
NEW
1910
                                ChannelID: channel.ID,
×
NEW
1911
                                NodeID:    channel.NodeID2,
×
NEW
1912
                        },
×
NEW
1913
                )
×
NEW
1914
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
NEW
1915
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
NEW
1916
                                err)
×
NEW
1917
                } else if err == nil {
×
NEW
1918
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
NEW
1919
                }
×
1920

NEW
1921
                return nil
×
NEW
1922
        }, func() {})
×
NEW
1923
        if err != nil {
×
NEW
1924
                return time.Time{}, time.Time{}, false, false,
×
NEW
1925
                        fmt.Errorf("unable to fetch channel: %w", err)
×
NEW
1926
        }
×
1927

NEW
1928
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
NEW
1929
                upd1Time: node1LastUpdate.Unix(),
×
NEW
1930
                upd2Time: node2LastUpdate.Unix(),
×
NEW
1931
                flags:    packRejectFlags(exists, isZombie),
×
NEW
1932
        })
×
NEW
1933

×
NEW
1934
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1935
}
1936

1937
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
1938
// passed channel point (outpoint). If the passed channel doesn't exist within
1939
// the database, then ErrEdgeNotFound is returned.
1940
//
1941
// NOTE: part of the V1Store interface.
NEW
1942
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
NEW
1943
        var (
×
NEW
1944
                ctx       = context.TODO()
×
NEW
1945
                channelID uint64
×
NEW
1946
        )
×
NEW
1947
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1948
                chanID, err := db.GetSCIDByOutpoint(
×
NEW
1949
                        ctx, sqlc.GetSCIDByOutpointParams{
×
NEW
1950
                                Outpoint: chanPoint.String(),
×
NEW
1951
                                Version:  int16(ProtocolV1),
×
NEW
1952
                        },
×
NEW
1953
                )
×
NEW
1954
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
1955
                        return ErrEdgeNotFound
×
NEW
1956
                } else if err != nil {
×
NEW
1957
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
NEW
1958
                                err)
×
NEW
1959
                }
×
1960

NEW
1961
                channelID = byteOrder.Uint64(chanID)
×
NEW
1962

×
NEW
1963
                return nil
×
NEW
1964
        }, func() {})
×
NEW
1965
        if err != nil {
×
NEW
1966
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
NEW
1967
        }
×
1968

NEW
1969
        return channelID, nil
×
1970
}
1971

1972
// IsPublicNode is a helper method that determines whether the node with the
1973
// given public key is seen as a public node in the graph from the graph's
1974
// source node's point of view.
1975
//
1976
// NOTE: part of the V1Store interface.
NEW
1977
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
NEW
1978
        ctx := context.TODO()
×
NEW
1979

×
NEW
1980
        var isPublic bool
×
NEW
1981
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1982
                var err error
×
NEW
1983
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
NEW
1984

×
NEW
1985
                return err
×
NEW
1986
        }, func() {})
×
NEW
1987
        if err != nil {
×
NEW
1988
                return false, fmt.Errorf("unable to check if node is "+
×
NEW
1989
                        "public: %w", err)
×
NEW
1990
        }
×
1991

NEW
1992
        return isPublic, nil
×
1993
}
1994

1995
// FetchChanInfos returns the set of channel edges that correspond to the passed
1996
// channel ID's. If an edge is the query is unknown to the database, it will
1997
// skipped and the result will contain only those edges that exist at the time
1998
// of the query. This can be used to respond to peer queries that are seeking to
1999
// fill in gaps in their view of the channel graph.
2000
//
2001
// NOTE: part of the V1Store interface.
NEW
2002
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
NEW
2003
        var (
×
NEW
2004
                ctx   = context.TODO()
×
NEW
2005
                edges []ChannelEdge
×
NEW
2006
        )
×
NEW
2007
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2008
                for _, chanID := range chanIDs {
×
NEW
2009
                        var chanIDB [8]byte
×
NEW
2010
                        byteOrder.PutUint64(chanIDB[:], chanID)
×
NEW
2011

×
NEW
2012
                        row, err := db.GetChannelBySCIDWithPolicies(
×
NEW
2013
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
NEW
2014
                                        Scid:    chanIDB[:],
×
NEW
2015
                                        Version: int16(ProtocolV1),
×
NEW
2016
                                },
×
NEW
2017
                        )
×
NEW
2018
                        if errors.Is(err, sql.ErrNoRows) {
×
NEW
2019
                                continue
×
NEW
2020
                        } else if err != nil {
×
NEW
2021
                                return fmt.Errorf("unable to fetch channel: %w",
×
NEW
2022
                                        err)
×
NEW
2023
                        }
×
2024

NEW
2025
                        node1, node2, err := getAndBuildNodes(ctx, db, row)
×
NEW
2026
                        if err != nil {
×
NEW
2027
                                return fmt.Errorf("unable to fetch nodes: %w",
×
NEW
2028
                                        err)
×
NEW
2029
                        }
×
2030

NEW
2031
                        edge, p1, p2, err := getAndBuildEdgeInfoAndPolicies(
×
NEW
2032
                                ctx, db, s.cfg.ChainHash, row.ID, row,
×
NEW
2033
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
NEW
2034
                        )
×
NEW
2035
                        if err != nil {
×
NEW
2036
                                return fmt.Errorf("unable to build channel "+
×
NEW
2037
                                        "info and policies: %w", err)
×
NEW
2038
                        }
×
2039

NEW
2040
                        edges = append(edges, ChannelEdge{
×
NEW
2041
                                Info:    edge,
×
NEW
2042
                                Policy1: p1,
×
NEW
2043
                                Policy2: p2,
×
NEW
2044
                                Node1:   node1,
×
NEW
2045
                                Node2:   node2,
×
NEW
2046
                        })
×
2047
                }
2048

NEW
2049
                return nil
×
NEW
2050
        }, func() {})
×
NEW
2051
        if err != nil {
×
NEW
2052
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
2053
        }
×
2054

NEW
2055
        return edges, nil
×
2056
}
2057

2058
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2059
// ID's that we don't know and are not known zombies of the passed set. In other
2060
// words, we perform a set difference of our set of chan ID's and the ones
2061
// passed in. This method can be used by callers to determine the set of
2062
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2063
// known zombies is also returned.
2064
//
2065
// NOTE: part of the V1Store interface.
2066
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
NEW
2067
        []ChannelUpdateInfo, error) {
×
NEW
2068

×
NEW
2069
        var (
×
NEW
2070
                ctx          = context.TODO()
×
NEW
2071
                newChanIDs   []uint64
×
NEW
2072
                knownZombies []ChannelUpdateInfo
×
NEW
2073
        )
×
NEW
2074
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2075
                for _, chanInfo := range chansInfo {
×
NEW
2076
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
NEW
2077
                        var chanIDB [8]byte
×
NEW
2078
                        byteOrder.PutUint64(chanIDB[:], channelID)
×
NEW
2079

×
NEW
2080
                        _, err := db.GetChannelBySCID(
×
NEW
2081
                                ctx, sqlc.GetChannelBySCIDParams{
×
NEW
2082
                                        Version: int16(ProtocolV1),
×
NEW
2083
                                        Scid:    chanIDB[:],
×
NEW
2084
                                },
×
NEW
2085
                        )
×
NEW
2086
                        if err == nil {
×
NEW
2087
                                continue
×
NEW
2088
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
NEW
2089
                                return fmt.Errorf("unable to fetch channel: %w",
×
NEW
2090
                                        err)
×
NEW
2091
                        }
×
2092

NEW
2093
                        isZombie, err := db.IsZombieChannel(
×
NEW
2094
                                ctx, sqlc.IsZombieChannelParams{
×
NEW
2095
                                        Scid:    int64(channelID),
×
NEW
2096
                                        Version: int16(ProtocolV1),
×
NEW
2097
                                },
×
NEW
2098
                        )
×
NEW
2099
                        if err != nil {
×
NEW
2100
                                return fmt.Errorf("unable to fetch zombie "+
×
NEW
2101
                                        "channel: %w", err)
×
NEW
2102
                        }
×
2103

NEW
2104
                        if isZombie {
×
NEW
2105
                                knownZombies = append(knownZombies, chanInfo)
×
NEW
2106

×
NEW
2107
                                continue
×
2108
                        }
2109

NEW
2110
                        newChanIDs = append(newChanIDs, channelID)
×
2111
                }
2112

NEW
2113
                return nil
×
NEW
2114
        }, func() {})
×
NEW
2115
        if err != nil {
×
NEW
2116
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
2117
        }
×
2118

NEW
2119
        return newChanIDs, knownZombies, nil
×
2120
}
2121

2122
// PruneGraphNodes is a garbage collection method which attempts to prune out
2123
// any nodes from the channel graph that are currently unconnected. This ensure
2124
// that we only maintain a graph of reachable nodes. In the event that a pruned
2125
// node gains more channels, it will be re-added back to the graph.
2126
//
2127
// NOTE: part of the V1Store interface.
NEW
2128
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
NEW
2129
        var ctx = context.TODO()
×
NEW
2130

×
NEW
2131
        var prunedNodes []route.Vertex
×
NEW
2132
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2133
                var err error
×
NEW
2134
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
NEW
2135

×
NEW
2136
                return err
×
NEW
2137
        }, func() {})
×
NEW
2138
        if err != nil {
×
NEW
2139
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
NEW
2140
        }
×
2141

NEW
2142
        return prunedNodes, nil
×
2143
}
2144

2145
// PruneGraph prunes newly closed channels from the channel graph in response
2146
// to a new block being solved on the network. Any transactions which spend the
2147
// funding output of any known channels within he graph will be deleted.
2148
// Additionally, the "prune tip", or the last block which has been used to
2149
// prune the graph is stored so callers can ensure the graph is fully in sync
2150
// with the current UTXO state. A slice of channels that have been closed by
2151
// the target block along with any pruned nodes are returned if the function
2152
// succeeds without error.
2153
//
2154
// NOTE: part of the V1Store interface.
2155
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2156
        blockHash *chainhash.Hash, blockHeight uint32) (
NEW
2157
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
NEW
2158

×
NEW
2159
        ctx := context.TODO()
×
NEW
2160

×
NEW
2161
        s.cacheMu.Lock()
×
NEW
2162
        defer s.cacheMu.Unlock()
×
NEW
2163

×
NEW
2164
        var (
×
NEW
2165
                closedChans []*models.ChannelEdgeInfo
×
NEW
2166
                prunedNodes []route.Vertex
×
NEW
2167
        )
×
NEW
2168
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2169
                for _, outpoint := range spentOutputs {
×
NEW
2170
                        row, err := db.GetChannelByOutpoint(
×
NEW
2171
                                ctx, sqlc.GetChannelByOutpointParams{
×
NEW
2172
                                        Outpoint: outpoint.String(),
×
NEW
2173
                                        Version:  int16(ProtocolV1),
×
NEW
2174
                                },
×
NEW
2175
                        )
×
NEW
2176
                        if errors.Is(err, sql.ErrNoRows) {
×
NEW
2177
                                continue
×
NEW
2178
                        } else if err != nil {
×
NEW
2179
                                return fmt.Errorf("unable to fetch channel: %w",
×
NEW
2180
                                        err)
×
NEW
2181
                        }
×
2182

NEW
2183
                        node1, node2, err := buildNodeVertices(
×
NEW
2184
                                row.Node1Pubkey, row.Node2Pubkey,
×
NEW
2185
                        )
×
NEW
2186
                        if err != nil {
×
NEW
2187
                                return err
×
NEW
2188
                        }
×
2189

NEW
2190
                        info, err := getAndBuildEdgeInfo(
×
NEW
2191
                                ctx, db, s.cfg.ChainHash, row.ID, row,
×
NEW
2192
                                node1, node2,
×
NEW
2193
                        )
×
NEW
2194
                        if err != nil {
×
NEW
2195
                                return err
×
NEW
2196
                        }
×
2197

NEW
2198
                        err = db.DeleteChannel(ctx, row.ID)
×
NEW
2199
                        if err != nil {
×
NEW
2200
                                return fmt.Errorf("unable to delete "+
×
NEW
2201
                                        "channel: %w", err)
×
NEW
2202
                        }
×
2203

NEW
2204
                        closedChans = append(closedChans, info)
×
2205
                }
2206

NEW
2207
                err := db.UpsertPruneLogEntry(
×
NEW
2208
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
NEW
2209
                                BlockHash:   blockHash[:],
×
NEW
2210
                                BlockHeight: int64(blockHeight),
×
NEW
2211
                        },
×
NEW
2212
                )
×
NEW
2213
                if err != nil {
×
NEW
2214
                        return fmt.Errorf("unable to insert prune log "+
×
NEW
2215
                                "entry: %w", err)
×
NEW
2216
                }
×
2217

2218
                // Now that we've pruned some channels, we'll also prune any
2219
                // nodes that no longer have any channels.
NEW
2220
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
NEW
2221
                if err != nil {
×
NEW
2222
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
NEW
2223
                                err)
×
NEW
2224
                }
×
2225

NEW
2226
                return nil
×
NEW
2227
        }, func() {
×
NEW
2228
                prunedNodes = nil
×
NEW
2229
                closedChans = nil
×
NEW
2230
        })
×
NEW
2231
        if err != nil {
×
NEW
2232
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
NEW
2233
        }
×
2234

NEW
2235
        for _, channel := range closedChans {
×
NEW
2236
                s.rejectCache.remove(channel.ChannelID)
×
NEW
2237
                s.chanCache.remove(channel.ChannelID)
×
NEW
2238
        }
×
2239

NEW
2240
        return closedChans, prunedNodes, nil
×
2241
}
2242

2243
// ChannelView returns the verifiable edge information for each active channel
2244
// within the known channel graph. The set of UTXO's (along with their scripts)
2245
// returned are the ones that need to be watched on chain to detect channel
2246
// closes on the resident blockchain.
2247
//
2248
// NOTE: part of the V1Store interface.
NEW
2249
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
NEW
2250
        var (
×
NEW
2251
                ctx        = context.TODO()
×
NEW
2252
                edgePoints []EdgePoint
×
NEW
2253
        )
×
NEW
2254
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2255
                dbChannel, err := db.ListAllChannels(ctx, int16(ProtocolV1))
×
NEW
2256
                if err != nil {
×
NEW
2257
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
2258
                }
×
2259

NEW
2260
                for _, dbChan := range dbChannel {
×
NEW
2261
                        if dbChan.BitcoinKey1 == nil {
×
NEW
2262
                                continue
×
2263
                        }
2264

NEW
2265
                        pkScript, err := genMultiSigP2WSH(
×
NEW
2266
                                dbChan.BitcoinKey1, dbChan.BitcoinKey2,
×
NEW
2267
                        )
×
NEW
2268
                        if err != nil {
×
NEW
2269
                                return err
×
NEW
2270
                        }
×
2271

NEW
2272
                        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
NEW
2273
                        if err != nil {
×
NEW
2274
                                return err
×
NEW
2275
                        }
×
2276

NEW
2277
                        edgePoints = append(edgePoints, EdgePoint{
×
NEW
2278
                                FundingPkScript: pkScript,
×
NEW
2279
                                OutPoint:        *op,
×
NEW
2280
                        })
×
2281
                }
2282

NEW
2283
                return nil
×
NEW
2284
        }, func() {})
×
NEW
2285
        if err != nil {
×
NEW
2286
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
NEW
2287
        }
×
2288

NEW
2289
        return edgePoints, nil
×
2290
}
2291

2292
// PruneTip returns the block height and hash of the latest block that has been
2293
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2294
// to tell if the graph is currently in sync with the current best known UTXO
2295
// state.
2296
//
2297
// NOTE: part of the V1Store interface.
NEW
2298
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
NEW
2299
        var (
×
NEW
2300
                ctx       = context.TODO()
×
NEW
2301
                tipHash   chainhash.Hash
×
NEW
2302
                tipHeight uint32
×
NEW
2303
        )
×
NEW
2304
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2305
                pruneTip, err := db.GetPruneTip(ctx)
×
NEW
2306
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
2307
                        return ErrGraphNeverPruned
×
NEW
2308
                } else if err != nil {
×
NEW
2309
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
NEW
2310
                }
×
2311

NEW
2312
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
NEW
2313
                tipHeight = uint32(pruneTip.BlockHeight)
×
NEW
2314

×
NEW
2315
                return nil
×
NEW
2316
        }, func() {})
×
NEW
2317
        if err != nil {
×
NEW
2318
                return nil, 0, err
×
NEW
2319
        }
×
2320

NEW
2321
        return &tipHash, tipHeight, nil
×
2322
}
2323

2324
// pruneGraphNodes deletes any node in the DB that doesn't have a channel
2325
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
NEW
2326
        db SQLQueries) ([]route.Vertex, error) {
×
NEW
2327

×
NEW
2328
        nodes, err := db.GetUnconnectedNodes(ctx)
×
NEW
2329
        if err != nil {
×
NEW
2330
                return nil, fmt.Errorf("unable to fetch unconnected nodes: %w",
×
NEW
2331
                        err)
×
NEW
2332
        }
×
2333

NEW
2334
        nodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
NEW
2335
        if err != nil {
×
NEW
2336
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
NEW
2337
        }
×
2338

NEW
2339
        prunedNodes := make([]route.Vertex, 0, len(nodes))
×
NEW
2340
        for _, node := range nodes {
×
NEW
2341
                // Don't delete the source node.
×
NEW
2342
                if node.ID == nodeID {
×
NEW
2343
                        continue
×
2344
                }
2345

NEW
2346
                _, err = db.DeleteNodeByPubKey(
×
NEW
2347
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
NEW
2348
                                PubKey:  node.PubKey,
×
NEW
2349
                                Version: int16(ProtocolV1),
×
NEW
2350
                        },
×
NEW
2351
                )
×
NEW
2352
                if err != nil {
×
NEW
2353
                        return nil, fmt.Errorf("unable to delete node: %w", err)
×
NEW
2354
                }
×
2355

NEW
2356
                var pubKey route.Vertex
×
NEW
2357
                copy(pubKey[:], node.PubKey)
×
NEW
2358
                prunedNodes = append(prunedNodes, pubKey)
×
2359
        }
2360

NEW
2361
        return prunedNodes, nil
×
2362
}
2363

2364
// DisconnectBlockAtHeight is used to indicate that the block specified
2365
// by the passed height has been disconnected from the main chain. This
2366
// will "rewind" the graph back to the height below, deleting channels
2367
// that are no longer confirmed from the graph. The prune log will be
2368
// set to the last prune height valid for the remaining chain.
2369
// Channels that were removed from the graph resulting from the
2370
// disconnected block are returned.
2371
//
2372
// NOTE: part of the V1Store interface.
2373
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
NEW
2374
        []*models.ChannelEdgeInfo, error) {
×
NEW
2375

×
NEW
2376
        ctx := context.TODO()
×
NEW
2377

×
NEW
2378
        var (
×
NEW
2379
                // Every channel having a ShortChannelID starting at 'height'
×
NEW
2380
                // will no longer be confirmed.
×
NEW
2381
                startShortChanID = lnwire.ShortChannelID{
×
NEW
2382
                        BlockHeight: height,
×
NEW
2383
                }
×
NEW
2384

×
NEW
2385
                // Delete everything after this height from the db up until the
×
NEW
2386
                // SCID alias range.
×
NEW
2387
                endShortChanID = aliasmgr.StartingAlias
×
NEW
2388

×
NEW
2389
                removedChans []*models.ChannelEdgeInfo
×
NEW
2390
        )
×
NEW
2391

×
NEW
2392
        var chanIDStart [8]byte
×
NEW
2393
        byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64())
×
NEW
2394
        var chanIDEnd [8]byte
×
NEW
2395
        byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64())
×
NEW
2396

×
NEW
2397
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2398
                rows, err := db.GetChannelsBySCIDRange(
×
NEW
2399
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
NEW
2400
                                StartScid: chanIDStart[:],
×
NEW
2401
                                EndScid:   chanIDEnd[:],
×
NEW
2402
                        },
×
NEW
2403
                )
×
NEW
2404
                if err != nil {
×
NEW
2405
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
2406
                }
×
2407

NEW
2408
                for _, row := range rows {
×
NEW
2409
                        node1, node2, err := buildNodeVertices(
×
NEW
2410
                                row.Node1PubKey, row.Node2PubKey,
×
NEW
2411
                        )
×
NEW
2412
                        if err != nil {
×
NEW
2413
                                return err
×
NEW
2414
                        }
×
2415

NEW
2416
                        channel, err := getAndBuildEdgeInfo(
×
NEW
2417
                                ctx, db, s.cfg.ChainHash, row.ID, row,
×
NEW
2418
                                node1, node2,
×
NEW
2419
                        )
×
NEW
2420
                        if err != nil {
×
NEW
2421
                                return err
×
NEW
2422
                        }
×
2423

NEW
2424
                        err = db.DeleteChannel(ctx, row.ID)
×
NEW
2425
                        if err != nil {
×
NEW
2426
                                return fmt.Errorf("unable to delete "+
×
NEW
2427
                                        "channel: %w", err)
×
NEW
2428
                        }
×
2429

NEW
2430
                        removedChans = append(removedChans, channel)
×
2431

2432
                }
2433

NEW
2434
                return db.DeletePruneLogEntriesInRange(
×
NEW
2435
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
NEW
2436
                                StartHeight: int64(height),
×
NEW
2437
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
NEW
2438
                        },
×
NEW
2439
                )
×
NEW
2440
        }, func() {
×
NEW
2441
                removedChans = nil
×
NEW
2442
        })
×
NEW
2443
        if err != nil {
×
NEW
2444
                return nil, fmt.Errorf("unable to disconnect block at "+
×
NEW
2445
                        "height: %w", err)
×
NEW
2446
        }
×
2447

NEW
2448
        for _, channel := range removedChans {
×
NEW
2449
                s.rejectCache.remove(channel.ChannelID)
×
NEW
2450
                s.chanCache.remove(channel.ChannelID)
×
NEW
2451
        }
×
2452

NEW
2453
        return removedChans, nil
×
2454
}
2455

2456
// AddEdgeProof sets the proof of an existing edge in the graph database.
2457
//
2458
// NOTE: part of the V1Store interface.
2459
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
NEW
2460
        proof *models.ChannelAuthProof) error {
×
NEW
2461

×
NEW
2462
        var (
×
NEW
2463
                ctx       = context.TODO()
×
NEW
2464
                scidBytes [8]byte
×
NEW
2465
        )
×
NEW
2466
        byteOrder.PutUint64(scidBytes[:], scid.ToUint64())
×
NEW
2467

×
NEW
2468
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2469
                dbChan, err := db.GetChannelBySCID(
×
NEW
2470
                        ctx, sqlc.GetChannelBySCIDParams{
×
NEW
2471
                                Scid:    scidBytes[:],
×
NEW
2472
                                Version: int16(ProtocolV1),
×
NEW
2473
                        },
×
NEW
2474
                )
×
NEW
2475
                if err != nil {
×
NEW
2476
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
NEW
2477
                }
×
2478

NEW
2479
                return db.AddV1ChannelProof(
×
NEW
2480
                        ctx, sqlc.AddV1ChannelProofParams{
×
NEW
2481
                                ID:                dbChan.ID,
×
NEW
2482
                                Node1Signature:    proof.NodeSig1Bytes,
×
NEW
2483
                                Node2Signature:    proof.NodeSig2Bytes,
×
NEW
2484
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
NEW
2485
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
NEW
2486
                        },
×
NEW
2487
                )
×
NEW
2488
        }, func() {})
×
NEW
2489
        if err != nil {
×
NEW
2490
                return fmt.Errorf("unable to add edge proof: %w", err)
×
NEW
2491
        }
×
2492

NEW
2493
        return nil
×
2494
}
2495

2496
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2497
// that we can ignore channel announcements that we know to be closed without
2498
// having to validate them and fetch a block.
2499
//
2500
// NOTE: part of the V1Store interface.
NEW
2501
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
NEW
2502
        ctx := context.TODO()
×
NEW
2503

×
NEW
2504
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2505
                var chanIDB [8]byte
×
NEW
2506
                byteOrder.PutUint64(chanIDB[:], scid.ToUint64())
×
NEW
2507

×
NEW
2508
                return db.InsertClosedChannel(ctx, chanIDB[:])
×
NEW
2509
        }, func() {})
×
2510
}
2511

2512
// IsClosedScid checks whether a channel identified by the passed in scid is
2513
// closed. This helps avoid having to perform expensive validation checks.
2514
//
2515
// NOTE: part of the V1Store interface.
NEW
2516
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
NEW
2517
        var (
×
NEW
2518
                ctx      = context.TODO()
×
NEW
2519
                isClosed bool
×
NEW
2520
        )
×
NEW
2521
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2522
                var chanIDB [8]byte
×
NEW
2523
                byteOrder.PutUint64(chanIDB[:], scid.ToUint64())
×
NEW
2524
                var err error
×
NEW
2525
                isClosed, err = db.IsClosedChannel(ctx, chanIDB[:])
×
NEW
2526
                if err != nil {
×
NEW
2527
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
NEW
2528
                                err)
×
NEW
2529
                }
×
2530

NEW
2531
                return nil
×
NEW
2532
        }, func() {})
×
NEW
2533
        if err != nil {
×
NEW
2534
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
NEW
2535
                        err)
×
NEW
2536
        }
×
2537

NEW
2538
        return isClosed, nil
×
2539
}
2540

2541
// GraphSession will provide the call-back with access to a NodeTraverser
2542
// instance which can be used to perform queries against the channel graph.
2543
//
2544
// NOTE: part of the V1Store interface.
NEW
2545
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error) error {
×
NEW
2546
        var ctx = context.TODO()
×
NEW
2547

×
NEW
2548
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2549
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
NEW
2550
        }, func() {})
×
2551
}
2552

2553
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2554
// read only transaction for a consistent view of the graph.
2555
type sqlNodeTraverser struct {
2556
        db    SQLQueries
2557
        chain chainhash.Hash
2558
}
2559

2560
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2561
// NodeTraverser interface.
2562
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2563

2564
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2565
func newSQLNodeTraverser(db SQLQueries,
NEW
2566
        chain chainhash.Hash) *sqlNodeTraverser {
×
NEW
2567

×
NEW
2568
        return &sqlNodeTraverser{
×
NEW
2569
                db:    db,
×
NEW
2570
                chain: chain,
×
NEW
2571
        }
×
NEW
2572
}
×
2573

2574
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2575
// node.
2576
//
2577
// NOTE: Part of the NodeTraverser interface.
2578
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
NEW
2579
        cb func(channel *DirectedChannel) error) error {
×
NEW
2580

×
NEW
2581
        ctx := context.TODO()
×
NEW
2582

×
NEW
2583
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
NEW
2584
}
×
2585

2586
// FetchNodeFeatures returns the features of the given node. If the node is
2587
// unknown, assume no additional features are supported.
2588
//
2589
// NOTE: Part of the NodeTraverser interface.
2590
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
NEW
2591
        *lnwire.FeatureVector, error) {
×
NEW
2592

×
NEW
2593
        ctx := context.TODO()
×
NEW
2594

×
NEW
2595
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
NEW
2596
}
×
2597

2598
// forEachNodeDirectedChannel iterates through all channels of a given
2599
// node, executing the passed callback on the directed edge representing the
2600
// channel and its incoming policy. If the node is not found, no error is
2601
// returned.
2602
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
NEW
2603
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
NEW
2604

×
NEW
2605
        toNodeCallback := func() route.Vertex {
×
NEW
2606
                return nodePub
×
NEW
2607
        }
×
2608

NEW
2609
        dbNode, err := db.GetNodeByPubKey(
×
NEW
2610
                ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
2611
                        Version: int16(ProtocolV1),
×
NEW
2612
                        PubKey:  nodePub[:],
×
NEW
2613
                },
×
NEW
2614
        )
×
NEW
2615
        if errors.Is(err, sql.ErrNoRows) {
×
NEW
2616
                return nil
×
NEW
2617
        } else if err != nil {
×
NEW
2618
                return fmt.Errorf("unable to fetch node: %w", err)
×
NEW
2619
        }
×
2620

NEW
2621
        features, err := getNodeFeatures(ctx, db, dbNode.ID)
×
NEW
2622
        if err != nil {
×
NEW
2623
                return fmt.Errorf("unable to fetch node features: %w", err)
×
NEW
2624
        }
×
2625

NEW
2626
        rows, err := db.ListChannelsByNodeID(
×
NEW
2627
                ctx, sqlc.ListChannelsByNodeIDParams{
×
NEW
2628
                        Version: int16(ProtocolV1),
×
NEW
2629
                        NodeID1: dbNode.ID,
×
NEW
2630
                },
×
NEW
2631
        )
×
NEW
2632
        if err != nil {
×
NEW
2633
                return fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
2634
        }
×
2635

NEW
2636
        for _, row := range rows {
×
NEW
2637
                node1, node2, err := buildNodeVertices(
×
NEW
2638
                        row.Node1Pubkey, row.Node2Pubkey,
×
NEW
2639
                )
×
NEW
2640
                if err != nil {
×
NEW
2641
                        return fmt.Errorf("unable to build node vertices: %w",
×
NEW
2642
                                err)
×
NEW
2643
                }
×
2644

NEW
2645
                edge, err := buildCacheableChannelInfo(row, node1, node2)
×
NEW
2646
                if err != nil {
×
NEW
2647
                        return err
×
NEW
2648
                }
×
2649

NEW
2650
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
NEW
2651
                if err != nil {
×
NEW
2652
                        return err
×
NEW
2653
                }
×
2654

NEW
2655
                var p1, p2 *models.CachedEdgePolicy
×
NEW
2656
                if dbPol1 != nil {
×
NEW
2657
                        policy1, err := buildChanPolicy(
×
NEW
2658
                                *dbPol1, edge.ChannelID, nil, node2, true,
×
NEW
2659
                        )
×
NEW
2660
                        if err != nil {
×
NEW
2661
                                return err
×
NEW
2662
                        }
×
2663

NEW
2664
                        p1 = models.NewCachedPolicy(policy1)
×
2665
                }
NEW
2666
                if dbPol2 != nil {
×
NEW
2667
                        policy2, err := buildChanPolicy(
×
NEW
2668
                                *dbPol2, edge.ChannelID, nil, node1, false,
×
NEW
2669
                        )
×
NEW
2670
                        if err != nil {
×
NEW
2671
                                return err
×
NEW
2672
                        }
×
2673

NEW
2674
                        p2 = models.NewCachedPolicy(policy2)
×
2675
                }
2676

2677
                // Determine the outgoing and incoming policy for this
2678
                // channel and node combo.
NEW
2679
                outPolicy, inPolicy := p1, p2
×
NEW
2680
                if p1 != nil && node2 == nodePub {
×
NEW
2681
                        outPolicy, inPolicy = p2, p1
×
NEW
2682
                } else if p2 != nil && node1 != nodePub {
×
NEW
2683
                        outPolicy, inPolicy = p2, p1
×
NEW
2684
                }
×
2685

NEW
2686
                var cachedInPolicy *models.CachedEdgePolicy
×
NEW
2687
                if inPolicy != nil {
×
NEW
2688
                        cachedInPolicy = inPolicy
×
NEW
2689
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
NEW
2690
                        cachedInPolicy.ToNodeFeatures = features
×
NEW
2691
                }
×
2692

NEW
2693
                directedChannel := &DirectedChannel{
×
NEW
2694
                        ChannelID:    edge.ChannelID,
×
NEW
2695
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
NEW
2696
                        OtherNode:    edge.NodeKey2Bytes,
×
NEW
2697
                        Capacity:     edge.Capacity,
×
NEW
2698
                        OutPolicySet: outPolicy != nil,
×
NEW
2699
                        InPolicy:     cachedInPolicy,
×
NEW
2700
                }
×
NEW
2701
                if outPolicy != nil {
×
NEW
2702
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
NEW
2703
                                directedChannel.InboundFee = fee
×
NEW
2704
                        })
×
2705
                }
2706

NEW
2707
                if nodePub == edge.NodeKey2Bytes {
×
NEW
2708
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
NEW
2709
                }
×
2710

NEW
2711
                if err := cb(directedChannel); err != nil {
×
NEW
2712
                        return err
×
NEW
2713
                }
×
2714
        }
2715

NEW
2716
        return nil
×
2717
}
2718

2719
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2720
// and executes the provided callback for each node.
2721
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
NEW
2722
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
NEW
2723

×
NEW
2724
        nodes, err := db.ListNodeIDsAndPubKeys(ctx, int16(ProtocolV1))
×
NEW
2725
        if err != nil {
×
NEW
2726
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
NEW
2727
        }
×
2728

NEW
2729
        for _, node := range nodes {
×
NEW
2730
                var pub route.Vertex
×
NEW
2731
                copy(pub[:], node.PubKey)
×
NEW
2732

×
NEW
2733
                if err := cb(node.ID, pub); err != nil {
×
NEW
2734
                        return fmt.Errorf("callback failed: %w", err)
×
NEW
2735
                }
×
2736
        }
2737

NEW
2738
        return nil
×
2739
}
2740

2741
// forEachNodeChannel iterates through all channels of a node, executing
2742
// the passed callback on each. The call-back is provided with the channel's
2743
// edge information, the outgoing policy and the incoming policy for the
2744
// channel and node combo.
2745
func forEachNodeChannel(ctx context.Context, db SQLQueries,
2746
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
2747
                *models.ChannelEdgePolicy,
NEW
2748
                *models.ChannelEdgePolicy) error) error {
×
NEW
2749

×
NEW
2750
        // Get all the V1 channels for this node.
×
NEW
2751
        rows, err := db.ListChannelsByNodeID(
×
NEW
2752
                ctx, sqlc.ListChannelsByNodeIDParams{
×
NEW
2753
                        Version: int16(ProtocolV1),
×
NEW
2754
                        NodeID1: id,
×
NEW
2755
                },
×
NEW
2756
        )
×
NEW
2757
        if err != nil {
×
NEW
2758
                return fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
2759
        }
×
2760

2761
        // Call the call-back for each channel and its known policies.
NEW
2762
        for _, row := range rows {
×
NEW
2763
                node1, node2, err := buildNodeVertices(
×
NEW
2764
                        row.Node1Pubkey, row.Node2Pubkey,
×
NEW
2765
                )
×
NEW
2766
                if err != nil {
×
NEW
2767
                        return fmt.Errorf("unable to build node vertices: %w",
×
NEW
2768
                                err)
×
NEW
2769
                }
×
2770

NEW
2771
                edge, p1, p2, err := getAndBuildEdgeInfoAndPolicies(
×
NEW
2772
                        ctx, db, chain, row.ID, row, node1, node2,
×
NEW
2773
                )
×
NEW
2774
                if err != nil {
×
NEW
2775
                        return fmt.Errorf("unable to build channel "+
×
NEW
2776
                                "info and policies: %w", err)
×
NEW
2777
                }
×
2778

2779
                // Determine the outgoing and incoming policy for this
2780
                // channel and node combo.
NEW
2781
                p1ToNode := row.NodeID2
×
NEW
2782
                p2ToNode := row.NodeID1
×
NEW
2783
                outPolicy, inPolicy := p1, p2
×
NEW
2784
                if (p1 != nil && p1ToNode == id) ||
×
NEW
2785
                        (p2 != nil && p2ToNode != id) {
×
NEW
2786

×
NEW
2787
                        outPolicy, inPolicy = p2, p1
×
NEW
2788
                }
×
2789

NEW
2790
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
NEW
2791
                        return err
×
NEW
2792
                }
×
2793
        }
2794

NEW
2795
        return nil
×
2796
}
2797

2798
// updateChanEdgePolicy upserts the channel policy info we have stored for
2799
// a channel we already know of.
2800
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
2801
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
NEW
2802
        error) {
×
NEW
2803

×
NEW
2804
        var (
×
NEW
2805
                node1Pub, node2Pub route.Vertex
×
NEW
2806
                isNode1            bool
×
NEW
2807
                chanIDB            [8]byte
×
NEW
2808
        )
×
NEW
2809
        byteOrder.PutUint64(chanIDB[:], edge.ChannelID)
×
NEW
2810

×
NEW
2811
        // Check that this edge policy refers to a channel that we already
×
NEW
2812
        // know of. We do this explicitly so that we can return the appropriate
×
NEW
2813
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
NEW
2814
        // abort the transaction which would abort the entire batch.
×
NEW
2815
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
NEW
2816
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
NEW
2817
                        Scid:    chanIDB[:],
×
NEW
2818
                        Version: int16(ProtocolV1),
×
NEW
2819
                },
×
NEW
2820
        )
×
NEW
2821
        if errors.Is(err, sql.ErrNoRows) {
×
NEW
2822
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
NEW
2823
        } else if err != nil {
×
NEW
2824
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
NEW
2825
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
NEW
2826
        }
×
2827

NEW
2828
        copy(node1Pub[:], dbChan.Node1PubKey)
×
NEW
2829
        copy(node2Pub[:], dbChan.Node2PubKey)
×
NEW
2830

×
NEW
2831
        // Figure out which node this edge is from.
×
NEW
2832
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
NEW
2833
        nodeID := dbChan.NodeID1
×
NEW
2834
        if !isNode1 {
×
NEW
2835
                nodeID = dbChan.NodeID2
×
NEW
2836
        }
×
2837

NEW
2838
        var (
×
NEW
2839
                inboundBase sql.NullInt64
×
NEW
2840
                inboundRate sql.NullInt64
×
NEW
2841
        )
×
NEW
2842
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
NEW
2843
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
NEW
2844
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
NEW
2845
        })
×
2846

NEW
2847
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
NEW
2848
                Version:     int16(ProtocolV1),
×
NEW
2849
                ChannelID:   dbChan.ID,
×
NEW
2850
                NodeID:      nodeID,
×
NEW
2851
                Timelock:    int32(edge.TimeLockDelta),
×
NEW
2852
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
NEW
2853
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
NEW
2854
                MinHtlcMsat: int64(edge.MinHTLC),
×
NEW
2855
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
NEW
2856
                Disabled: sql.NullBool{
×
NEW
2857
                        Valid: true,
×
NEW
2858
                        Bool:  edge.IsDisabled(),
×
NEW
2859
                },
×
NEW
2860
                MaxHtlcMsat: sql.NullInt64{
×
NEW
2861
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
NEW
2862
                        Int64: int64(edge.MaxHTLC),
×
NEW
2863
                },
×
NEW
2864
                InboundBaseFeeMsat:      inboundBase,
×
NEW
2865
                InboundFeeRateMilliMsat: inboundRate,
×
NEW
2866
                Signature:               edge.SigBytes,
×
NEW
2867
        })
×
NEW
2868
        if err != nil {
×
NEW
2869
                return node1Pub, node2Pub, isNode1,
×
NEW
2870
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
NEW
2871
        }
×
2872

2873
        // Convert the flat extra opaque data into a map of TLV types to
2874
        // values.
NEW
2875
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
NEW
2876
        if err != nil {
×
NEW
2877
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
NEW
2878
                        "marshal extra opaque data: %w", err)
×
NEW
2879
        }
×
2880

2881
        // Update the channel policy's extra signed fields.
NEW
2882
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
2883
        if err != nil {
×
NEW
2884
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
NEW
2885
                        "policy extra TLVs: %w", err)
×
UNCOV
2886
        }
×
2887

NEW
2888
        return node1Pub, node2Pub, isNode1, nil
×
2889
}
2890

2891
// getNodeByPubKey attempts to look up a target node by its public key.
2892
func getNodeByPubKey(ctx context.Context, db SQLQueries,
2893
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
2894

×
2895
        dbNode, err := db.GetNodeByPubKey(
×
2896
                ctx, sqlc.GetNodeByPubKeyParams{
×
2897
                        Version: int16(ProtocolV1),
×
2898
                        PubKey:  pubKey[:],
×
2899
                },
×
2900
        )
×
2901
        if errors.Is(err, sql.ErrNoRows) {
×
2902
                return 0, nil, ErrGraphNodeNotFound
×
2903
        } else if err != nil {
×
2904
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
2905
        }
×
2906

2907
        node, err := buildNode(ctx, db, &dbNode)
×
2908
        if err != nil {
×
2909
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
2910
        }
×
2911

2912
        return dbNode.ID, node, nil
×
2913
}
2914

2915
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
2916
// provided database channel row and the public keys of the two nodes
2917
// involved in the channel.
2918
func buildCacheableChannelInfo(row any, node1Pub,
NEW
2919
        node2Pub route.Vertex) (*models.CachedEdgeInfo, error) {
×
NEW
2920

×
NEW
2921
        dbChan, err := extractChannel(row)
×
NEW
2922
        if err != nil {
×
NEW
2923
                return nil, err
×
NEW
2924
        }
×
2925

NEW
2926
        return &models.CachedEdgeInfo{
×
NEW
2927
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
NEW
2928
                NodeKey1Bytes: node1Pub,
×
NEW
2929
                NodeKey2Bytes: node2Pub,
×
NEW
2930
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
NEW
2931
        }, nil
×
2932
}
2933

2934
// buildNode constructs a LightningNode instance from the given database node
2935
// record. The node's features, addresses and extra signed fields are also
2936
// fetched from the database and set on the node.
2937
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
2938
        *models.LightningNode, error) {
×
2939

×
2940
        if dbNode.Version != int16(ProtocolV1) {
×
2941
                return nil, fmt.Errorf("unsupported node version: %d",
×
2942
                        dbNode.Version)
×
2943
        }
×
2944

2945
        var pub [33]byte
×
2946
        copy(pub[:], dbNode.PubKey)
×
2947

×
2948
        node := &models.LightningNode{
×
2949
                PubKeyBytes: pub,
×
2950
                Features:    lnwire.EmptyFeatureVector(),
×
2951
                LastUpdate:  time.Unix(0, 0),
×
2952
        }
×
2953

×
2954
        if len(dbNode.Signature) == 0 {
×
2955
                return node, nil
×
2956
        }
×
2957

2958
        node.HaveNodeAnnouncement = true
×
2959
        node.AuthSigBytes = dbNode.Signature
×
2960
        node.Alias = dbNode.Alias.String
×
2961
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
2962

×
2963
        var err error
×
NEW
2964
        if dbNode.Color.Valid {
×
NEW
2965
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
NEW
2966
                if err != nil {
×
NEW
2967
                        return nil, fmt.Errorf("unable to decode color: %w",
×
NEW
2968
                                err)
×
NEW
2969
                }
×
2970
        }
2971

2972
        // Fetch the node's features.
2973
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
2974
        if err != nil {
×
2975
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
2976
                        "features: %w", dbNode.ID, err)
×
2977
        }
×
2978

2979
        // Fetch the node's addresses.
2980
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
2981
        if err != nil {
×
2982
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
2983
                        "addresses: %w", dbNode.ID, err)
×
2984
        }
×
2985

2986
        // Fetch the node's extra signed fields.
2987
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
2988
        if err != nil {
×
2989
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
2990
                        "extra signed fields: %w", dbNode.ID, err)
×
2991
        }
×
2992

2993
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
2994
        if err != nil {
×
2995
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
2996
                        "fields: %w", err)
×
2997
        }
×
2998

2999
        if len(recs) != 0 {
×
3000
                node.ExtraOpaqueData = recs
×
3001
        }
×
3002

3003
        return node, nil
×
3004
}
3005

3006
// getNodeFeatures fetches the feature bits and constructs the feature vector
3007
// for a node with the given DB ID.
3008
func getNodeFeatures(ctx context.Context, db SQLQueries,
3009
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3010

×
3011
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3012
        if err != nil {
×
3013
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3014
                        nodeID, err)
×
3015
        }
×
3016

3017
        features := lnwire.EmptyFeatureVector()
×
3018
        for _, feature := range rows {
×
3019
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3020
        }
×
3021

3022
        return features, nil
×
3023
}
3024

3025
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3026
// given DB ID.
3027
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3028
        nodeID int64) (map[uint64][]byte, error) {
×
3029

×
3030
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3031
        if err != nil {
×
3032
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
3033
                        "signed fields: %w", nodeID, err)
×
3034
        }
×
3035

3036
        extraFields := make(map[uint64][]byte)
×
3037
        for _, field := range fields {
×
3038
                extraFields[uint64(field.Type)] = field.Value
×
3039
        }
×
3040

3041
        return extraFields, nil
×
3042
}
3043

3044
// upsertNode upserts the node record into the database. If the node already
3045
// exists, then the node's information is updated. If the node doesn't exist,
3046
// then a new node is created. The node's features, addresses and extra TLV
3047
// types are also updated. The node's DB ID is returned.
3048
func upsertNode(ctx context.Context, db SQLQueries,
3049
        node *models.LightningNode) (int64, error) {
×
3050

×
3051
        params := sqlc.UpsertNodeParams{
×
3052
                Version: int16(ProtocolV1),
×
3053
                PubKey:  node.PubKeyBytes[:],
×
3054
        }
×
3055

×
3056
        if node.HaveNodeAnnouncement {
×
3057
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3058
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3059
                params.Alias = sqldb.SQLStr(node.Alias)
×
3060
                params.Signature = node.AuthSigBytes
×
3061
        }
×
3062

3063
        nodeID, err := db.UpsertNode(ctx, params)
×
3064
        if err != nil {
×
3065
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3066
                        err)
×
3067
        }
×
3068

3069
        // We can exit here if we don't have the announcement yet.
3070
        if !node.HaveNodeAnnouncement {
×
3071
                return nodeID, nil
×
3072
        }
×
3073

3074
        // Update the node's features.
3075
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3076
        if err != nil {
×
3077
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3078
        }
×
3079

3080
        // Update the node's addresses.
3081
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3082
        if err != nil {
×
3083
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3084
        }
×
3085

3086
        // Convert the flat extra opaque data into a map of TLV types to
3087
        // values.
3088
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3089
        if err != nil {
×
3090
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3091
                        err)
×
3092
        }
×
3093

3094
        // Update the node's extra signed fields.
3095
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3096
        if err != nil {
×
3097
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3098
        }
×
3099

3100
        return nodeID, nil
×
3101
}
3102

3103
// upsertNodeFeatures updates the node's features node_features table. This
3104
// includes deleting any feature bits no longer present and inserting any new
3105
// feature bits. If the feature bit does not yet exist in the features table,
3106
// then an entry is created in that table first.
3107
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3108
        features *lnwire.FeatureVector) error {
×
3109

×
3110
        // Get any existing features for the node.
×
3111
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3112
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3113
                return err
×
3114
        }
×
3115

3116
        // Copy the nodes latest set of feature bits.
3117
        newFeatures := make(map[int32]struct{})
×
3118
        if features != nil {
×
3119
                for feature := range features.Features() {
×
3120
                        newFeatures[int32(feature)] = struct{}{}
×
3121
                }
×
3122
        }
3123

3124
        // For any current feature that already exists in the DB, remove it from
3125
        // the in-memory map. For any existing feature that does not exist in
3126
        // the in-memory map, delete it from the database.
3127
        for _, feature := range existingFeatures {
×
3128
                // The feature is still present, so there are no updates to be
×
3129
                // made.
×
3130
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3131
                        delete(newFeatures, feature.FeatureBit)
×
3132
                        continue
×
3133
                }
3134

3135
                // The feature is no longer present, so we remove it from the
3136
                // database.
3137
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3138
                        NodeID:     nodeID,
×
3139
                        FeatureBit: feature.FeatureBit,
×
3140
                })
×
3141
                if err != nil {
×
3142
                        return fmt.Errorf("unable to delete node(%d) "+
×
3143
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3144
                                err)
×
3145
                }
×
3146
        }
3147

3148
        // Any remaining entries in newFeatures are new features that need to be
3149
        // added to the database for the first time.
3150
        for feature := range newFeatures {
×
3151
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3152
                        NodeID:     nodeID,
×
3153
                        FeatureBit: feature,
×
3154
                })
×
3155
                if err != nil {
×
3156
                        return fmt.Errorf("unable to insert node(%d) "+
×
3157
                                "feature(%v): %w", nodeID, feature, err)
×
3158
                }
×
3159
        }
3160

3161
        return nil
×
3162
}
3163

3164
// fetchNodeFeatures fetches the features for a node with the given public key.
3165
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3166
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3167

×
3168
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3169
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3170
                        PubKey:  nodePub[:],
×
3171
                        Version: int16(ProtocolV1),
×
3172
                },
×
3173
        )
×
3174
        if err != nil {
×
3175
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3176
                        nodePub, err)
×
3177
        }
×
3178

3179
        features := lnwire.EmptyFeatureVector()
×
3180
        for _, bit := range rows {
×
3181
                features.Set(lnwire.FeatureBit(bit))
×
3182
        }
×
3183

3184
        return features, nil
×
3185
}
3186

3187
// dbAddressType is an enum type that represents the different address types
3188
// that we store in the node_addresses table. The address type determines how
3189
// the address is to be serialised/deserialize.
3190
type dbAddressType uint8
3191

3192
const (
3193
        addressTypeIPv4   dbAddressType = 1
3194
        addressTypeIPv6   dbAddressType = 2
3195
        addressTypeTorV2  dbAddressType = 3
3196
        addressTypeTorV3  dbAddressType = 4
3197
        addressTypeOpaque dbAddressType = math.MaxInt8
3198
)
3199

3200
// upsertNodeAddresses updates the node's addresses in the database. This
3201
// includes deleting any existing addresses and inserting the new set of
3202
// addresses. The deletion is necessary since the ordering of the addresses may
3203
// change, and we need to ensure that the database reflects the latest set of
3204
// addresses so that at the time of reconstructing the node announcement, the
3205
// order is preserved and the signature over the message remains valid.
3206
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3207
        addresses []net.Addr) error {
×
3208

×
3209
        // Delete any existing addresses for the node. This is required since
×
3210
        // even if the new set of addresses is the same, the ordering may have
×
3211
        // changed for a given address type.
×
3212
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3213
        if err != nil {
×
3214
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3215
                        nodeID, err)
×
3216
        }
×
3217

3218
        // Copy the nodes latest set of addresses.
3219
        newAddresses := map[dbAddressType][]string{
×
3220
                addressTypeIPv4:   {},
×
3221
                addressTypeIPv6:   {},
×
3222
                addressTypeTorV2:  {},
×
3223
                addressTypeTorV3:  {},
×
3224
                addressTypeOpaque: {},
×
3225
        }
×
3226
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3227
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3228
        }
×
3229

3230
        for _, address := range addresses {
×
3231
                switch addr := address.(type) {
×
3232
                case *net.TCPAddr:
×
3233
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3234
                                addAddr(addressTypeIPv4, addr)
×
3235
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3236
                                addAddr(addressTypeIPv6, addr)
×
3237
                        } else {
×
3238
                                return fmt.Errorf("unhandled IP address: %v",
×
3239
                                        addr)
×
3240
                        }
×
3241

3242
                case *tor.OnionAddr:
×
3243
                        switch len(addr.OnionService) {
×
3244
                        case tor.V2Len:
×
3245
                                addAddr(addressTypeTorV2, addr)
×
3246
                        case tor.V3Len:
×
3247
                                addAddr(addressTypeTorV3, addr)
×
3248
                        default:
×
3249
                                return fmt.Errorf("invalid length for a tor " +
×
3250
                                        "address")
×
3251
                        }
3252

3253
                case *lnwire.OpaqueAddrs:
×
3254
                        addAddr(addressTypeOpaque, addr)
×
3255

3256
                default:
×
3257
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3258
                }
3259
        }
3260

3261
        // Any remaining entries in newAddresses are new addresses that need to
3262
        // be added to the database for the first time.
3263
        for addrType, addrList := range newAddresses {
×
3264
                for position, addr := range addrList {
×
3265
                        err := db.InsertNodeAddress(
×
3266
                                ctx, sqlc.InsertNodeAddressParams{
×
3267
                                        NodeID:   nodeID,
×
3268
                                        Type:     int16(addrType),
×
3269
                                        Address:  addr,
×
3270
                                        Position: int32(position),
×
3271
                                },
×
3272
                        )
×
3273
                        if err != nil {
×
3274
                                return fmt.Errorf("unable to insert "+
×
3275
                                        "node(%d) address(%v): %w", nodeID,
×
3276
                                        addr, err)
×
3277
                        }
×
3278
                }
3279
        }
3280

3281
        return nil
×
3282
}
3283

3284
// getNodeAddresses fetches the addresses for a node with the given public key.
3285
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3286
        []net.Addr, error) {
×
3287

×
3288
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
3289
        // are returned in the same order as they were inserted.
×
3290
        rows, err := db.GetNodeAddressesByPubKey(
×
3291
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
3292
                        Version: int16(ProtocolV1),
×
3293
                        PubKey:  nodePub,
×
3294
                },
×
3295
        )
×
3296
        if err != nil {
×
3297
                return false, nil, err
×
3298
        }
×
3299

3300
        // GetNodeAddressesByPubKey uses a left join so there should always be
3301
        // at least one row returned if the node exists even if it has no
3302
        // addresses.
3303
        if len(rows) == 0 {
×
3304
                return false, nil, nil
×
3305
        }
×
3306

3307
        addresses := make([]net.Addr, 0, len(rows))
×
3308
        for _, addr := range rows {
×
3309
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3310
                        continue
×
3311
                }
3312

3313
                address := addr.Address.String
×
3314

×
3315
                switch dbAddressType(addr.Type.Int16) {
×
3316
                case addressTypeIPv4:
×
3317
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3318
                        if err != nil {
×
3319
                                return false, nil, nil
×
3320
                        }
×
3321
                        tcp.IP = tcp.IP.To4()
×
3322

×
3323
                        addresses = append(addresses, tcp)
×
3324

3325
                case addressTypeIPv6:
×
3326
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3327
                        if err != nil {
×
3328
                                return false, nil, nil
×
3329
                        }
×
3330
                        addresses = append(addresses, tcp)
×
3331

3332
                case addressTypeTorV3, addressTypeTorV2:
×
3333
                        service, portStr, err := net.SplitHostPort(address)
×
3334
                        if err != nil {
×
3335
                                return false, nil, fmt.Errorf("unable to "+
×
3336
                                        "split tor v3 address: %v",
×
3337
                                        addr.Address)
×
3338
                        }
×
3339

3340
                        port, err := strconv.Atoi(portStr)
×
3341
                        if err != nil {
×
3342
                                return false, nil, err
×
3343
                        }
×
3344

3345
                        addresses = append(addresses, &tor.OnionAddr{
×
3346
                                OnionService: service,
×
3347
                                Port:         port,
×
3348
                        })
×
3349

3350
                case addressTypeOpaque:
×
3351
                        opaque, err := hex.DecodeString(address)
×
3352
                        if err != nil {
×
3353
                                return false, nil, fmt.Errorf("unable to "+
×
3354
                                        "decode opaque address: %v", addr)
×
3355
                        }
×
3356

3357
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3358
                                Payload: opaque,
×
3359
                        })
×
3360

3361
                default:
×
3362
                        return false, nil, fmt.Errorf("unknown address "+
×
3363
                                "type: %v", addr.Type)
×
3364
                }
3365
        }
3366

3367
        return true, addresses, nil
×
3368
}
3369

3370
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3371
// database. This includes updating any existing types, inserting any new types,
3372
// and deleting any types that are no longer present.
3373
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3374
        nodeID int64, extraFields map[uint64][]byte) error {
×
3375

×
3376
        // Get any existing extra signed fields for the node.
×
3377
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3378
        if err != nil {
×
3379
                return err
×
3380
        }
×
3381

3382
        // Make a lookup map of the existing field types so that we can use it
3383
        // to keep track of any fields we should delete.
3384
        m := make(map[uint64]bool)
×
3385
        for _, field := range existingFields {
×
3386
                m[uint64(field.Type)] = true
×
3387
        }
×
3388

3389
        // For all the new fields, we'll upsert them and remove them from the
3390
        // map of existing fields.
3391
        for tlvType, value := range extraFields {
×
3392
                err = db.UpsertNodeExtraType(
×
3393
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3394
                                NodeID: nodeID,
×
3395
                                Type:   int64(tlvType),
×
3396
                                Value:  value,
×
3397
                        },
×
3398
                )
×
3399
                if err != nil {
×
3400
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3401
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3402
                }
×
3403

3404
                // Remove the field from the map of existing fields if it was
3405
                // present.
3406
                delete(m, tlvType)
×
3407
        }
3408

3409
        // For all the fields that are left in the map of existing fields, we'll
3410
        // delete them as they are no longer present in the new set of fields.
3411
        for tlvType := range m {
×
3412
                err = db.DeleteExtraNodeType(
×
3413
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3414
                                NodeID: nodeID,
×
3415
                                Type:   int64(tlvType),
×
3416
                        },
×
3417
                )
×
3418
                if err != nil {
×
3419
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3420
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3421
                }
×
3422
        }
3423

3424
        return nil
×
3425
}
3426

3427
// getSourceNode returns the DB node ID and pub key of the source node for the
3428
// specified protocol version.
3429
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3430
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3431

×
NEW
3432
        s.srcNodeMu.Lock()
×
NEW
3433
        defer s.srcNodeMu.Unlock()
×
NEW
3434

×
NEW
3435
        // If we already have the source node ID and pub key cached, then
×
NEW
3436
        // return them.
×
NEW
3437
        if s.srcNodeID != 0 {
×
NEW
3438
                return s.srcNodeID, s.srcNodePub, nil
×
NEW
3439
        }
×
3440

3441
        var pubKey route.Vertex
×
3442

×
3443
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3444
        if err != nil {
×
3445
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3446
                        err)
×
3447
        }
×
3448

3449
        if len(nodes) == 0 {
×
3450
                return 0, pubKey, ErrSourceNodeNotSet
×
3451
        } else if len(nodes) > 1 {
×
3452
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3453
                        "protocol %s found", version)
×
3454
        }
×
3455

3456
        copy(pubKey[:], nodes[0].PubKey)
×
3457

×
NEW
3458
        s.srcNodeID = nodes[0].NodeID
×
NEW
3459
        s.srcNodePub = pubKey
×
NEW
3460

×
UNCOV
3461
        return nodes[0].NodeID, pubKey, nil
×
3462
}
3463

3464
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3465
// This then produces a map from TLV type to value. If the input is not a
3466
// valid TLV stream, then an error is returned.
3467
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3468
        r := bytes.NewReader(data)
×
3469

×
3470
        tlvStream, err := tlv.NewStream()
×
3471
        if err != nil {
×
3472
                return nil, err
×
3473
        }
×
3474

3475
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3476
        // pass it into the P2P decoding variant.
3477
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3478
        if err != nil {
×
NEW
3479
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3480
        }
×
3481
        if len(parsedTypes) == 0 {
×
3482
                return nil, nil
×
3483
        }
×
3484

3485
        records := make(map[uint64][]byte)
×
3486
        for k, v := range parsedTypes {
×
3487
                records[uint64(k)] = v
×
3488
        }
×
3489

3490
        return records, nil
×
3491
}
3492

3493
type dbChanInfo struct {
3494
        channelID int64
3495
        node1ID   int64
3496
        node2ID   int64
3497
}
3498

3499
// insertChannel inserts a new channel record into the database.
3500
func insertChannel(ctx context.Context, db SQLQueries,
NEW
3501
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3502

×
3503
        var chanIDB [8]byte
×
3504
        byteOrder.PutUint64(chanIDB[:], edge.ChannelID)
×
3505

×
3506
        // Make sure that the channel doesn't already exist. We do this
×
3507
        // explicitly instead of relying on catching a unique constraint error
×
3508
        // because relying on SQL to throw that error would abort the entire
×
3509
        // batch of transactions.
×
3510
        _, err := db.GetChannelBySCID(
×
3511
                ctx, sqlc.GetChannelBySCIDParams{
×
3512
                        Scid:    chanIDB[:],
×
3513
                        Version: int16(ProtocolV1),
×
3514
                },
×
3515
        )
×
3516
        if err == nil {
×
NEW
3517
                return nil, ErrEdgeAlreadyExist
×
3518
        } else if !errors.Is(err, sql.ErrNoRows) {
×
NEW
3519
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3520
        }
×
3521

3522
        // Make sure that at least a "shell" entry for each node is present in
3523
        // the nodes table.
3524
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3525
        if err != nil {
×
NEW
3526
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3527
        }
×
3528

3529
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3530
        if err != nil {
×
NEW
3531
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3532
        }
×
3533

3534
        var capacity sql.NullInt64
×
3535
        if edge.Capacity != 0 {
×
3536
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3537
        }
×
3538

3539
        createParams := sqlc.CreateChannelParams{
×
3540
                Version:     int16(ProtocolV1),
×
3541
                Scid:        chanIDB[:],
×
3542
                NodeID1:     node1DBID,
×
3543
                NodeID2:     node2DBID,
×
3544
                Outpoint:    edge.ChannelPoint.String(),
×
3545
                Capacity:    capacity,
×
3546
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3547
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3548
        }
×
3549

×
3550
        if edge.AuthProof != nil {
×
3551
                proof := edge.AuthProof
×
3552

×
3553
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3554
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3555
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3556
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3557
        }
×
3558

3559
        // Insert the new channel record.
3560
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3561
        if err != nil {
×
NEW
3562
                return nil, err
×
3563
        }
×
3564

3565
        // Insert any channel features.
3566
        if len(edge.Features) != 0 {
×
3567
                chanFeatures := lnwire.NewRawFeatureVector()
×
3568
                err := chanFeatures.Decode(bytes.NewReader(edge.Features))
×
3569
                if err != nil {
×
NEW
3570
                        return nil, err
×
3571
                }
×
3572

3573
                fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features)
×
3574
                for feature := range fv.Features() {
×
3575
                        err = db.InsertChannelFeature(
×
3576
                                ctx, sqlc.InsertChannelFeatureParams{
×
3577
                                        ChannelID:  dbChanID,
×
3578
                                        FeatureBit: int32(feature),
×
3579
                                },
×
3580
                        )
×
3581
                        if err != nil {
×
NEW
3582
                                return nil, fmt.Errorf("unable to insert "+
×
3583
                                        "channel(%d) feature(%v): %w", dbChanID,
×
3584
                                        feature, err)
×
3585
                        }
×
3586
                }
3587
        }
3588

3589
        // Finally, insert any extra TLV fields in the channel announcement.
3590
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3591
        if err != nil {
×
NEW
3592
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
NEW
3593
                        "data: %w", err)
×
UNCOV
3594
        }
×
3595

3596
        for tlvType, value := range extra {
×
3597
                err := db.CreateChannelExtraType(
×
3598
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3599
                                ChannelID: dbChanID,
×
3600
                                Type:      int64(tlvType),
×
3601
                                Value:     value,
×
3602
                        },
×
3603
                )
×
3604
                if err != nil {
×
NEW
3605
                        return nil, fmt.Errorf("unable to upsert "+
×
NEW
3606
                                "channel(%d) extra signed field(%v): %w",
×
NEW
3607
                                edge.ChannelID, tlvType, err)
×
UNCOV
3608
                }
×
3609
        }
3610

NEW
3611
        return &dbChanInfo{
×
NEW
3612
                channelID: dbChanID,
×
NEW
3613
                node1ID:   node1DBID,
×
NEW
3614
                node2ID:   node2DBID,
×
NEW
3615
        }, nil
×
3616
}
3617

3618
// maybeCreateShellNode checks if a shell node entry exists for the
3619
// given public key. If it does not exist, then a new shell node entry is
3620
// created. The ID of the node is returned. A shell node only has a protocol
3621
// version and public key persisted.
3622
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3623
        pubKey route.Vertex) (int64, error) {
×
3624

×
3625
        dbNode, err := db.GetNodeByPubKey(
×
3626
                ctx, sqlc.GetNodeByPubKeyParams{
×
3627
                        PubKey:  pubKey[:],
×
3628
                        Version: int16(ProtocolV1),
×
3629
                },
×
3630
        )
×
3631
        // The node exists. Return the ID.
×
3632
        if err == nil {
×
3633
                return dbNode.ID, nil
×
3634
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3635
                return 0, err
×
3636
        }
×
3637

3638
        // Otherwise, the node does not exist, so we create a shell entry for
3639
        // it.
3640
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3641
                Version: int16(ProtocolV1),
×
3642
                PubKey:  pubKey[:],
×
3643
        })
×
3644
        if err != nil {
×
3645
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3646
        }
×
3647

3648
        return id, nil
×
3649
}
3650

3651
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3652
// the database. This includes updating any existing types, inserting any new
3653
// types, and deleting any types that are no longer present.
3654
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
NEW
3655
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
NEW
3656

×
NEW
3657
        // Get any existing extra signed fields for the channel policy.
×
NEW
3658
        existingFields, err := db.GetChannelPolicyExtraTypes(
×
NEW
3659
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
NEW
3660
                        ID: chanPolicyID,
×
NEW
3661
                },
×
NEW
3662
        )
×
NEW
3663
        if err != nil {
×
NEW
3664
                return err
×
NEW
3665
        }
×
3666

3667
        // Make a lookup map of the existing field types so that we can use it
3668
        // to keep track of any fields we should delete.
NEW
3669
        m := make(map[uint64]bool)
×
NEW
3670
        for _, field := range existingFields {
×
NEW
3671
                if field.PolicyID != chanPolicyID {
×
NEW
3672
                        return fmt.Errorf("channel policy ID mismatch: "+
×
NEW
3673
                                "expected %d, got %d", chanPolicyID,
×
NEW
3674
                                field.PolicyID)
×
NEW
3675
                }
×
3676

NEW
3677
                m[uint64(field.Type)] = true
×
3678
        }
3679

3680
        // For all the new fields, we'll upsert them and remove them from the
3681
        // map of existing fields.
NEW
3682
        for tlvType, value := range extraFields {
×
NEW
3683
                err = db.UpsertChanPolicyExtraType(
×
NEW
3684
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
NEW
3685
                                ChannelPolicyID: chanPolicyID,
×
NEW
3686
                                Type:            int64(tlvType),
×
NEW
3687
                                Value:           value,
×
NEW
3688
                        },
×
NEW
3689
                )
×
NEW
3690
                if err != nil {
×
NEW
3691
                        return fmt.Errorf("unable to upsert "+
×
NEW
3692
                                "channel_policy(%d) extra signed field(%v): %w",
×
NEW
3693
                                chanPolicyID, tlvType, err)
×
NEW
3694
                }
×
3695

3696
                // Remove the field from the map of existing fields if it was
3697
                // present.
NEW
3698
                delete(m, tlvType)
×
3699
        }
3700

3701
        // For all the fields that are left in the map of existing fields, we'll
3702
        // delete them as they are no longer present in the new set of fields.
NEW
3703
        for tlvType := range m {
×
NEW
3704
                err = db.DeleteChannelPolicyExtraType(
×
NEW
3705
                        ctx, sqlc.DeleteChannelPolicyExtraTypeParams{
×
NEW
3706
                                ChannelPolicyID: chanPolicyID,
×
NEW
3707
                                Type:            int64(tlvType),
×
NEW
3708
                        },
×
NEW
3709
                )
×
NEW
3710
                if err != nil {
×
NEW
3711
                        return fmt.Errorf("unable to delete "+
×
NEW
3712
                                "channel_policy(%d) extra signed field(%v): %w",
×
NEW
3713
                                chanPolicyID, tlvType, err)
×
NEW
3714
                }
×
3715
        }
3716

NEW
3717
        return nil
×
3718
}
3719

3720
// getAndBuildEdgeInfoAndPolicies fetches all the data from the DB required to
3721
// build complete models.ChannelEdgeInfo and  models.ChannelEdgePolicy instances
3722
// for a channel with the given DB ID.
3723
func getAndBuildEdgeInfoAndPolicies(ctx context.Context, db SQLQueries,
3724
        chain chainhash.Hash, dbChanID int64, dbChanRow any, node1,
3725
        node2 route.Vertex) (*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
NEW
3726
        *models.ChannelEdgePolicy, error) {
×
NEW
3727

×
NEW
3728
        edge, err := getAndBuildEdgeInfo(
×
NEW
3729
                ctx, db, chain, dbChanID, dbChanRow, node1, node2,
×
NEW
3730
        )
×
NEW
3731
        if err != nil {
×
NEW
3732
                return nil, nil, nil, err
×
NEW
3733
        }
×
3734

NEW
3735
        dbPol1, dbPol2, err := extractChannelPolicies(dbChanRow)
×
NEW
3736
        if err != nil {
×
NEW
3737
                return nil, nil, nil, err
×
NEW
3738
        }
×
3739

NEW
3740
        p1, p2, err := getAndBuildChanPolicies(
×
NEW
3741
                ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
NEW
3742
        )
×
NEW
3743

×
NEW
3744
        return edge, p1, p2, err
×
3745
}
3746

3747
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3748
// provided dbChanRow and also fetches any other required information
3749
// to construct the edge info.
3750
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
3751
        chain chainhash.Hash, dbChanID int64, dbChanRow any, node1,
NEW
3752
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
NEW
3753

×
NEW
3754
        dbChan, err := extractChannel(dbChanRow)
×
NEW
3755
        if err != nil {
×
NEW
3756
                return nil, err
×
NEW
3757
        }
×
3758

NEW
3759
        fv, extras, err := getChanFeaturesAndExtras(
×
NEW
3760
                ctx, db, dbChanID,
×
NEW
3761
        )
×
NEW
3762
        if err != nil {
×
NEW
3763
                return nil, err
×
NEW
3764
        }
×
3765

NEW
3766
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
NEW
3767
        if err != nil {
×
NEW
3768
                return nil, err
×
NEW
3769
        }
×
3770

NEW
3771
        var featureBuf bytes.Buffer
×
NEW
3772
        if err := fv.Encode(&featureBuf); err != nil {
×
NEW
3773
                return nil, fmt.Errorf("unable to encode features: %w", err)
×
NEW
3774
        }
×
3775

NEW
3776
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
NEW
3777
        if err != nil {
×
NEW
3778
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
NEW
3779
                        "fields: %w", err)
×
NEW
3780
        }
×
NEW
3781
        if recs == nil {
×
NEW
3782
                recs = make([]byte, 0)
×
NEW
3783
        }
×
3784

NEW
3785
        var btcKey1, btcKey2 route.Vertex
×
NEW
3786
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
NEW
3787
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
NEW
3788

×
NEW
3789
        channel := &models.ChannelEdgeInfo{
×
NEW
3790
                ChainHash:        chain,
×
NEW
3791
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
NEW
3792
                NodeKey1Bytes:    node1,
×
NEW
3793
                NodeKey2Bytes:    node2,
×
NEW
3794
                BitcoinKey1Bytes: btcKey1,
×
NEW
3795
                BitcoinKey2Bytes: btcKey2,
×
NEW
3796
                ChannelPoint:     *op,
×
NEW
3797
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
NEW
3798
                Features:         featureBuf.Bytes(),
×
NEW
3799
                ExtraOpaqueData:  recs,
×
NEW
3800
        }
×
NEW
3801

×
NEW
3802
        if dbChan.Bitcoin1Signature != nil {
×
NEW
3803
                channel.AuthProof = &models.ChannelAuthProof{
×
NEW
3804
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
NEW
3805
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
NEW
3806
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
NEW
3807
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
NEW
3808
                }
×
NEW
3809
        }
×
3810

NEW
3811
        return channel, nil
×
3812
}
3813

3814
// buildNodeVertices is a helper that converts raw node public keys
3815
// into route.Vertex instances.
3816
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
NEW
3817
        route.Vertex, error) {
×
NEW
3818

×
NEW
3819
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
NEW
3820
        if err != nil {
×
NEW
3821
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
NEW
3822
                        "create vertex from node1 pubkey: %w", err)
×
NEW
3823
        }
×
3824

NEW
3825
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
NEW
3826
        if err != nil {
×
NEW
3827
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
NEW
3828
                        "create vertex from node2 pubkey: %w", err)
×
NEW
3829
        }
×
3830

NEW
3831
        return node1Vertex, node2Vertex, nil
×
3832
}
3833

3834
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
3835
// for a channel with the given ID.
3836
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
NEW
3837
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
NEW
3838

×
NEW
3839
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
NEW
3840
        if err != nil {
×
NEW
3841
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
NEW
3842
                        "features and extras: %w", err)
×
NEW
3843
        }
×
3844

NEW
3845
        var (
×
NEW
3846
                fv     = lnwire.EmptyFeatureVector()
×
NEW
3847
                extras = make(map[uint64][]byte)
×
NEW
3848
        )
×
NEW
3849
        for _, row := range rows {
×
NEW
3850
                switch row.Kind {
×
NEW
3851
                case "feature":
×
NEW
3852
                        featureBit, err := strconv.Atoi(row.Key)
×
NEW
3853
                        if err != nil {
×
NEW
3854
                                return nil, nil, err
×
NEW
3855
                        }
×
NEW
3856
                        fv.Set(lnwire.FeatureBit(featureBit))
×
3857

NEW
3858
                case "extra":
×
NEW
3859
                        tlvType, err := strconv.ParseInt(row.Key, 10, 64)
×
NEW
3860
                        if err != nil {
×
NEW
3861
                                return nil, nil, err
×
NEW
3862
                        }
×
NEW
3863
                        valueBytes, ok := row.Value.([]byte)
×
NEW
3864
                        if !ok {
×
NEW
3865
                                return nil, nil, fmt.Errorf("unexpected type "+
×
NEW
3866
                                        "for Value: %T", row.Value)
×
NEW
3867
                        }
×
NEW
3868
                        extras[uint64(tlvType)] = valueBytes
×
3869
                }
3870
        }
3871

NEW
3872
        return fv, extras, nil
×
3873
}
3874

3875
// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves
3876
// all the extra info required to build the complete models.ChannelEdgePolicy
3877
// types. It returns two policies, which may be nil if the provided
3878
// sqlc.ChannelPolicy records are nil.
3879
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
3880
        dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1,
3881
        node2 route.Vertex) (*models.ChannelEdgePolicy,
NEW
3882
        *models.ChannelEdgePolicy, error) {
×
NEW
3883

×
NEW
3884
        if dbPol1 == nil && dbPol2 == nil {
×
NEW
3885
                return nil, nil, nil
×
NEW
3886
        }
×
3887

NEW
3888
        var (
×
NEW
3889
                policy1ID int64
×
NEW
3890
                policy2ID int64
×
NEW
3891
        )
×
NEW
3892
        if dbPol1 != nil {
×
NEW
3893
                policy1ID = dbPol1.ID
×
NEW
3894
        }
×
NEW
3895
        if dbPol2 != nil {
×
NEW
3896
                policy2ID = dbPol2.ID
×
NEW
3897
        }
×
NEW
3898
        rows, err := db.GetChannelPolicyExtraTypes(
×
NEW
3899
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
NEW
3900
                        ID:   policy1ID,
×
NEW
3901
                        ID_2: policy2ID,
×
NEW
3902
                },
×
NEW
3903
        )
×
NEW
3904
        if err != nil {
×
NEW
3905
                return nil, nil, err
×
NEW
3906
        }
×
3907

NEW
3908
        var (
×
NEW
3909
                dbPol1Extras = make(map[uint64][]byte)
×
NEW
3910
                dbPol2Extras = make(map[uint64][]byte)
×
NEW
3911
        )
×
NEW
3912
        for _, row := range rows {
×
NEW
3913
                switch row.PolicyID {
×
NEW
3914
                case policy1ID:
×
NEW
3915
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
NEW
3916
                case policy2ID:
×
NEW
3917
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
NEW
3918
                default:
×
NEW
3919
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
NEW
3920
                                "in row: %v", row.PolicyID, row)
×
3921
                }
3922
        }
3923

NEW
3924
        var pol1, pol2 *models.ChannelEdgePolicy
×
NEW
3925
        if dbPol1 != nil {
×
NEW
3926
                pol1, err = buildChanPolicy(
×
NEW
3927
                        *dbPol1, channelID, dbPol1Extras, node2, true,
×
NEW
3928
                )
×
NEW
3929
                if err != nil {
×
NEW
3930
                        return nil, nil, err
×
NEW
3931
                }
×
3932
        }
NEW
3933
        if dbPol2 != nil {
×
NEW
3934
                pol2, err = buildChanPolicy(
×
NEW
3935
                        *dbPol2, channelID, dbPol2Extras, node1, false,
×
NEW
3936
                )
×
NEW
3937
                if err != nil {
×
NEW
3938
                        return nil, nil, err
×
NEW
3939
                }
×
3940
        }
3941

NEW
3942
        return pol1, pol2, nil
×
3943
}
3944

3945
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
3946
// provided sqlc.ChannelPolicy and other required information.
3947
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
3948
        extras map[uint64][]byte, toNode route.Vertex,
NEW
3949
        isNode1 bool) (*models.ChannelEdgePolicy, error) {
×
NEW
3950

×
NEW
3951
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
NEW
3952
        if err != nil {
×
NEW
3953
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
NEW
3954
                        "fields: %w", err)
×
NEW
3955
        }
×
3956

NEW
3957
        var msgFlags lnwire.ChanUpdateMsgFlags
×
NEW
3958
        if dbPolicy.MaxHtlcMsat.Valid {
×
NEW
3959
                msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc
×
NEW
3960
        }
×
3961

NEW
3962
        var chanFlags lnwire.ChanUpdateChanFlags
×
NEW
3963
        if !isNode1 {
×
NEW
3964
                chanFlags |= lnwire.ChanUpdateDirection
×
NEW
3965
        }
×
NEW
3966
        if dbPolicy.Disabled.Bool {
×
NEW
3967
                chanFlags |= lnwire.ChanUpdateDisabled
×
NEW
3968
        }
×
3969

NEW
3970
        var inboundFee fn.Option[lnwire.Fee]
×
NEW
3971
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
NEW
3972
                dbPolicy.InboundBaseFeeMsat.Valid {
×
NEW
3973

×
NEW
3974
                inboundFee = fn.Some(lnwire.Fee{
×
NEW
3975
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
NEW
3976
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
NEW
3977
                })
×
NEW
3978
        }
×
3979

NEW
3980
        return &models.ChannelEdgePolicy{
×
NEW
3981
                SigBytes:  dbPolicy.Signature,
×
NEW
3982
                ChannelID: channelID,
×
NEW
3983
                LastUpdate: time.Unix(
×
NEW
3984
                        dbPolicy.LastUpdate.Int64, 0,
×
NEW
3985
                ),
×
NEW
3986
                MessageFlags:  msgFlags,
×
NEW
3987
                ChannelFlags:  chanFlags,
×
NEW
3988
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
NEW
3989
                MinHTLC: lnwire.MilliSatoshi(
×
NEW
3990
                        dbPolicy.MinHtlcMsat,
×
NEW
3991
                ),
×
NEW
3992
                MaxHTLC: lnwire.MilliSatoshi(
×
NEW
3993
                        dbPolicy.MaxHtlcMsat.Int64,
×
NEW
3994
                ),
×
NEW
3995
                FeeBaseMSat: lnwire.MilliSatoshi(
×
NEW
3996
                        dbPolicy.BaseFeeMsat,
×
NEW
3997
                ),
×
NEW
3998
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
NEW
3999
                ToNode:                    toNode,
×
NEW
4000
                InboundFee:                inboundFee,
×
NEW
4001
                ExtraOpaqueData:           recs,
×
NEW
4002
        }, nil
×
4003
}
4004

4005
// getAndBuildNodes builds the models.LightningNode instances for the
4006
// given row which is expected to be a sqlc type that contains node information.
4007
func getAndBuildNodes(ctx context.Context, db SQLQueries,
NEW
4008
        row any) (*models.LightningNode, *models.LightningNode, error) {
×
NEW
4009

×
NEW
4010
        dbNode1, dbNode2, err := extractNodes(row)
×
NEW
4011
        if err != nil {
×
NEW
4012
                return nil, nil, err
×
NEW
4013
        }
×
4014

NEW
4015
        node1, err := buildNode(ctx, db, &dbNode1)
×
NEW
4016
        if err != nil {
×
NEW
4017
                return nil, nil, err
×
NEW
4018
        }
×
4019

NEW
4020
        node2, err := buildNode(ctx, db, &dbNode2)
×
NEW
4021
        if err != nil {
×
NEW
4022
                return nil, nil, err
×
NEW
4023
        }
×
4024

NEW
4025
        return node1, node2, nil
×
4026
}
4027

4028
// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give
4029
// row which is expected to be a sqlc type that contains channel policy
4030
// information. It returns two policies, which may be nil if the policy
4031
// information is not present in the row.
4032
//
4033
//nolint:ll
4034
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
NEW
4035
        error) {
×
NEW
4036

×
NEW
4037
        var policy1, policy2 *sqlc.ChannelPolicy
×
NEW
4038
        switch r := row.(type) {
×
NEW
4039
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
NEW
4040
                if r.Policy1ID.Valid {
×
NEW
4041
                        policy1 = &sqlc.ChannelPolicy{
×
NEW
4042
                                ID:                      r.Policy1ID.Int64,
×
NEW
4043
                                Version:                 r.Policy1Version.Int16,
×
NEW
4044
                                ChannelID:               r.ID,
×
NEW
4045
                                NodeID:                  r.Policy1NodeID.Int64,
×
NEW
4046
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
4047
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
4048
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
4049
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
4050
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
4051
                                LastUpdate:              r.Policy1LastUpdate,
×
NEW
4052
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
4053
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
4054
                                Disabled:                r.Policy1Disabled,
×
NEW
4055
                                Signature:               r.Policy1Signature,
×
NEW
4056
                        }
×
NEW
4057
                }
×
NEW
4058
                if r.Policy2ID.Valid {
×
NEW
4059
                        policy2 = &sqlc.ChannelPolicy{
×
NEW
4060
                                ID:                      r.Policy2ID.Int64,
×
NEW
4061
                                Version:                 r.Policy2Version.Int16,
×
NEW
4062
                                ChannelID:               r.ID,
×
NEW
4063
                                NodeID:                  r.Policy2NodeID.Int64,
×
NEW
4064
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
4065
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
4066
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
4067
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
4068
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
4069
                                LastUpdate:              r.Policy2LastUpdate,
×
NEW
4070
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
4071
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
4072
                                Disabled:                r.Policy2Disabled,
×
NEW
4073
                                Signature:               r.Policy2Signature,
×
NEW
4074
                        }
×
NEW
4075
                }
×
NEW
4076
                return policy1, policy2, nil
×
4077

NEW
4078
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
NEW
4079
                if r.Policy1ID.Valid {
×
NEW
4080
                        policy1 = &sqlc.ChannelPolicy{
×
NEW
4081
                                ID:                      r.Policy1ID.Int64,
×
NEW
4082
                                Version:                 r.Policy1Version.Int16,
×
NEW
4083
                                ChannelID:               r.ID,
×
NEW
4084
                                NodeID:                  r.Policy1NodeID.Int64,
×
NEW
4085
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
4086
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
4087
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
4088
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
4089
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
4090
                                LastUpdate:              r.Policy1LastUpdate,
×
NEW
4091
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
4092
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
4093
                                Disabled:                r.Policy1Disabled,
×
NEW
4094
                                Signature:               r.Policy1Signature,
×
NEW
4095
                        }
×
NEW
4096
                }
×
NEW
4097
                if r.Policy2ID.Valid {
×
NEW
4098
                        policy2 = &sqlc.ChannelPolicy{
×
NEW
4099
                                ID:                      r.Policy2ID.Int64,
×
NEW
4100
                                Version:                 r.Policy2Version.Int16,
×
NEW
4101
                                ChannelID:               r.ID,
×
NEW
4102
                                NodeID:                  r.Policy2NodeID.Int64,
×
NEW
4103
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
4104
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
4105
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
4106
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
4107
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
4108
                                LastUpdate:              r.Policy2LastUpdate,
×
NEW
4109
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
4110
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
4111
                                Disabled:                r.Policy2Disabled,
×
NEW
4112
                                Signature:               r.Policy2Signature,
×
NEW
4113
                        }
×
NEW
4114
                }
×
NEW
4115
                return policy1, policy2, nil
×
4116

NEW
4117
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
NEW
4118
                if r.Policy1ID.Valid {
×
NEW
4119
                        policy1 = &sqlc.ChannelPolicy{
×
NEW
4120
                                ID:                      r.Policy1ID.Int64,
×
NEW
4121
                                Version:                 r.Policy1Version.Int16,
×
NEW
4122
                                ChannelID:               r.ID,
×
NEW
4123
                                NodeID:                  r.Policy1NodeID.Int64,
×
NEW
4124
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
4125
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
4126
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
4127
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
4128
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
4129
                                LastUpdate:              r.Policy1LastUpdate,
×
NEW
4130
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
4131
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
4132
                                Disabled:                r.Policy1Disabled,
×
NEW
4133
                                Signature:               r.Policy1Signature,
×
NEW
4134
                        }
×
NEW
4135
                }
×
NEW
4136
                if r.Policy2ID.Valid {
×
NEW
4137
                        policy2 = &sqlc.ChannelPolicy{
×
NEW
4138
                                ID:                      r.Policy2ID.Int64,
×
NEW
4139
                                Version:                 r.Policy2Version.Int16,
×
NEW
4140
                                ChannelID:               r.ID,
×
NEW
4141
                                NodeID:                  r.Policy2NodeID.Int64,
×
NEW
4142
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
4143
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
4144
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
4145
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
4146
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
4147
                                LastUpdate:              r.Policy2LastUpdate,
×
NEW
4148
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
4149
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
4150
                                Disabled:                r.Policy2Disabled,
×
NEW
4151
                                Signature:               r.Policy2Signature,
×
NEW
4152
                        }
×
NEW
4153
                }
×
NEW
4154
                return policy1, policy2, nil
×
4155

NEW
4156
        case sqlc.ListChannelsByNodeIDRow:
×
NEW
4157
                if r.Policy1ID.Valid {
×
NEW
4158
                        policy1 = &sqlc.ChannelPolicy{
×
NEW
4159
                                ID:                      r.Policy1ID.Int64,
×
NEW
4160
                                Version:                 r.Policy1Version.Int16,
×
NEW
4161
                                ChannelID:               r.ID,
×
NEW
4162
                                NodeID:                  r.Policy1NodeID.Int64,
×
NEW
4163
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
4164
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
4165
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
4166
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
4167
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
4168
                                LastUpdate:              r.Policy1LastUpdate,
×
NEW
4169
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
4170
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
4171
                                Disabled:                r.Policy1Disabled,
×
NEW
4172
                                Signature:               r.Policy1Signature,
×
NEW
4173
                        }
×
NEW
4174
                }
×
NEW
4175
                if r.Policy2ID.Valid {
×
NEW
4176
                        policy2 = &sqlc.ChannelPolicy{
×
NEW
4177
                                ID:                      r.Policy2ID.Int64,
×
NEW
4178
                                Version:                 r.Policy2Version.Int16,
×
NEW
4179
                                ChannelID:               r.ID,
×
NEW
4180
                                NodeID:                  r.Policy2NodeID.Int64,
×
NEW
4181
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
4182
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
4183
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
4184
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
4185
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
4186
                                LastUpdate:              r.Policy2LastUpdate,
×
NEW
4187
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
4188
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
4189
                                Disabled:                r.Policy2Disabled,
×
NEW
4190
                                Signature:               r.Policy2Signature,
×
NEW
4191
                        }
×
NEW
4192
                }
×
4193

NEW
4194
                return policy1, policy2, nil
×
4195

NEW
4196
        case sqlc.ListAllChannelsRow:
×
NEW
4197
                if r.Policy1ID.Valid {
×
NEW
4198
                        policy1 = &sqlc.ChannelPolicy{
×
NEW
4199
                                ID:                      r.Policy1ID.Int64,
×
NEW
4200
                                Version:                 r.Policy1Version.Int16,
×
NEW
4201
                                ChannelID:               r.ID,
×
NEW
4202
                                NodeID:                  r.Policy1NodeID.Int64,
×
NEW
4203
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
4204
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
4205
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
4206
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
4207
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
4208
                                LastUpdate:              r.Policy1LastUpdate,
×
NEW
4209
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
4210
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
4211
                                Disabled:                r.Policy1Disabled,
×
NEW
4212
                                Signature:               r.Policy1Signature,
×
NEW
4213
                        }
×
NEW
4214
                }
×
NEW
4215
                if r.Policy2ID.Valid {
×
NEW
4216
                        policy2 = &sqlc.ChannelPolicy{
×
NEW
4217
                                ID:                      r.Policy2ID.Int64,
×
NEW
4218
                                Version:                 r.Policy2Version.Int16,
×
NEW
4219
                                ChannelID:               r.ID,
×
NEW
4220
                                NodeID:                  r.Policy2NodeID.Int64,
×
NEW
4221
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
4222
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
4223
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
4224
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
4225
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
4226
                                LastUpdate:              r.Policy2LastUpdate,
×
NEW
4227
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
4228
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
4229
                                Disabled:                r.Policy2Disabled,
×
NEW
4230
                                Signature:               r.Policy2Signature,
×
NEW
4231
                        }
×
NEW
4232
                }
×
4233

NEW
4234
                return policy1, policy2, nil
×
NEW
4235
        default:
×
NEW
4236
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
NEW
4237
                        "extractChannelPolicies: %T", r)
×
4238
        }
4239
}
4240

4241
// extractChannel extracts the sqlc.Channel record from the given row
4242
// which is expected to be a sqlc type that contains channel information.
NEW
4243
func extractChannel(row any) (sqlc.Channel, error) {
×
NEW
4244
        switch r := row.(type) {
×
NEW
4245
        case sqlc.GetChannelsBySCIDRangeRow:
×
NEW
4246
                return sqlc.Channel{
×
NEW
4247
                        ID:                r.ID,
×
NEW
4248
                        Version:           r.Version,
×
NEW
4249
                        Scid:              r.Scid,
×
NEW
4250
                        NodeID1:           r.NodeID1,
×
NEW
4251
                        NodeID2:           r.NodeID2,
×
NEW
4252
                        Outpoint:          r.Outpoint,
×
NEW
4253
                        Capacity:          r.Capacity,
×
NEW
4254
                        BitcoinKey1:       r.BitcoinKey1,
×
NEW
4255
                        BitcoinKey2:       r.BitcoinKey2,
×
NEW
4256
                        Node1Signature:    r.Node1Signature,
×
NEW
4257
                        Node2Signature:    r.Node2Signature,
×
NEW
4258
                        Bitcoin1Signature: r.Bitcoin1Signature,
×
NEW
4259
                        Bitcoin2Signature: r.Bitcoin2Signature,
×
NEW
4260
                }, nil
×
4261

NEW
4262
        case sqlc.GetChannelByOutpointRow:
×
NEW
4263
                return sqlc.Channel{
×
NEW
4264
                        ID:                r.ID,
×
NEW
4265
                        Version:           r.Version,
×
NEW
4266
                        Scid:              r.Scid,
×
NEW
4267
                        NodeID1:           r.NodeID1,
×
NEW
4268
                        NodeID2:           r.NodeID2,
×
NEW
4269
                        Outpoint:          r.Outpoint,
×
NEW
4270
                        Capacity:          r.Capacity,
×
NEW
4271
                        BitcoinKey1:       r.BitcoinKey1,
×
NEW
4272
                        BitcoinKey2:       r.BitcoinKey2,
×
NEW
4273
                        Node1Signature:    r.Node1Signature,
×
NEW
4274
                        Node2Signature:    r.Node2Signature,
×
NEW
4275
                        Bitcoin1Signature: r.Bitcoin1Signature,
×
NEW
4276
                        Bitcoin2Signature: r.Bitcoin2Signature,
×
NEW
4277
                }, nil
×
4278

NEW
4279
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
NEW
4280
                return sqlc.Channel{
×
NEW
4281
                        ID:                r.ID,
×
NEW
4282
                        Version:           r.Version,
×
NEW
4283
                        Scid:              r.Scid,
×
NEW
4284
                        NodeID1:           r.NodeID1,
×
NEW
4285
                        NodeID2:           r.NodeID2,
×
NEW
4286
                        Outpoint:          r.Outpoint,
×
NEW
4287
                        Capacity:          r.Capacity,
×
NEW
4288
                        BitcoinKey1:       r.BitcoinKey1,
×
NEW
4289
                        BitcoinKey2:       r.BitcoinKey2,
×
NEW
4290
                        Node1Signature:    r.Node1Signature,
×
NEW
4291
                        Node2Signature:    r.Node2Signature,
×
NEW
4292
                        Bitcoin1Signature: r.Bitcoin1Signature,
×
NEW
4293
                        Bitcoin2Signature: r.Bitcoin2Signature,
×
NEW
4294
                }, nil
×
4295

NEW
4296
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
NEW
4297
                return sqlc.Channel{
×
NEW
4298
                        ID:                r.ID,
×
NEW
4299
                        Version:           r.Version,
×
NEW
4300
                        Scid:              r.Scid,
×
NEW
4301
                        NodeID1:           r.NodeID1,
×
NEW
4302
                        NodeID2:           r.NodeID2,
×
NEW
4303
                        Outpoint:          r.Outpoint,
×
NEW
4304
                        Capacity:          r.Capacity,
×
NEW
4305
                        BitcoinKey1:       r.BitcoinKey1,
×
NEW
4306
                        BitcoinKey2:       r.BitcoinKey2,
×
NEW
4307
                        Node1Signature:    r.Node1Signature,
×
NEW
4308
                        Node2Signature:    r.Node2Signature,
×
NEW
4309
                        Bitcoin1Signature: r.Bitcoin1Signature,
×
NEW
4310
                        Bitcoin2Signature: r.Bitcoin2Signature,
×
NEW
4311
                }, nil
×
4312

NEW
4313
        case sqlc.ListAllChannelsRow:
×
NEW
4314
                return sqlc.Channel{
×
NEW
4315
                        ID:                r.ID,
×
NEW
4316
                        Version:           r.Version,
×
NEW
4317
                        Scid:              r.Scid,
×
NEW
4318
                        NodeID1:           r.NodeID1,
×
NEW
4319
                        NodeID2:           r.NodeID2,
×
NEW
4320
                        Outpoint:          r.Outpoint,
×
NEW
4321
                        Capacity:          r.Capacity,
×
NEW
4322
                        BitcoinKey1:       r.BitcoinKey1,
×
NEW
4323
                        BitcoinKey2:       r.BitcoinKey2,
×
NEW
4324
                        Node1Signature:    r.Node1Signature,
×
NEW
4325
                        Node2Signature:    r.Node2Signature,
×
NEW
4326
                        Bitcoin1Signature: r.Bitcoin1Signature,
×
NEW
4327
                        Bitcoin2Signature: r.Bitcoin2Signature,
×
NEW
4328
                }, nil
×
4329

NEW
4330
        case sqlc.ListChannelsByNodeIDRow:
×
NEW
4331
                return sqlc.Channel{
×
NEW
4332
                        ID:                r.ID,
×
NEW
4333
                        Version:           r.Version,
×
NEW
4334
                        Scid:              r.Scid,
×
NEW
4335
                        NodeID1:           r.NodeID1,
×
NEW
4336
                        NodeID2:           r.NodeID2,
×
NEW
4337
                        Outpoint:          r.Outpoint,
×
NEW
4338
                        Capacity:          r.Capacity,
×
NEW
4339
                        BitcoinKey1:       r.BitcoinKey1,
×
NEW
4340
                        BitcoinKey2:       r.BitcoinKey2,
×
NEW
4341
                        Node1Signature:    r.Node1Signature,
×
NEW
4342
                        Node2Signature:    r.Node2Signature,
×
NEW
4343
                        Bitcoin1Signature: r.Bitcoin1Signature,
×
NEW
4344
                        Bitcoin2Signature: r.Bitcoin2Signature,
×
NEW
4345
                }, nil
×
4346

NEW
4347
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
NEW
4348
                return sqlc.Channel{
×
NEW
4349
                        ID:                r.ID,
×
NEW
4350
                        Version:           r.Version,
×
NEW
4351
                        Scid:              r.Scid,
×
NEW
4352
                        NodeID1:           r.NodeID1,
×
NEW
4353
                        NodeID2:           r.NodeID2,
×
NEW
4354
                        Outpoint:          r.Outpoint,
×
NEW
4355
                        Capacity:          r.Capacity,
×
NEW
4356
                        BitcoinKey1:       r.BitcoinKey1,
×
NEW
4357
                        BitcoinKey2:       r.BitcoinKey2,
×
NEW
4358
                        Node1Signature:    r.Node1Signature,
×
NEW
4359
                        Node2Signature:    r.Node2Signature,
×
NEW
4360
                        Bitcoin1Signature: r.Bitcoin1Signature,
×
NEW
4361
                        Bitcoin2Signature: r.Bitcoin2Signature,
×
NEW
4362
                }, nil
×
4363

NEW
4364
        default:
×
NEW
4365
                return sqlc.Channel{}, fmt.Errorf("unexpected row type in "+
×
NEW
4366
                        "extractChannel: %T", r)
×
4367
        }
4368
}
4369

4370
// extractNodes extracts the sqlc.Node records from the given row
4371
// which is expected to be a sqlc type that contains node information.
NEW
4372
func extractNodes(row any) (sqlc.Node, sqlc.Node, error) {
×
NEW
4373
        switch r := row.(type) {
×
NEW
4374
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
NEW
4375
                return sqlc.Node{
×
NEW
4376
                                ID:         r.Node1ID,
×
NEW
4377
                                Version:    r.Node1Version,
×
NEW
4378
                                PubKey:     r.Node1PubKey,
×
NEW
4379
                                Alias:      r.Node1Alias,
×
NEW
4380
                                LastUpdate: r.Node1LastUpdate,
×
NEW
4381
                                Color:      r.Node1Color,
×
NEW
4382
                                Signature:  r.Node1AnnSignature,
×
NEW
4383
                        }, sqlc.Node{
×
NEW
4384
                                ID:         r.Node2ID,
×
NEW
4385
                                Version:    r.Node2Version,
×
NEW
4386
                                PubKey:     r.Node2PubKey,
×
NEW
4387
                                Alias:      r.Node2Alias,
×
NEW
4388
                                LastUpdate: r.Node2LastUpdate,
×
NEW
4389
                                Color:      r.Node2Color,
×
NEW
4390
                                Signature:  r.Node2AnnSignature,
×
NEW
4391
                        }, nil
×
4392

NEW
4393
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
NEW
4394
                return sqlc.Node{
×
NEW
4395
                                ID:         r.Node1ID,
×
NEW
4396
                                Version:    r.Node1Version,
×
NEW
4397
                                PubKey:     r.Node1PubKey,
×
NEW
4398
                                Alias:      r.Node1Alias,
×
NEW
4399
                                LastUpdate: r.Node1LastUpdate,
×
NEW
4400
                                Color:      r.Node1Color,
×
NEW
4401
                                Signature:  r.Node1AnnSignature,
×
NEW
4402
                        }, sqlc.Node{
×
NEW
4403
                                ID:         r.Node2ID,
×
NEW
4404
                                Version:    r.Node2Version,
×
NEW
4405
                                PubKey:     r.Node2PubKey,
×
NEW
4406
                                Alias:      r.Node2Alias,
×
NEW
4407
                                LastUpdate: r.Node2LastUpdate,
×
NEW
4408
                                Color:      r.Node2Color,
×
NEW
4409
                                Signature:  r.Node2AnnSignature,
×
NEW
4410
                        }, nil
×
4411

NEW
4412
        default:
×
NEW
4413
                return sqlc.Node{}, sqlc.Node{}, fmt.Errorf("unexpected row "+
×
NEW
4414
                        "type in extractNodes: %T", r)
×
4415
        }
4416
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc