• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15602296443

12 Jun 2025 05:19AM UTC coverage: 57.49% (-0.8%) from 58.333%
15602296443

Pull #9932

github

web-flow
Merge 1c1dadc69 into 35102e7c3
Pull Request #9932: [draft] graph/db+sqldb: graph store SQL implementation + migration

7 of 2587 new or added lines in 6 files covered. (0.27%)

240 existing lines in 11 files now uncovered.

97770 of 170065 relevant lines covered (57.49%)

1.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "math"
11
        "net"
12
        "sort"
13
        "strconv"
14
        "sync"
15
        "time"
16

17
        "github.com/btcsuite/btcd/btcec/v2"
18
        "github.com/btcsuite/btcd/btcutil"
19
        "github.com/btcsuite/btcd/chaincfg/chainhash"
20
        "github.com/btcsuite/btcd/wire"
21
        "github.com/lightningnetwork/lnd/aliasmgr"
22
        "github.com/lightningnetwork/lnd/batch"
23
        "github.com/lightningnetwork/lnd/fn/v2"
24
        "github.com/lightningnetwork/lnd/graph/db/models"
25
        "github.com/lightningnetwork/lnd/lnwire"
26
        "github.com/lightningnetwork/lnd/routing/route"
27
        "github.com/lightningnetwork/lnd/sqldb"
28
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
29
        "github.com/lightningnetwork/lnd/tlv"
30
        "github.com/lightningnetwork/lnd/tor"
31
)
32

33
// ProtocolVersion is an enum that defines the gossip protocol version of a
34
// message.
35
type ProtocolVersion uint8
36

37
const (
38
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
39
        ProtocolV1 ProtocolVersion = 1
40
)
41

42
// String returns a string representation of the protocol version.
43
func (v ProtocolVersion) String() string {
×
44
        return fmt.Sprintf("V%d", v)
×
45
}
×
46

47
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
48
// execute queries against the SQL graph tables.
49
//
50
//nolint:ll,interfacebloat
51
type SQLQueries interface {
52
        /*
53
                Node queries.
54
        */
55
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
56
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
57
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
58
        ListNodes(ctx context.Context, version int16) ([]sqlc.ListNodesRow, error)
59
        ListNodeIDsAndPubKeys(ctx context.Context, version int16) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
60
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
61
        GetUnconnectedNodes(ctx context.Context) ([]sqlc.GetUnconnectedNodesRow, error)
62
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
63

64
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
65
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
66
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
67

68
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
69
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
70
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
71

72
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
73
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
74
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
75
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
76

77
        /*
78
                Source node queries.
79
        */
80
        AddSourceNode(ctx context.Context, nodeID int64) error
81
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
82

83
        /*
84
                Channel queries.
85
        */
86
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
87
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) error
88
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
89
        GetChannelByOutpoint(ctx context.Context, arg sqlc.GetChannelByOutpointParams) (sqlc.GetChannelByOutpointRow, error)
90
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
91
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
92
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
93
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
94
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
95
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
96
        ListAllChannels(ctx context.Context, version int16) ([]sqlc.ListAllChannelsRow, error)
97
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
98
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
99
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
100
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
101
        DeleteChannel(ctx context.Context, id int64) error
102

103
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
104
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
105

106
        /*
107
                Channel Policy table queries.
108
        */
109
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
110
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.ChannelPolicy, error)
111
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
112

113
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
114
        DeleteChannelPolicyExtraType(ctx context.Context, arg sqlc.DeleteChannelPolicyExtraTypeParams) error
115
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
116

117
        /*
118
                Zombie index queries.
119
        */
120
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
121
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error)
122
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
123
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) error
124
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
125

126
        /*
127
                Prune log table queries.
128
        */
129
        GetPruneTip(ctx context.Context) (sqlc.PruneLog, error)
130
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
131
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
132

133
        /*
134
                Closed SCID table queries.
135
        */
136
        InsertClosedChannel(ctx context.Context, scid []byte) error
137
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
138
}
139

140
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
141
// database operations.
142
type BatchedSQLQueries interface {
143
        SQLQueries
144
        sqldb.BatchedTx[SQLQueries]
145
}
146

147
// SQLStore is an implementation of the V1Store interface that uses a SQL
148
// database as the backend.
149
//
150
// NOTE: currently, this temporarily embeds the KVStore struct so that we can
151
// implement the V1Store interface incrementally. For any method not
152
// implemented,  things will fall back to the KVStore. This is ONLY the case
153
// for the time being while this struct is purely used in unit tests only.
154
type SQLStore struct {
155
        cfg *SQLStoreConfig
156
        db  BatchedSQLQueries
157

158
        // cacheMu guards all caches (rejectCache and chanCache). If
159
        // this mutex will be acquired at the same time as the DB mutex then
160
        // the cacheMu MUST be acquired first to prevent deadlock.
161
        cacheMu     sync.RWMutex
162
        rejectCache *rejectCache
163
        chanCache   *channelCache
164

165
        chanScheduler batch.Scheduler[SQLQueries]
166
        nodeScheduler batch.Scheduler[SQLQueries]
167

168
        srcNodeID  int64
169
        srcNodePub route.Vertex
170
        srcNodeMu  sync.Mutex
171
}
172

173
// A compile-time assertion to ensure that SQLStore implements the V1Store
174
// interface.
175
var _ V1Store = (*SQLStore)(nil)
176

177
// SQLStoreConfig holds the configuration for the SQLStore.
178
type SQLStoreConfig struct {
179
        // ChainHash is the genesis hash for the chain that all the gossip
180
        // messages in this store are aimed at.
181
        ChainHash chainhash.Hash
182
}
183

184
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
185
// storage backend.
186
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
187
        options ...StoreOptionModifier) (*SQLStore, error) {
×
188

×
189
        opts := DefaultOptions()
×
190
        for _, o := range options {
×
191
                o(opts)
×
192
        }
×
193

194
        if opts.NoMigration {
×
195
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
196
                        "supported for SQL stores")
×
197
        }
×
198

199
        s := &SQLStore{
×
NEW
200
                cfg:         cfg,
×
201
                db:          db,
×
202
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
203
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
204
        }
×
205

×
206
        s.chanScheduler = batch.NewTimeScheduler(
×
207
                db, &s.cacheMu, opts.BatchCommitInterval,
×
208
        )
×
209
        s.nodeScheduler = batch.NewTimeScheduler(
×
210
                db, nil, opts.BatchCommitInterval,
×
211
        )
×
212

×
213
        return s, nil
×
214
}
215

216
// AddLightningNode adds a vertex/node to the graph database. If the node is not
217
// in the database from before, this will add a new, unconnected one to the
218
// graph. If it is present from before, this will update that node's
219
// information.
220
//
221
// NOTE: part of the V1Store interface.
222
func (s *SQLStore) AddLightningNode(node *models.LightningNode,
223
        opts ...batch.SchedulerOption) error {
×
224

×
225
        ctx := context.TODO()
×
226

×
227
        r := &batch.Request[SQLQueries]{
×
228
                Opts: batch.NewSchedulerOptions(opts...),
×
229
                Do: func(queries SQLQueries) error {
×
230
                        _, err := upsertNode(ctx, queries, node)
×
231
                        return err
×
232
                },
×
233
        }
234

235
        return s.nodeScheduler.Execute(ctx, r)
×
236
}
237

238
// FetchLightningNode attempts to look up a target node by its identity public
239
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
240
// returned.
241
//
242
// NOTE: part of the V1Store interface.
243
func (s *SQLStore) FetchLightningNode(pubKey route.Vertex) (
244
        *models.LightningNode, error) {
×
245

×
246
        ctx := context.TODO()
×
247

×
248
        var node *models.LightningNode
×
249
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
250
                var err error
×
251
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
252

×
253
                return err
×
254
        }, sqldb.NoOpReset)
×
255
        if err != nil {
×
256
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
257
        }
×
258

259
        return node, nil
×
260
}
261

262
// HasLightningNode determines if the graph has a vertex identified by the
263
// target node identity public key. If the node exists in the database, a
264
// timestamp of when the data for the node was lasted updated is returned along
265
// with a true boolean. Otherwise, an empty time.Time is returned with a false
266
// boolean.
267
//
268
// NOTE: part of the V1Store interface.
269
func (s *SQLStore) HasLightningNode(pubKey [33]byte) (time.Time, bool,
270
        error) {
×
271

×
272
        ctx := context.TODO()
×
273

×
274
        var (
×
275
                exists     bool
×
276
                lastUpdate time.Time
×
277
        )
×
278
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
279
                dbNode, err := db.GetNodeByPubKey(
×
280
                        ctx, sqlc.GetNodeByPubKeyParams{
×
281
                                Version: int16(ProtocolV1),
×
282
                                PubKey:  pubKey[:],
×
283
                        },
×
284
                )
×
285
                if errors.Is(err, sql.ErrNoRows) {
×
286
                        return nil
×
287
                } else if err != nil {
×
288
                        return fmt.Errorf("unable to fetch node: %w", err)
×
289
                }
×
290

291
                exists = true
×
292

×
293
                if dbNode.LastUpdate.Valid {
×
294
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
295
                }
×
296

297
                return nil
×
298
        }, sqldb.NoOpReset)
299
        if err != nil {
×
300
                return time.Time{}, false,
×
301
                        fmt.Errorf("unable to fetch node: %w", err)
×
302
        }
×
303

304
        return lastUpdate, exists, nil
×
305
}
306

307
// AddrsForNode returns all known addresses for the target node public key
308
// that the graph DB is aware of. The returned boolean indicates if the
309
// given node is unknown to the graph DB or not.
310
//
311
// NOTE: part of the V1Store interface.
312
func (s *SQLStore) AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr,
313
        error) {
×
314

×
315
        ctx := context.TODO()
×
316

×
317
        var (
×
318
                addresses []net.Addr
×
319
                known     bool
×
320
        )
×
321
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
322
                var err error
×
323
                known, addresses, err = getNodeAddresses(
×
324
                        ctx, db, nodePub.SerializeCompressed(),
×
325
                )
×
326
                if err != nil {
×
327
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
328
                                err)
×
329
                }
×
330

331
                return nil
×
332
        }, sqldb.NoOpReset)
333
        if err != nil {
×
334
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
335
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
336
        }
×
337

338
        return known, addresses, nil
×
339
}
340

341
// DeleteLightningNode starts a new database transaction to remove a vertex/node
342
// from the database according to the node's public key.
343
//
344
// NOTE: part of the V1Store interface.
345
func (s *SQLStore) DeleteLightningNode(pubKey route.Vertex) error {
×
346
        ctx := context.TODO()
×
347

×
348
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
349
                res, err := db.DeleteNodeByPubKey(
×
350
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
351
                                Version: int16(ProtocolV1),
×
352
                                PubKey:  pubKey[:],
×
353
                        },
×
354
                )
×
355
                if err != nil {
×
356
                        return err
×
357
                }
×
358

359
                rows, err := res.RowsAffected()
×
360
                if err != nil {
×
361
                        return err
×
362
                }
×
363

364
                if rows == 0 {
×
365
                        return ErrGraphNodeNotFound
×
366
                } else if rows > 1 {
×
367
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
368
                }
×
369

370
                return err
×
371
        }, sqldb.NoOpReset)
372
        if err != nil {
×
373
                return fmt.Errorf("unable to delete node: %w", err)
×
374
        }
×
375

376
        return nil
×
377
}
378

379
// FetchNodeFeatures returns the features of the given node. If no features are
380
// known for the node, an empty feature vector is returned.
381
//
382
// NOTE: this is part of the graphdb.NodeTraverser interface.
383
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
384
        *lnwire.FeatureVector, error) {
×
385

×
386
        ctx := context.TODO()
×
387

×
388
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
389
}
×
390

391
// DisabledChannelIDs returns the channel ids of disabled channels.
392
// A channel is disabled when two of the associated ChanelEdgePolicies
393
// have their disabled bit on.
394
//
395
// NOTE: part of the V1Store interface.
NEW
396
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
NEW
397
        var (
×
NEW
398
                ctx     = context.TODO()
×
NEW
399
                chanIDs []uint64
×
NEW
400
        )
×
NEW
401
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
402
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
NEW
403
                if err != nil {
×
NEW
404
                        return fmt.Errorf("unable to fetch disabled "+
×
NEW
405
                                "channels: %w", err)
×
NEW
406
                }
×
407

NEW
408
                for _, dbChanID := range dbChanIDs {
×
NEW
409
                        chanIDs = append(chanIDs, byteOrder.Uint64(dbChanID))
×
NEW
410
                }
×
411

NEW
412
                return nil
×
NEW
413
        }, func() {})
×
NEW
414
        if err != nil {
×
NEW
415
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
NEW
416
                        err)
×
NEW
417
        }
×
418

NEW
419
        return chanIDs, nil
×
420
}
421

422
// LookupAlias attempts to return the alias as advertised by the target node.
423
//
424
// NOTE: part of the V1Store interface.
425
func (s *SQLStore) LookupAlias(pub *btcec.PublicKey) (string, error) {
×
426
        var (
×
427
                ctx   = context.TODO()
×
428
                alias string
×
429
        )
×
430
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
431
                dbNode, err := db.GetNodeByPubKey(
×
432
                        ctx, sqlc.GetNodeByPubKeyParams{
×
433
                                Version: int16(ProtocolV1),
×
434
                                PubKey:  pub.SerializeCompressed(),
×
435
                        },
×
436
                )
×
437
                if errors.Is(err, sql.ErrNoRows) {
×
438
                        return ErrNodeAliasNotFound
×
439
                } else if err != nil {
×
440
                        return fmt.Errorf("unable to fetch node: %w", err)
×
441
                }
×
442

443
                if !dbNode.Alias.Valid {
×
444
                        return ErrNodeAliasNotFound
×
445
                }
×
446

447
                alias = dbNode.Alias.String
×
448

×
449
                return nil
×
450
        }, sqldb.NoOpReset)
451
        if err != nil {
×
452
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
453
        }
×
454

455
        return alias, nil
×
456
}
457

458
// SourceNode returns the source node of the graph. The source node is treated
459
// as the center node within a star-graph. This method may be used to kick off
460
// a path finding algorithm in order to explore the reachability of another
461
// node based off the source node.
462
//
463
// NOTE: part of the V1Store interface.
464
func (s *SQLStore) SourceNode() (*models.LightningNode, error) {
×
465
        ctx := context.TODO()
×
466

×
467
        var node *models.LightningNode
×
468
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
469
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
470
                if err != nil {
×
471
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
472
                                err)
×
473
                }
×
474

475
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
476

×
477
                return err
×
478
        }, sqldb.NoOpReset)
479
        if err != nil {
×
480
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
481
        }
×
482

483
        return node, nil
×
484
}
485

486
// SetSourceNode sets the source node within the graph database. The source
487
// node is to be used as the center of a star-graph within path finding
488
// algorithms.
489
//
490
// NOTE: part of the V1Store interface.
491
func (s *SQLStore) SetSourceNode(node *models.LightningNode) error {
×
492
        ctx := context.TODO()
×
493

×
494
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
495
                id, err := upsertNode(ctx, db, node)
×
496
                if err != nil {
×
497
                        return fmt.Errorf("unable to upsert source node: %w",
×
498
                                err)
×
499
                }
×
500

501
                // Make sure that if a source node for this version is already
502
                // set, then the ID is the same as the one we are about to set.
NEW
503
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
504
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
505
                        return fmt.Errorf("unable to fetch source node: %w",
×
506
                                err)
×
507
                } else if err == nil {
×
508
                        if dbSourceNodeID != id {
×
509
                                return fmt.Errorf("v1 source node already "+
×
510
                                        "set to a different node: %d vs %d",
×
511
                                        dbSourceNodeID, id)
×
512
                        }
×
513

514
                        return nil
×
515
                }
516

517
                return db.AddSourceNode(ctx, id)
×
518
        }, sqldb.NoOpReset)
519
}
520

521
// NodeUpdatesInHorizon returns all the known lightning node which have an
522
// update timestamp within the passed range. This method can be used by two
523
// nodes to quickly determine if they have the same set of up to date node
524
// announcements.
525
//
526
// NOTE: This is part of the V1Store interface.
527
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
528
        endTime time.Time) ([]models.LightningNode, error) {
×
529

×
530
        ctx := context.TODO()
×
531

×
532
        var nodes []models.LightningNode
×
533
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
534
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
535
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
536
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
537
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
538
                        },
×
539
                )
×
540
                if err != nil {
×
541
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
542
                }
×
543

544
                for _, dbNode := range dbNodes {
×
545
                        node, err := buildNode(ctx, db, &dbNode)
×
546
                        if err != nil {
×
547
                                return fmt.Errorf("unable to build node: %w",
×
548
                                        err)
×
549
                        }
×
550

551
                        nodes = append(nodes, *node)
×
552
                }
553

554
                return nil
×
555
        }, sqldb.NoOpReset)
556
        if err != nil {
×
557
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
558
        }
×
559

560
        return nodes, nil
×
561
}
562

563
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
564
// undirected edge from the two target nodes are created. The information stored
565
// denotes the static attributes of the channel, such as the channelID, the keys
566
// involved in creation of the channel, and the set of features that the channel
567
// supports. The chanPoint and chanID are used to uniquely identify the edge
568
// globally within the database.
569
//
570
// NOTE: part of the V1Store interface.
571
func (s *SQLStore) AddChannelEdge(edge *models.ChannelEdgeInfo,
572
        opts ...batch.SchedulerOption) error {
×
573

×
574
        ctx := context.TODO()
×
575

×
576
        var alreadyExists bool
×
577
        r := &batch.Request[SQLQueries]{
×
578
                Opts: batch.NewSchedulerOptions(opts...),
×
579
                Reset: func() {
×
580
                        alreadyExists = false
×
581
                },
×
582
                Do: func(tx SQLQueries) error {
×
NEW
583
                        _, err := insertChannel(ctx, tx, edge)
×
584

×
585
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
586
                        // succeed, but propagate the error via local state.
×
587
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
588
                                alreadyExists = true
×
589
                                return nil
×
590
                        }
×
591

592
                        return err
×
593
                },
594
                OnCommit: func(err error) error {
×
595
                        switch {
×
596
                        case err != nil:
×
597
                                return err
×
598
                        case alreadyExists:
×
599
                                return ErrEdgeAlreadyExist
×
600
                        default:
×
601
                                s.rejectCache.remove(edge.ChannelID)
×
602
                                s.chanCache.remove(edge.ChannelID)
×
603
                                return nil
×
604
                        }
605
                },
606
        }
607

NEW
608
        return s.chanScheduler.Execute(ctx, r)
×
609
}
610

611
// HighestChanID returns the "highest" known channel ID in the channel graph.
612
// This represents the "newest" channel from the PoV of the chain. This method
613
// can be used by peers to quickly determine if their graphs are in sync.
614
//
615
// NOTE: This is part of the V1Store interface.
NEW
616
func (s *SQLStore) HighestChanID() (uint64, error) {
×
NEW
617
        ctx := context.TODO()
×
NEW
618

×
NEW
619
        var highestChanID uint64
×
NEW
620
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
621
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
NEW
622
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
623
                        return nil
×
NEW
624
                } else if err != nil {
×
NEW
625
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
NEW
626
                                err)
×
NEW
627
                }
×
628

NEW
629
                highestChanID = byteOrder.Uint64(chanID)
×
NEW
630

×
NEW
631
                return nil
×
632
        }, sqldb.NoOpReset)
NEW
633
        if err != nil {
×
NEW
634
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
NEW
635
        }
×
636

NEW
637
        return highestChanID, nil
×
638
}
639

640
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
641
// within the database for the referenced channel. The `flags` attribute within
642
// the ChannelEdgePolicy determines which of the directed edges are being
643
// updated. If the flag is 1, then the first node's information is being
644
// updated, otherwise it's the second node's information. The node ordering is
645
// determined by the lexicographical ordering of the identity public keys of the
646
// nodes on either side of the channel.
647
//
648
// NOTE: part of the V1Store interface.
649
func (s *SQLStore) UpdateEdgePolicy(edge *models.ChannelEdgePolicy,
NEW
650
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
NEW
651

×
NEW
652
        ctx := context.TODO()
×
NEW
653

×
NEW
654
        var (
×
NEW
655
                isUpdate1    bool
×
NEW
656
                edgeNotFound bool
×
NEW
657
                from, to     route.Vertex
×
NEW
658
        )
×
NEW
659

×
NEW
660
        r := &batch.Request[SQLQueries]{
×
NEW
661
                Opts: batch.NewSchedulerOptions(opts...),
×
NEW
662
                Reset: func() {
×
NEW
663
                        isUpdate1 = false
×
NEW
664
                        edgeNotFound = false
×
NEW
665
                },
×
NEW
666
                Do: func(tx SQLQueries) error {
×
NEW
667
                        var err error
×
NEW
668
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
NEW
669
                                ctx, tx, edge,
×
NEW
670
                        )
×
NEW
671
                        if err != nil {
×
NEW
672
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
NEW
673
                        }
×
674

675
                        // Silence ErrEdgeNotFound so that the batch can
676
                        // succeed, but propagate the error via local state.
NEW
677
                        if errors.Is(err, ErrEdgeNotFound) {
×
NEW
678
                                edgeNotFound = true
×
NEW
679
                                return nil
×
NEW
680
                        }
×
681

NEW
682
                        return err
×
683
                },
NEW
684
                OnCommit: func(err error) error {
×
NEW
685
                        switch {
×
NEW
686
                        case err != nil:
×
NEW
687
                                return err
×
NEW
688
                        case edgeNotFound:
×
NEW
689
                                return ErrEdgeNotFound
×
NEW
690
                        default:
×
NEW
691
                                s.updateEdgeCache(edge, isUpdate1)
×
NEW
692
                                return nil
×
693
                        }
694
                },
695
        }
696

NEW
697
        err := s.chanScheduler.Execute(ctx, r)
×
NEW
698

×
NEW
699
        return from, to, err
×
700
}
701

702
// updateEdgeCache updates our reject and channel caches with the new
703
// edge policy information.
704
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
NEW
705
        isUpdate1 bool) {
×
NEW
706

×
NEW
707
        // If an entry for this channel is found in reject cache, we'll modify
×
NEW
708
        // the entry with the updated timestamp for the direction that was just
×
NEW
709
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
NEW
710
        // during the next query for this edge.
×
NEW
711
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
NEW
712
                if isUpdate1 {
×
NEW
713
                        entry.upd1Time = e.LastUpdate.Unix()
×
NEW
714
                } else {
×
NEW
715
                        entry.upd2Time = e.LastUpdate.Unix()
×
NEW
716
                }
×
NEW
717
                s.rejectCache.insert(e.ChannelID, entry)
×
718
        }
719

720
        // If an entry for this channel is found in channel cache, we'll modify
721
        // the entry with the updated policy for the direction that was just
722
        // written. If the edge doesn't exist, we'll defer loading the info and
723
        // policies and lazily read from disk during the next query.
NEW
724
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
NEW
725
                if isUpdate1 {
×
NEW
726
                        channel.Policy1 = e
×
NEW
727
                } else {
×
NEW
728
                        channel.Policy2 = e
×
NEW
729
                }
×
NEW
730
                s.chanCache.insert(e.ChannelID, channel)
×
731
        }
732
}
733

734
// ForEachSourceNodeChannel iterates through all channels of the source node,
735
// executing the passed callback on each. The call-back is provided with the
736
// channel's outpoint, whether we have a policy for the channel and the channel
737
// peer's node information.
738
//
739
// NOTE: part of the V1Store interface.
740
func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint,
NEW
741
        havePolicy bool, otherNode *models.LightningNode) error) error {
×
NEW
742

×
NEW
743
        var ctx = context.TODO()
×
NEW
744

×
NEW
745
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
746
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
NEW
747
                if err != nil {
×
NEW
748
                        return fmt.Errorf("unable to fetch source node: %w",
×
NEW
749
                                err)
×
NEW
750
                }
×
751

NEW
752
                return forEachNodeChannel(
×
NEW
753
                        ctx, db, s.cfg.ChainHash, nodeID,
×
NEW
754
                        func(info *models.ChannelEdgeInfo,
×
NEW
755
                                outPolicy *models.ChannelEdgePolicy,
×
NEW
756
                                _ *models.ChannelEdgePolicy) error {
×
NEW
757

×
NEW
758
                                // Fetch the other node.
×
NEW
759
                                var (
×
NEW
760
                                        otherNodePub [33]byte
×
NEW
761
                                        node1        = info.NodeKey1Bytes
×
NEW
762
                                        node2        = info.NodeKey2Bytes
×
NEW
763
                                )
×
NEW
764
                                switch {
×
NEW
765
                                case bytes.Equal(node1[:], nodePub[:]):
×
NEW
766
                                        otherNodePub = node2
×
NEW
767
                                case bytes.Equal(node2[:], nodePub[:]):
×
NEW
768
                                        otherNodePub = node1
×
NEW
769
                                default:
×
NEW
770
                                        return fmt.Errorf("node not " +
×
NEW
771
                                                "participating in this channel")
×
772
                                }
773

NEW
774
                                _, otherNode, err := getNodeByPubKey(
×
NEW
775
                                        ctx, db, otherNodePub,
×
NEW
776
                                )
×
NEW
777
                                if err != nil {
×
NEW
778
                                        return fmt.Errorf("unable to fetch "+
×
NEW
779
                                                "other node(%x): %w",
×
NEW
780
                                                otherNodePub, err)
×
NEW
781
                                }
×
782

NEW
783
                                return cb(
×
NEW
784
                                        info.ChannelPoint, outPolicy != nil,
×
NEW
785
                                        otherNode,
×
NEW
786
                                )
×
787
                        },
788
                )
NEW
789
        }, func() {})
×
790
}
791

792
// ForEachNode iterates through all the stored vertices/nodes in the graph,
793
// executing the passed callback with each node encountered. If the callback
794
// returns an error, then the transaction is aborted and the iteration stops
795
// early. Any operations performed on the NodeTx passed to the call-back are
796
// executed under the same read transaction and so, methods on the NodeTx object
797
// _MUST_ only be called from within the call-back.
798
//
799
// NOTE: part of the V1Store interface.
NEW
800
func (s *SQLStore) ForEachNode(cb func(tx NodeRTx) error) error {
×
NEW
801
        var ctx = context.TODO()
×
NEW
802

×
NEW
803
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
804
                return forEachNode(ctx, db,
×
NEW
805
                        func(nodeID int64, nodePub route.Vertex) error {
×
NEW
806
                                dbNode, err := db.GetNodeByPubKey(
×
NEW
807
                                        ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
808
                                                PubKey:  nodePub[:],
×
NEW
809
                                                Version: int16(ProtocolV1),
×
NEW
810
                                        },
×
NEW
811
                                )
×
NEW
812
                                if err != nil {
×
NEW
813
                                        return fmt.Errorf("unable to get "+
×
NEW
814
                                                "node(id=%d): %w", nodeID, err)
×
NEW
815
                                }
×
816

NEW
817
                                node, err := buildNode(ctx, db, &dbNode)
×
NEW
818
                                if err != nil {
×
NEW
819
                                        return fmt.Errorf("unable to build "+
×
NEW
820
                                                "node(id=%d): %w", nodeID, err)
×
NEW
821
                                }
×
822

NEW
823
                                return cb(newSQLGraphNodeTx(
×
NEW
824
                                        db, s.cfg.ChainHash, nodeID, node,
×
NEW
825
                                ))
×
826
                        },
827
                )
NEW
828
        }, func() {})
×
829
}
830

831
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
832
// SQLStore and a SQL transaction.
833
type sqlGraphNodeTx struct {
834
        db    SQLQueries
835
        id    int64
836
        node  *models.LightningNode
837
        chain chainhash.Hash
838
}
839

840
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
841
// interface.
842
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
843

844
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
NEW
845
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
NEW
846

×
NEW
847
        return &sqlGraphNodeTx{
×
NEW
848
                db:    db,
×
NEW
849
                chain: chain,
×
NEW
850
                id:    id,
×
NEW
851
                node:  node,
×
NEW
852
        }
×
NEW
853
}
×
854

855
// Node returns the raw information of the node.
856
//
857
// NOTE: This is a part of the NodeRTx interface.
NEW
858
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
NEW
859
        return s.node
×
NEW
860
}
×
861

862
// ForEachChannel can be used to iterate over the node's channels under the same
863
// transaction used to fetch the node.
864
//
865
// NOTE: This is a part of the NodeRTx interface.
866
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
NEW
867
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
NEW
868

×
NEW
869
        ctx := context.TODO()
×
NEW
870

×
NEW
871
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
NEW
872
}
×
873

874
// FetchNode fetches the node with the given pub key under the same transaction
875
// used to fetch the current node. The returned node is also a NodeRTx and any
876
// operations on that NodeRTx will also be done under the same transaction.
877
//
878
// NOTE: This is a part of the NodeRTx interface.
NEW
879
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
NEW
880
        ctx := context.TODO()
×
NEW
881

×
NEW
882
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
NEW
883
        if err != nil {
×
NEW
884
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
NEW
885
                        nodePub, err)
×
NEW
886
        }
×
887

NEW
888
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
889
}
890

891
// ForEachNodeDirectedChannel iterates through all channels of a given node,
892
// executing the passed callback on the directed edge representing the channel
893
// and its incoming policy. If the callback returns an error, then the iteration
894
// is halted with the error propagated back up to the caller.
895
//
896
// Unknown policies are passed into the callback as nil values.
897
//
898
// NOTE: this is part of the graphdb.NodeTraverser interface.
899
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
NEW
900
        cb func(channel *DirectedChannel) error) error {
×
NEW
901

×
NEW
902
        var ctx = context.TODO()
×
NEW
903

×
NEW
904
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
905
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
NEW
906
        }, func() {})
×
907
}
908

909
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
910
// graph, executing the passed callback with each node encountered. If the
911
// callback returns an error, then the transaction is aborted and the iteration
912
// stops early.
913
//
914
// NOTE: This is a part of the V1Store interface.
915
func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex,
NEW
916
        *lnwire.FeatureVector) error) error {
×
NEW
917

×
NEW
918
        ctx := context.TODO()
×
NEW
919

×
NEW
920
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
921
                return forEachNode(ctx, db, func(nodeID int64,
×
NEW
922
                        nodePub route.Vertex) error {
×
NEW
923

×
NEW
924
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
NEW
925
                        if err != nil {
×
NEW
926
                                return fmt.Errorf("unable to fetch node "+
×
NEW
927
                                        "features: %w", err)
×
NEW
928
                        }
×
929

NEW
930
                        return cb(nodePub, features)
×
931
                })
NEW
932
        }, func() {})
×
NEW
933
        if err != nil {
×
NEW
934
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
NEW
935
        }
×
936

NEW
937
        return nil
×
938
}
939

940
// ForEachNodeChannel iterates through all channels of the given node,
941
// executing the passed callback with an edge info structure and the policies
942
// of each end of the channel. The first edge policy is the outgoing edge *to*
943
// the connecting node, while the second is the incoming edge *from* the
944
// connecting node. If the callback returns an error, then the iteration is
945
// halted with the error propagated back up to the caller.
946
//
947
// Unknown policies are passed into the callback as nil values.
948
//
949
// NOTE: part of the V1Store interface.
950
func (s *SQLStore) ForEachNodeChannel(nodePub route.Vertex,
951
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
NEW
952
                *models.ChannelEdgePolicy) error) error {
×
NEW
953

×
NEW
954
        var ctx = context.TODO()
×
NEW
955

×
NEW
956
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
957
                dbNode, err := db.GetNodeByPubKey(
×
NEW
958
                        ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
959
                                Version: int16(ProtocolV1),
×
NEW
960
                                PubKey:  nodePub[:],
×
NEW
961
                        },
×
NEW
962
                )
×
NEW
963
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
964
                        return nil
×
NEW
965
                } else if err != nil {
×
NEW
966
                        return fmt.Errorf("unable to fetch node: %w", err)
×
NEW
967
                }
×
968

NEW
969
                return forEachNodeChannel(
×
NEW
970
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
NEW
971
                )
×
NEW
972
        }, func() {})
×
973
}
974

975
// ChanUpdatesInHorizon returns all the known channel edges which have at least
976
// one edge that has an update timestamp within the specified horizon.
977
//
978
// NOTE: This is part of the V1Store interface.
979
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
NEW
980
        endTime time.Time) ([]ChannelEdge, error) {
×
NEW
981

×
NEW
982
        s.cacheMu.Lock()
×
NEW
983
        defer s.cacheMu.Unlock()
×
NEW
984

×
NEW
985
        var (
×
NEW
986
                ctx = context.TODO()
×
NEW
987
                // To ensure we don't return duplicate ChannelEdges, we'll use an
×
NEW
988
                // additional map to keep track of the edges already seen to prevent
×
NEW
989
                // re-adding it.
×
NEW
990
                edgesSeen    = make(map[uint64]struct{})
×
NEW
991
                edgesToCache = make(map[uint64]ChannelEdge)
×
NEW
992
                edges        []ChannelEdge
×
NEW
993
                hits         int
×
NEW
994
        )
×
NEW
995
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
996
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
NEW
997
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
NEW
998
                                Version: int16(ProtocolV1),
×
NEW
999
                                StartTime: sql.NullInt64{
×
NEW
1000
                                        Valid: true,
×
NEW
1001
                                        Int64: startTime.Unix(),
×
NEW
1002
                                },
×
NEW
1003
                                EndTime: sql.NullInt64{
×
NEW
1004
                                        Valid: true,
×
NEW
1005
                                        Int64: endTime.Unix(),
×
NEW
1006
                                },
×
NEW
1007
                        },
×
NEW
1008
                )
×
NEW
1009
                if err != nil {
×
NEW
1010
                        return err
×
NEW
1011
                }
×
1012

NEW
1013
                for _, row := range rows {
×
NEW
1014
                        // If we've already retrieved the info and policies for
×
NEW
1015
                        // this edge, then we can skip it as we don't need to do
×
NEW
1016
                        // so again.
×
NEW
1017
                        chanIDInt := byteOrder.Uint64(row.Scid)
×
NEW
1018
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
NEW
1019
                                continue
×
1020
                        }
1021

NEW
1022
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
NEW
1023
                                hits++
×
NEW
1024
                                edgesSeen[chanIDInt] = struct{}{}
×
NEW
1025
                                edges = append(edges, channel)
×
NEW
1026

×
NEW
1027
                                continue
×
1028
                        }
1029

NEW
1030
                        node1, node2, err := getAndBuildNodes(ctx, db, row)
×
NEW
1031
                        if err != nil {
×
NEW
1032
                                return err
×
NEW
1033
                        }
×
1034

NEW
1035
                        channel, p1, p2, err := getAndBuildEdgeInfoAndPolicies(
×
NEW
1036
                                ctx, db, s.cfg.ChainHash, row.ID, row,
×
NEW
1037
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
NEW
1038
                        )
×
NEW
1039
                        if err != nil {
×
NEW
1040
                                return err
×
NEW
1041
                        }
×
1042

NEW
1043
                        edgesSeen[chanIDInt] = struct{}{}
×
NEW
1044
                        chanEdge := ChannelEdge{
×
NEW
1045
                                Info:    channel,
×
NEW
1046
                                Policy1: p1,
×
NEW
1047
                                Policy2: p2,
×
NEW
1048
                                Node1:   node1,
×
NEW
1049
                                Node2:   node2,
×
NEW
1050
                        }
×
NEW
1051
                        edges = append(edges, chanEdge)
×
NEW
1052
                        edgesToCache[chanIDInt] = chanEdge
×
1053
                }
1054

NEW
1055
                return nil
×
NEW
1056
        }, func() {
×
NEW
1057
                edgesSeen = make(map[uint64]struct{})
×
NEW
1058
                edgesToCache = make(map[uint64]ChannelEdge)
×
NEW
1059
                edges = nil
×
NEW
1060
        })
×
NEW
1061
        if err != nil {
×
NEW
1062
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
1063
        }
×
1064

1065
        // Insert any edges loaded from disk into the cache.
NEW
1066
        for chanid, channel := range edgesToCache {
×
NEW
1067
                s.chanCache.insert(chanid, channel)
×
NEW
1068
        }
×
1069

NEW
1070
        log.Debugf("ChanUpdatesInHorizon hit percentage: %f (%d/%d)",
×
NEW
1071
                float64(hits)/float64(len(edges)), hits, len(edges))
×
NEW
1072

×
NEW
1073
        return edges, nil
×
1074
}
1075

1076
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1077
// data to the call-back.
1078
//
1079
// NOTE: The callback contents MUST not be modified.
1080
//
1081
// NOTE: part of the V1Store interface.
1082
func (s *SQLStore) ForEachNodeCached(cb func(node route.Vertex,
NEW
1083
        chans map[uint64]*DirectedChannel) error) error {
×
NEW
1084

×
NEW
1085
        var ctx = context.TODO()
×
NEW
1086

×
NEW
1087
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1088
                nodes, err := db.ListNodes(ctx, int16(ProtocolV1))
×
NEW
1089
                if err != nil {
×
NEW
1090
                        return fmt.Errorf("unable to fetch node ids: %w", err)
×
NEW
1091
                }
×
1092

NEW
1093
                for _, node := range nodes {
×
NEW
1094
                        features, err := getNodeFeatures(ctx, db, node.ID)
×
NEW
1095
                        if err != nil {
×
NEW
1096
                                return fmt.Errorf("unable to fetch node "+
×
NEW
1097
                                        "features: %w", err)
×
NEW
1098
                        }
×
1099

NEW
1100
                        var nodePub route.Vertex
×
NEW
1101
                        copy(nodePub[:], node.PubKey)
×
NEW
1102

×
NEW
1103
                        toNodeCallback := func() route.Vertex {
×
NEW
1104
                                return nodePub
×
NEW
1105
                        }
×
1106

NEW
1107
                        rows, err := db.ListChannelsByNodeID(
×
NEW
1108
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
NEW
1109
                                        Version: int16(ProtocolV1),
×
NEW
1110
                                        NodeID1: node.ID,
×
NEW
1111
                                },
×
NEW
1112
                        )
×
NEW
1113
                        if err != nil {
×
NEW
1114
                                return fmt.Errorf("unable to fetch channels: "+
×
NEW
1115
                                        "%w", err)
×
NEW
1116
                        }
×
1117

NEW
1118
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
NEW
1119
                        for _, row := range rows {
×
NEW
1120
                                node1, node2, err := buildNodeVertices(
×
NEW
1121
                                        row.Node1Pubkey, row.Node2Pubkey,
×
NEW
1122
                                )
×
NEW
1123
                                if err != nil {
×
NEW
1124
                                        return err
×
NEW
1125
                                }
×
1126

NEW
1127
                                e, p1, p2, err := getAndBuildEdgeInfoAndPolicies(
×
NEW
1128
                                        ctx, db, s.cfg.ChainHash, row.ID, row,
×
NEW
1129
                                        node1, node2,
×
NEW
1130
                                )
×
NEW
1131
                                if err != nil {
×
NEW
1132
                                        return err
×
NEW
1133
                                }
×
1134

1135
                                // Determine the outgoing and incoming policy
1136
                                // for this channel and node combo.
NEW
1137
                                outPolicy, inPolicy := p1, p2
×
NEW
1138
                                if p1 != nil && p1.ToNode == nodePub {
×
NEW
1139
                                        outPolicy, inPolicy = p2, p1
×
NEW
1140
                                } else if p2 != nil && p2.ToNode != nodePub {
×
NEW
1141
                                        outPolicy, inPolicy = p2, p1
×
NEW
1142
                                }
×
1143

NEW
1144
                                var cachedInPolicy *models.CachedEdgePolicy
×
NEW
1145
                                if inPolicy != nil {
×
NEW
1146
                                        cachedInPolicy = models.NewCachedPolicy(
×
NEW
1147
                                                p2,
×
NEW
1148
                                        )
×
NEW
1149
                                        cachedInPolicy.ToNodePubKey =
×
NEW
1150
                                                toNodeCallback
×
NEW
1151
                                        cachedInPolicy.ToNodeFeatures =
×
NEW
1152
                                                features
×
NEW
1153
                                }
×
1154

NEW
1155
                                var inboundFee lnwire.Fee
×
NEW
1156
                                outPolicy.InboundFee.WhenSome(
×
NEW
1157
                                        func(fee lnwire.Fee) {
×
NEW
1158
                                                inboundFee = fee
×
NEW
1159
                                        },
×
1160
                                )
1161

NEW
1162
                                directedChannel := &DirectedChannel{
×
NEW
1163
                                        ChannelID: e.ChannelID,
×
NEW
1164
                                        IsNode1: nodePub ==
×
NEW
1165
                                                e.NodeKey1Bytes,
×
NEW
1166
                                        OtherNode:    e.NodeKey2Bytes,
×
NEW
1167
                                        Capacity:     e.Capacity,
×
NEW
1168
                                        OutPolicySet: p1 != nil,
×
NEW
1169
                                        InPolicy:     cachedInPolicy,
×
NEW
1170
                                        InboundFee:   inboundFee,
×
NEW
1171
                                }
×
NEW
1172

×
NEW
1173
                                if nodePub == e.NodeKey2Bytes {
×
NEW
1174
                                        directedChannel.OtherNode =
×
NEW
1175
                                                e.NodeKey1Bytes
×
NEW
1176
                                }
×
1177

NEW
1178
                                channels[e.ChannelID] = directedChannel
×
1179
                        }
1180

NEW
1181
                        if err := cb(nodePub, channels); err != nil {
×
NEW
1182
                                return err
×
NEW
1183
                        }
×
1184
                }
1185

NEW
1186
                return nil
×
NEW
1187
        }, func() {})
×
1188
}
1189

1190
// ForEachChannelCacheable iterates through all the channel edges stored
1191
// within the graph and invokes the passed callback for each edge. The
1192
// callback takes two edges as since this is a directed graph, both the
1193
// in/out edges are visited. If the callback returns an error, then the
1194
// transaction is aborted and the iteration stops early.
1195
//
1196
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1197
// pointer for that particular channel edge routing policy will be
1198
// passed into the callback.
1199
//
1200
// NOTE: this method is like ForEachChannel but fetches only the data
1201
// required for the graph cache.
1202
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1203
        *models.CachedEdgePolicy,
NEW
1204
        *models.CachedEdgePolicy) error) error {
×
NEW
1205

×
NEW
1206
        ctx := context.TODO()
×
NEW
1207

×
NEW
1208
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1209
                rows, err := db.ListAllChannels(ctx, int16(ProtocolV1))
×
NEW
1210
                if err != nil {
×
NEW
1211
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
1212
                }
×
1213

NEW
1214
                for _, row := range rows {
×
NEW
1215
                        node1, node2, err := buildNodeVertices(
×
NEW
1216
                                row.Node1Pubkey, row.Node2Pubkey,
×
NEW
1217
                        )
×
NEW
1218
                        if err != nil {
×
NEW
1219
                                return err
×
NEW
1220
                        }
×
1221

NEW
1222
                        edge, err := buildCacheableChannelInfo(
×
NEW
1223
                                row, node1, node2,
×
NEW
1224
                        )
×
NEW
1225
                        if err != nil {
×
NEW
1226
                                return err
×
NEW
1227
                        }
×
1228

NEW
1229
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
NEW
1230
                        if err != nil {
×
NEW
1231
                                return err
×
NEW
1232
                        }
×
1233

NEW
1234
                        var pol1, pol2 *models.CachedEdgePolicy
×
NEW
1235
                        if dbPol1 != nil {
×
NEW
1236
                                policy1, err := buildChanPolicy(
×
NEW
1237
                                        *dbPol1, edge.ChannelID, nil,
×
NEW
1238
                                        node2, true,
×
NEW
1239
                                )
×
NEW
1240
                                if err != nil {
×
NEW
1241
                                        return err
×
NEW
1242
                                }
×
1243

NEW
1244
                                pol1 = models.NewCachedPolicy(policy1)
×
1245
                        }
NEW
1246
                        if dbPol2 != nil {
×
NEW
1247
                                policy2, err := buildChanPolicy(
×
NEW
1248
                                        *dbPol2, edge.ChannelID, nil,
×
NEW
1249
                                        node1, false,
×
NEW
1250
                                )
×
NEW
1251
                                if err != nil {
×
NEW
1252
                                        return err
×
NEW
1253
                                }
×
1254

NEW
1255
                                pol2 = models.NewCachedPolicy(policy2)
×
1256
                        }
1257

NEW
1258
                        if err := cb(edge, pol1, pol2); err != nil {
×
NEW
1259
                                return err
×
NEW
1260
                        }
×
1261
                }
1262

NEW
1263
                return nil
×
NEW
1264
        }, func() {})
×
1265
}
1266

1267
// ForEachChannel iterates through all the channel edges stored within the
1268
// graph and invokes the passed callback for each edge. The callback takes two
1269
// edges as since this is a directed graph, both the in/out edges are visited.
1270
// If the callback returns an error, then the transaction is aborted and the
1271
// iteration stops early.
1272
//
1273
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1274
// for that particular channel edge routing policy will be passed into the
1275
// callback.
1276
//
1277
// NOTE: part of the V1Store interface.
1278
func (s *SQLStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
NEW
1279
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
NEW
1280

×
NEW
1281
        var ctx = context.TODO()
×
NEW
1282

×
NEW
1283
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1284
                rows, err := db.ListAllChannels(ctx, int16(ProtocolV1))
×
NEW
1285
                if err != nil {
×
NEW
1286
                        return err
×
NEW
1287
                }
×
1288

NEW
1289
                for _, row := range rows {
×
NEW
1290
                        node1, node2, err := buildNodeVertices(
×
NEW
1291
                                row.Node1Pubkey, row.Node2Pubkey,
×
NEW
1292
                        )
×
NEW
1293
                        if err != nil {
×
NEW
1294
                                return err
×
NEW
1295
                        }
×
1296

NEW
1297
                        edge, p1, p2, err := getAndBuildEdgeInfoAndPolicies(
×
NEW
1298
                                ctx, db, s.cfg.ChainHash, row.ID, row,
×
NEW
1299
                                node1, node2,
×
NEW
1300
                        )
×
NEW
1301
                        if err != nil {
×
NEW
1302
                                return fmt.Errorf("unable to build edge "+
×
NEW
1303
                                        "info and policies: %w", err)
×
NEW
1304
                        }
×
1305

NEW
1306
                        if err := cb(edge, p1, p2); err != nil {
×
NEW
1307
                                return err
×
NEW
1308
                        }
×
1309
                }
1310

NEW
1311
                return nil
×
NEW
1312
        }, func() {})
×
1313
}
1314

1315
// FilterChannelRange returns the channel ID's of all known channels which were
1316
// mined in a block height within the passed range. The channel IDs are grouped
1317
// by their common block height. This method can be used to quickly share with a
1318
// peer the set of channels we know of within a particular range to catch them
1319
// up after a period of time offline. If withTimestamps is true then the
1320
// timestamp info of the latest received channel update messages of the channel
1321
// will be included in the response.
1322
//
1323
// NOTE: This is part of the V1Store interface.
1324
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
NEW
1325
        withTimestamps bool) ([]BlockChannelRange, error) {
×
NEW
1326

×
NEW
1327
        var (
×
NEW
1328
                ctx       = context.TODO()
×
NEW
1329
                startSCID = &lnwire.ShortChannelID{
×
NEW
1330
                        BlockHeight: startHeight,
×
NEW
1331
                }
×
NEW
1332
                endSCID = lnwire.ShortChannelID{
×
NEW
1333
                        BlockHeight: endHeight,
×
NEW
1334
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
NEW
1335
                        TxPosition:  math.MaxUint16,
×
NEW
1336
                }
×
NEW
1337
        )
×
NEW
1338

×
NEW
1339
        var chanIDStart [8]byte
×
NEW
1340
        byteOrder.PutUint64(chanIDStart[:], startSCID.ToUint64())
×
NEW
1341
        var chanIDEnd [8]byte
×
NEW
1342
        byteOrder.PutUint64(chanIDEnd[:], endSCID.ToUint64())
×
NEW
1343

×
NEW
1344
        // 1) get all channels where channelID is between start and end chan ID.
×
NEW
1345
        // 2) skip if not public (ie, no channel_proof)
×
NEW
1346
        // 3) collect that channel.
×
NEW
1347
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
NEW
1348
        //    and add those timestamps to the collected channel.
×
NEW
1349
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
NEW
1350
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1351
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
NEW
1352
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
NEW
1353
                                StartScid: chanIDStart[:],
×
NEW
1354
                                EndScid:   chanIDEnd[:],
×
NEW
1355
                        },
×
NEW
1356
                )
×
NEW
1357
                if err != nil {
×
NEW
1358
                        return fmt.Errorf("unable to fetch channel range: %w",
×
NEW
1359
                                err)
×
NEW
1360
                }
×
1361

NEW
1362
                for _, dbChan := range dbChans {
×
NEW
1363
                        cid := lnwire.NewShortChanIDFromInt(
×
NEW
1364
                                byteOrder.Uint64(dbChan.Scid),
×
NEW
1365
                        )
×
NEW
1366
                        chanInfo := NewChannelUpdateInfo(
×
NEW
1367
                                cid, time.Time{}, time.Time{},
×
NEW
1368
                        )
×
NEW
1369

×
NEW
1370
                        if !withTimestamps {
×
NEW
1371
                                channelsPerBlock[cid.BlockHeight] = append(
×
NEW
1372
                                        channelsPerBlock[cid.BlockHeight],
×
NEW
1373
                                        chanInfo,
×
NEW
1374
                                )
×
NEW
1375

×
NEW
1376
                                continue
×
1377
                        }
1378

1379
                        //nolint:ll
NEW
1380
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
NEW
1381
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
1382
                                        Version:   int16(ProtocolV1),
×
NEW
1383
                                        ChannelID: dbChan.ID,
×
NEW
1384
                                        NodeID:    dbChan.NodeID1,
×
NEW
1385
                                },
×
NEW
1386
                        )
×
NEW
1387
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
NEW
1388
                                return fmt.Errorf("unable to fetch node1 "+
×
NEW
1389
                                        "policy: %w", err)
×
NEW
1390
                        } else if err == nil {
×
NEW
1391
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
NEW
1392
                                        node1Policy.LastUpdate.Int64, 0,
×
NEW
1393
                                )
×
NEW
1394
                        }
×
1395

1396
                        //nolint:ll
NEW
1397
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
NEW
1398
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
1399
                                        Version:   int16(ProtocolV1),
×
NEW
1400
                                        ChannelID: dbChan.ID,
×
NEW
1401
                                        NodeID:    dbChan.NodeID2,
×
NEW
1402
                                },
×
NEW
1403
                        )
×
NEW
1404
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
NEW
1405
                                return fmt.Errorf("unable to fetch node2 "+
×
NEW
1406
                                        "policy: %w", err)
×
NEW
1407
                        } else if err == nil {
×
NEW
1408
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
NEW
1409
                                        node2Policy.LastUpdate.Int64, 0,
×
NEW
1410
                                )
×
NEW
1411
                        }
×
1412

NEW
1413
                        channelsPerBlock[cid.BlockHeight] = append(
×
NEW
1414
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
NEW
1415
                        )
×
1416
                }
1417

NEW
1418
                return nil
×
NEW
1419
        }, func() {
×
NEW
1420
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
NEW
1421
        })
×
NEW
1422
        if err != nil {
×
NEW
1423
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
NEW
1424
        }
×
1425

NEW
1426
        if len(channelsPerBlock) == 0 {
×
NEW
1427
                return nil, nil
×
NEW
1428
        }
×
1429

1430
        // Return the channel ranges in ascending block height order.
NEW
1431
        blocks := make([]uint32, 0, len(channelsPerBlock))
×
NEW
1432
        for block := range channelsPerBlock {
×
NEW
1433
                blocks = append(blocks, block)
×
NEW
1434
        }
×
NEW
1435
        sort.Slice(blocks, func(i, j int) bool {
×
NEW
1436
                return blocks[i] < blocks[j]
×
NEW
1437
        })
×
1438

NEW
1439
        channelRanges := make([]BlockChannelRange, 0, len(channelsPerBlock))
×
NEW
1440
        for _, block := range blocks {
×
NEW
1441
                channelRanges = append(channelRanges, BlockChannelRange{
×
NEW
1442
                        Height:   block,
×
NEW
1443
                        Channels: channelsPerBlock[block],
×
NEW
1444
                })
×
NEW
1445
        }
×
1446

NEW
1447
        return channelRanges, nil
×
1448
}
1449

1450
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1451
// zombie. This method is used on an ad-hoc basis, when channels need to be
1452
// marked as zombies outside the normal pruning cycle.
1453
//
1454
// NOTE: part of the V1Store interface.
1455
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
NEW
1456
        pubKey1, pubKey2 [33]byte) error {
×
NEW
1457

×
NEW
1458
        ctx := context.TODO()
×
NEW
1459

×
NEW
1460
        s.cacheMu.Lock()
×
NEW
1461
        defer s.cacheMu.Unlock()
×
NEW
1462

×
NEW
1463
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
1464
                return db.UpsertZombieChannel(
×
NEW
1465
                        ctx, sqlc.UpsertZombieChannelParams{
×
NEW
1466
                                Version:  int16(ProtocolV1),
×
NEW
1467
                                Scid:     int64(chanID),
×
NEW
1468
                                NodeKey1: pubKey1[:],
×
NEW
1469
                                NodeKey2: pubKey2[:],
×
NEW
1470
                        },
×
NEW
1471
                )
×
NEW
1472
        }, func() {})
×
NEW
1473
        if err != nil {
×
NEW
1474
                return err
×
NEW
1475
        }
×
1476

NEW
1477
        s.rejectCache.remove(chanID)
×
NEW
1478
        s.chanCache.remove(chanID)
×
NEW
1479

×
NEW
1480
        return nil
×
1481
}
1482

1483
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1484
//
1485
// NOTE: part of the V1Store interface.
NEW
1486
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
NEW
1487
        s.cacheMu.Lock()
×
NEW
1488
        defer s.cacheMu.Unlock()
×
NEW
1489

×
NEW
1490
        var ctx = context.TODO()
×
NEW
1491
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
1492
                _, err := db.GetZombieChannel(
×
NEW
1493
                        ctx, sqlc.GetZombieChannelParams{
×
NEW
1494
                                Scid:    int64(chanID),
×
NEW
1495
                                Version: int16(ProtocolV1),
×
NEW
1496
                        },
×
NEW
1497
                )
×
NEW
1498
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
1499
                        return ErrZombieEdgeNotFound
×
NEW
1500
                } else if err != nil {
×
NEW
1501
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
NEW
1502
                                err)
×
NEW
1503
                }
×
1504

NEW
1505
                return db.DeleteZombieChannel(
×
NEW
1506
                        ctx, sqlc.DeleteZombieChannelParams{
×
NEW
1507
                                Scid:    int64(chanID),
×
NEW
1508
                                Version: int16(ProtocolV1),
×
NEW
1509
                        },
×
NEW
1510
                )
×
NEW
1511
        }, func() {})
×
NEW
1512
        if err != nil {
×
NEW
1513
                return err
×
NEW
1514
        }
×
1515

NEW
1516
        s.rejectCache.remove(chanID)
×
NEW
1517
        s.chanCache.remove(chanID)
×
NEW
1518

×
NEW
1519
        return err
×
1520
}
1521

1522
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1523
// zombie, then the two node public keys corresponding to this edge are also
1524
// returned.
1525
//
1526
// NOTE: part of the V1Store interface.
NEW
1527
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) {
×
NEW
1528
        var (
×
NEW
1529
                ctx              = context.TODO()
×
NEW
1530
                isZombie         bool
×
NEW
1531
                pubKey1, pubKey2 route.Vertex
×
NEW
1532
        )
×
NEW
1533
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1534
                zombie, err := db.GetZombieChannel(
×
NEW
1535
                        ctx, sqlc.GetZombieChannelParams{
×
NEW
1536
                                Scid:    int64(chanID),
×
NEW
1537
                                Version: int16(ProtocolV1),
×
NEW
1538
                        },
×
NEW
1539
                )
×
NEW
1540
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
1541
                        return nil
×
NEW
1542
                }
×
NEW
1543
                if err != nil {
×
NEW
1544
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
NEW
1545
                                err)
×
NEW
1546
                }
×
1547

NEW
1548
                copy(pubKey1[:], zombie.NodeKey1)
×
NEW
1549
                copy(pubKey2[:], zombie.NodeKey2)
×
NEW
1550
                isZombie = true
×
NEW
1551

×
NEW
1552
                return nil
×
NEW
1553
        }, func() {})
×
NEW
1554
        if err != nil {
×
NEW
1555
                return false, route.Vertex{}, route.Vertex{}
×
NEW
1556
        }
×
1557

NEW
1558
        return isZombie, pubKey1, pubKey2
×
1559
}
1560

1561
// NumZombies returns the current number of zombie channels in the graph.
1562
//
1563
// NOTE: part of the V1Store interface.
NEW
1564
func (s *SQLStore) NumZombies() (uint64, error) {
×
NEW
1565
        var (
×
NEW
1566
                ctx        = context.TODO()
×
NEW
1567
                numZombies uint64
×
NEW
1568
        )
×
NEW
1569
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1570
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
NEW
1571
                if err != nil {
×
NEW
1572
                        return fmt.Errorf("unable to count zombie channels: %w",
×
NEW
1573
                                err)
×
NEW
1574
                }
×
1575

NEW
1576
                numZombies = uint64(count)
×
NEW
1577

×
NEW
1578
                return nil
×
NEW
1579
        }, func() {})
×
NEW
1580
        if err != nil {
×
NEW
1581
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
NEW
1582
        }
×
1583

NEW
1584
        return numZombies, nil
×
1585
}
1586

1587
// DeleteChannelEdges removes edges with the given channel IDs from the
1588
// database and marks them as zombies. This ensures that we're unable to re-add
1589
// it to our database once again. If an edge does not exist within the
1590
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1591
// true, then when we mark these edges as zombies, we'll set up the keys such
1592
// that we require the node that failed to send the fresh update to be the one
1593
// that resurrects the channel from its zombie state. The markZombie bool
1594
// denotes whether to mark the channel as a zombie.
1595
//
1596
// NOTE: part of the V1Store interface.
1597
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
NEW
1598
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
NEW
1599

×
NEW
1600
        s.cacheMu.Lock()
×
NEW
1601
        defer s.cacheMu.Unlock()
×
NEW
1602

×
NEW
1603
        var (
×
NEW
1604
                ctx     = context.TODO()
×
NEW
1605
                deleted []*models.ChannelEdgeInfo
×
NEW
1606
        )
×
NEW
1607
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
1608
                for _, chanID := range chanIDs {
×
NEW
1609
                        var chanIDB [8]byte
×
NEW
1610
                        byteOrder.PutUint64(chanIDB[:], chanID)
×
NEW
1611

×
NEW
1612
                        row, err := db.GetChannelBySCIDWithPolicies(
×
NEW
1613
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
NEW
1614
                                        Scid:    chanIDB[:],
×
NEW
1615
                                        Version: int16(ProtocolV1),
×
NEW
1616
                                },
×
NEW
1617
                        )
×
NEW
1618
                        if errors.Is(err, sql.ErrNoRows) {
×
NEW
1619
                                return ErrEdgeNotFound
×
NEW
1620
                        } else if err != nil {
×
NEW
1621
                                return fmt.Errorf("unable to fetch channel: %w",
×
NEW
1622
                                        err)
×
NEW
1623
                        }
×
1624

NEW
1625
                        node1, node2, err := buildNodeVertices(
×
NEW
1626
                                row.Node1PubKey, row.Node2PubKey,
×
NEW
1627
                        )
×
NEW
1628
                        if err != nil {
×
NEW
1629
                                return err
×
NEW
1630
                        }
×
1631

NEW
1632
                        info, err := getAndBuildEdgeInfo(
×
NEW
1633
                                ctx, db, s.cfg.ChainHash, row.ID, row,
×
NEW
1634
                                node1, node2,
×
NEW
1635
                        )
×
NEW
1636
                        if err != nil {
×
NEW
1637
                                return err
×
NEW
1638
                        }
×
1639

NEW
1640
                        err = db.DeleteChannel(ctx, row.ID)
×
NEW
1641
                        if err != nil {
×
NEW
1642
                                return fmt.Errorf("unable to delete "+
×
NEW
1643
                                        "channel: %w", err)
×
NEW
1644
                        }
×
1645

NEW
1646
                        deleted = append(deleted, info)
×
NEW
1647

×
NEW
1648
                        if !markZombie {
×
NEW
1649
                                continue
×
1650
                        }
1651

NEW
1652
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
NEW
1653
                                info.NodeKey2Bytes
×
NEW
1654
                        if strictZombiePruning {
×
NEW
1655
                                var e1UpdateTime, e2UpdateTime *time.Time
×
NEW
1656
                                if row.Policy1LastUpdate.Valid {
×
NEW
1657
                                        e1Time := time.Unix(
×
NEW
1658
                                                row.Policy1LastUpdate.Int64, 0,
×
NEW
1659
                                        )
×
NEW
1660
                                        e1UpdateTime = &e1Time
×
NEW
1661
                                }
×
NEW
1662
                                if row.Policy2LastUpdate.Valid {
×
NEW
1663
                                        e2Time := time.Unix(
×
NEW
1664
                                                row.Policy2LastUpdate.Int64, 0,
×
NEW
1665
                                        )
×
NEW
1666
                                        e2UpdateTime = &e2Time
×
NEW
1667
                                }
×
1668

NEW
1669
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
NEW
1670
                                        info, e1UpdateTime, e2UpdateTime,
×
NEW
1671
                                )
×
1672
                        }
1673

NEW
1674
                        err = db.UpsertZombieChannel(
×
NEW
1675
                                ctx, sqlc.UpsertZombieChannelParams{
×
NEW
1676
                                        Version:  int16(ProtocolV1),
×
NEW
1677
                                        Scid:     int64(chanID),
×
NEW
1678
                                        NodeKey1: nodeKey1[:],
×
NEW
1679
                                        NodeKey2: nodeKey2[:],
×
NEW
1680
                                },
×
NEW
1681
                        )
×
NEW
1682
                        if err != nil {
×
NEW
1683
                                return fmt.Errorf("unable to mark channel as "+
×
NEW
1684
                                        "zombie: %w", err)
×
NEW
1685
                        }
×
1686
                }
1687

NEW
1688
                return nil
×
NEW
1689
        }, func() {
×
NEW
1690
                deleted = nil
×
NEW
1691
        })
×
NEW
1692
        if err != nil {
×
NEW
1693
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
NEW
1694
                        err)
×
NEW
1695
        }
×
1696

NEW
1697
        for _, chanID := range chanIDs {
×
NEW
1698
                s.rejectCache.remove(chanID)
×
NEW
1699
                s.chanCache.remove(chanID)
×
NEW
1700
        }
×
1701

NEW
1702
        return deleted, nil
×
1703
}
1704

1705
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1706
// channel identified by the channel ID. If the channel can't be found, then
1707
// ErrEdgeNotFound is returned. A struct which houses the general information
1708
// for the channel itself is returned as well as two structs that contain the
1709
// routing policies for the channel in either direction.
1710
//
1711
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1712
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1713
// the ChannelEdgeInfo will only include the public keys of each node.
1714
//
1715
// NOTE: part of the V1Store interface.
1716
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1717
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
NEW
1718
        *models.ChannelEdgePolicy, error) {
×
NEW
1719

×
NEW
1720
        var (
×
NEW
1721
                ctx              = context.TODO()
×
NEW
1722
                edge             *models.ChannelEdgeInfo
×
NEW
1723
                policy1, policy2 *models.ChannelEdgePolicy
×
NEW
1724
        )
×
NEW
1725
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1726
                var chanIDB [8]byte
×
NEW
1727
                byteOrder.PutUint64(chanIDB[:], chanID)
×
NEW
1728

×
NEW
1729
                row, err := db.GetChannelBySCIDWithPolicies(
×
NEW
1730
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
NEW
1731
                                Scid:    chanIDB[:],
×
NEW
1732
                                Version: int16(ProtocolV1),
×
NEW
1733
                        },
×
NEW
1734
                )
×
NEW
1735
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
1736
                        // First check if this edge is perhaps in the zombie
×
NEW
1737
                        // index.
×
NEW
1738
                        isZombie, err := db.IsZombieChannel(
×
NEW
1739
                                ctx, sqlc.IsZombieChannelParams{
×
NEW
1740
                                        Scid:    int64(chanID),
×
NEW
1741
                                        Version: int16(ProtocolV1),
×
NEW
1742
                                },
×
NEW
1743
                        )
×
NEW
1744
                        if err != nil {
×
NEW
1745
                                return fmt.Errorf("unable to check if "+
×
NEW
1746
                                        "channel is zombie: %w", err)
×
NEW
1747
                        } else if isZombie {
×
NEW
1748
                                return ErrZombieEdge
×
NEW
1749
                        }
×
1750

NEW
1751
                        return ErrEdgeNotFound
×
NEW
1752
                } else if err != nil {
×
NEW
1753
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
NEW
1754
                }
×
1755

NEW
1756
                node1, node2, err := buildNodeVertices(
×
NEW
1757
                        row.Node1PubKey, row.Node2PubKey,
×
NEW
1758
                )
×
NEW
1759
                if err != nil {
×
NEW
1760
                        return err
×
NEW
1761
                }
×
1762

NEW
1763
                edge, policy1, policy2, err = getAndBuildEdgeInfoAndPolicies(
×
NEW
1764
                        ctx, db, s.cfg.ChainHash, row.ID, row, node1, node2,
×
NEW
1765
                )
×
NEW
1766

×
NEW
1767
                return err
×
NEW
1768
        }, func() {})
×
NEW
1769
        if err != nil {
×
NEW
1770
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
NEW
1771
                        err)
×
NEW
1772
        }
×
1773

NEW
1774
        return edge, policy1, policy2, nil
×
1775
}
1776

1777
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1778
// the channel identified by the funding outpoint. If the channel can't be
1779
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1780
// information for the channel itself is returned as well as two structs that
1781
// contain the routing policies for the channel in either direction.
1782
//
1783
// NOTE: part of the V1Store interface.
1784
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1785
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
NEW
1786
        *models.ChannelEdgePolicy, error) {
×
NEW
1787

×
NEW
1788
        var (
×
NEW
1789
                ctx              = context.TODO()
×
NEW
1790
                edge             *models.ChannelEdgeInfo
×
NEW
1791
                policy1, policy2 *models.ChannelEdgePolicy
×
NEW
1792
        )
×
NEW
1793
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1794
                row, err := db.GetChannelByOutpointWithPolicies(
×
NEW
1795
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
NEW
1796
                                Outpoint: op.String(),
×
NEW
1797
                                Version:  int16(ProtocolV1),
×
NEW
1798
                        },
×
NEW
1799
                )
×
NEW
1800
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
1801
                        return ErrEdgeNotFound
×
NEW
1802
                } else if err != nil {
×
NEW
1803
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
NEW
1804
                }
×
1805

NEW
1806
                node1, node2, err := buildNodeVertices(
×
NEW
1807
                        row.Node1Pubkey, row.Node2Pubkey,
×
NEW
1808
                )
×
NEW
1809
                if err != nil {
×
NEW
1810
                        return err
×
NEW
1811
                }
×
1812

NEW
1813
                edge, policy1, policy2, err = getAndBuildEdgeInfoAndPolicies(
×
NEW
1814
                        ctx, db, s.cfg.ChainHash, row.ID, row, node1, node2,
×
NEW
1815
                )
×
NEW
1816

×
NEW
1817
                return err
×
NEW
1818
        }, func() {})
×
NEW
1819
        if err != nil {
×
NEW
1820
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
NEW
1821
                        err)
×
NEW
1822
        }
×
1823

NEW
1824
        return edge, policy1, policy2, nil
×
1825
}
1826

1827
// HasChannelEdge returns true if the database knows of a channel edge with the
1828
// passed channel ID, and false otherwise. If an edge with that ID is found
1829
// within the graph, then two time stamps representing the last time the edge
1830
// was updated for both directed edges are returned along with the boolean. If
1831
// it is not found, then the zombie index is checked and its result is returned
1832
// as the second boolean.
1833
//
1834
// NOTE: part of the V1Store interface.
1835
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
NEW
1836
        bool, error) {
×
NEW
1837

×
NEW
1838
        ctx := context.TODO()
×
NEW
1839

×
NEW
1840
        var (
×
NEW
1841
                exists          bool
×
NEW
1842
                isZombie        bool
×
NEW
1843
                node1LastUpdate time.Time
×
NEW
1844
                node2LastUpdate time.Time
×
NEW
1845
        )
×
NEW
1846

×
NEW
1847
        // We'll query the cache with the shared lock held to allow multiple
×
NEW
1848
        // readers to access values in the cache concurrently if they exist.
×
NEW
1849
        s.cacheMu.RLock()
×
NEW
1850
        if entry, ok := s.rejectCache.get(chanID); ok {
×
NEW
1851
                s.cacheMu.RUnlock()
×
NEW
1852
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
NEW
1853
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
NEW
1854
                exists, isZombie = entry.flags.unpack()
×
NEW
1855

×
NEW
1856
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
NEW
1857
        }
×
NEW
1858
        s.cacheMu.RUnlock()
×
NEW
1859

×
NEW
1860
        s.cacheMu.Lock()
×
NEW
1861
        defer s.cacheMu.Unlock()
×
NEW
1862

×
NEW
1863
        // The item was not found with the shared lock, so we'll acquire the
×
NEW
1864
        // exclusive lock and check the cache again in case another method added
×
NEW
1865
        // the entry to the cache while no lock was held.
×
NEW
1866
        if entry, ok := s.rejectCache.get(chanID); ok {
×
NEW
1867
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
NEW
1868
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
NEW
1869
                exists, isZombie = entry.flags.unpack()
×
NEW
1870

×
NEW
1871
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
NEW
1872
        }
×
1873

NEW
1874
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1875
                var chanIDB [8]byte
×
NEW
1876
                byteOrder.PutUint64(chanIDB[:], chanID)
×
NEW
1877

×
NEW
1878
                channel, err := db.GetChannelBySCID(
×
NEW
1879
                        ctx, sqlc.GetChannelBySCIDParams{
×
NEW
1880
                                Scid:    chanIDB[:],
×
NEW
1881
                                Version: int16(ProtocolV1),
×
NEW
1882
                        },
×
NEW
1883
                )
×
NEW
1884
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
1885
                        // Check if it is a zombie channel.
×
NEW
1886
                        isZombie, err = db.IsZombieChannel(
×
NEW
1887
                                ctx, sqlc.IsZombieChannelParams{
×
NEW
1888
                                        Scid:    int64(chanID),
×
NEW
1889
                                        Version: int16(ProtocolV1),
×
NEW
1890
                                },
×
NEW
1891
                        )
×
NEW
1892
                        if err != nil {
×
NEW
1893
                                return fmt.Errorf("could not check if channel "+
×
NEW
1894
                                        "is zombie: %w", err)
×
NEW
1895
                        }
×
1896

NEW
1897
                        return nil
×
NEW
1898
                } else if err != nil {
×
NEW
1899
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
NEW
1900
                }
×
1901

NEW
1902
                exists = true
×
NEW
1903

×
NEW
1904
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
NEW
1905
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
1906
                                Version:   int16(ProtocolV1),
×
NEW
1907
                                ChannelID: channel.ID,
×
NEW
1908
                                NodeID:    channel.NodeID1,
×
NEW
1909
                        },
×
NEW
1910
                )
×
NEW
1911
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
NEW
1912
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
NEW
1913
                                err)
×
NEW
1914
                } else if err == nil {
×
NEW
1915
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
NEW
1916
                }
×
1917

NEW
1918
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
NEW
1919
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
1920
                                Version:   int16(ProtocolV1),
×
NEW
1921
                                ChannelID: channel.ID,
×
NEW
1922
                                NodeID:    channel.NodeID2,
×
NEW
1923
                        },
×
NEW
1924
                )
×
NEW
1925
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
NEW
1926
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
NEW
1927
                                err)
×
NEW
1928
                } else if err == nil {
×
NEW
1929
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
NEW
1930
                }
×
1931

NEW
1932
                return nil
×
NEW
1933
        }, func() {})
×
NEW
1934
        if err != nil {
×
NEW
1935
                return time.Time{}, time.Time{}, false, false,
×
NEW
1936
                        fmt.Errorf("unable to fetch channel: %w", err)
×
NEW
1937
        }
×
1938

NEW
1939
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
NEW
1940
                upd1Time: node1LastUpdate.Unix(),
×
NEW
1941
                upd2Time: node2LastUpdate.Unix(),
×
NEW
1942
                flags:    packRejectFlags(exists, isZombie),
×
NEW
1943
        })
×
NEW
1944

×
NEW
1945
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1946
}
1947

1948
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
1949
// passed channel point (outpoint). If the passed channel doesn't exist within
1950
// the database, then ErrEdgeNotFound is returned.
1951
//
1952
// NOTE: part of the V1Store interface.
NEW
1953
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
NEW
1954
        var (
×
NEW
1955
                ctx       = context.TODO()
×
NEW
1956
                channelID uint64
×
NEW
1957
        )
×
NEW
1958
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1959
                chanID, err := db.GetSCIDByOutpoint(
×
NEW
1960
                        ctx, sqlc.GetSCIDByOutpointParams{
×
NEW
1961
                                Outpoint: chanPoint.String(),
×
NEW
1962
                                Version:  int16(ProtocolV1),
×
NEW
1963
                        },
×
NEW
1964
                )
×
NEW
1965
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
1966
                        return ErrEdgeNotFound
×
NEW
1967
                } else if err != nil {
×
NEW
1968
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
NEW
1969
                                err)
×
NEW
1970
                }
×
1971

NEW
1972
                channelID = byteOrder.Uint64(chanID)
×
NEW
1973

×
NEW
1974
                return nil
×
NEW
1975
        }, func() {})
×
NEW
1976
        if err != nil {
×
NEW
1977
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
NEW
1978
        }
×
1979

NEW
1980
        return channelID, nil
×
1981
}
1982

1983
// IsPublicNode is a helper method that determines whether the node with the
1984
// given public key is seen as a public node in the graph from the graph's
1985
// source node's point of view.
1986
//
1987
// NOTE: part of the V1Store interface.
NEW
1988
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
NEW
1989
        ctx := context.TODO()
×
NEW
1990

×
NEW
1991
        var isPublic bool
×
NEW
1992
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1993
                var err error
×
NEW
1994
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
NEW
1995

×
NEW
1996
                return err
×
NEW
1997
        }, func() {})
×
NEW
1998
        if err != nil {
×
NEW
1999
                return false, fmt.Errorf("unable to check if node is "+
×
NEW
2000
                        "public: %w", err)
×
NEW
2001
        }
×
2002

NEW
2003
        return isPublic, nil
×
2004
}
2005

2006
// FetchChanInfos returns the set of channel edges that correspond to the passed
2007
// channel ID's. If an edge is the query is unknown to the database, it will
2008
// skipped and the result will contain only those edges that exist at the time
2009
// of the query. This can be used to respond to peer queries that are seeking to
2010
// fill in gaps in their view of the channel graph.
2011
//
2012
// NOTE: part of the V1Store interface.
NEW
2013
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
NEW
2014
        var (
×
NEW
2015
                ctx   = context.TODO()
×
NEW
2016
                edges []ChannelEdge
×
NEW
2017
        )
×
NEW
2018
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2019
                for _, chanID := range chanIDs {
×
NEW
2020
                        var chanIDB [8]byte
×
NEW
2021
                        byteOrder.PutUint64(chanIDB[:], chanID)
×
NEW
2022

×
NEW
2023
                        row, err := db.GetChannelBySCIDWithPolicies(
×
NEW
2024
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
NEW
2025
                                        Scid:    chanIDB[:],
×
NEW
2026
                                        Version: int16(ProtocolV1),
×
NEW
2027
                                },
×
NEW
2028
                        )
×
NEW
2029
                        if errors.Is(err, sql.ErrNoRows) {
×
NEW
2030
                                continue
×
NEW
2031
                        } else if err != nil {
×
NEW
2032
                                return fmt.Errorf("unable to fetch channel: %w",
×
NEW
2033
                                        err)
×
NEW
2034
                        }
×
2035

NEW
2036
                        node1, node2, err := getAndBuildNodes(ctx, db, row)
×
NEW
2037
                        if err != nil {
×
NEW
2038
                                return fmt.Errorf("unable to fetch nodes: %w",
×
NEW
2039
                                        err)
×
NEW
2040
                        }
×
2041

NEW
2042
                        edge, p1, p2, err := getAndBuildEdgeInfoAndPolicies(
×
NEW
2043
                                ctx, db, s.cfg.ChainHash, row.ID, row,
×
NEW
2044
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
NEW
2045
                        )
×
NEW
2046
                        if err != nil {
×
NEW
2047
                                return fmt.Errorf("unable to build channel "+
×
NEW
2048
                                        "info and policies: %w", err)
×
NEW
2049
                        }
×
2050

NEW
2051
                        edges = append(edges, ChannelEdge{
×
NEW
2052
                                Info:    edge,
×
NEW
2053
                                Policy1: p1,
×
NEW
2054
                                Policy2: p2,
×
NEW
2055
                                Node1:   node1,
×
NEW
2056
                                Node2:   node2,
×
NEW
2057
                        })
×
2058
                }
2059

NEW
2060
                return nil
×
NEW
2061
        }, func() {})
×
NEW
2062
        if err != nil {
×
NEW
2063
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
2064
        }
×
2065

NEW
2066
        return edges, nil
×
2067
}
2068

2069
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2070
// ID's that we don't know and are not known zombies of the passed set. In other
2071
// words, we perform a set difference of our set of chan ID's and the ones
2072
// passed in. This method can be used by callers to determine the set of
2073
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2074
// known zombies is also returned.
2075
//
2076
// NOTE: part of the V1Store interface.
2077
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
NEW
2078
        []ChannelUpdateInfo, error) {
×
NEW
2079

×
NEW
2080
        var (
×
NEW
2081
                ctx          = context.TODO()
×
NEW
2082
                newChanIDs   []uint64
×
NEW
2083
                knownZombies []ChannelUpdateInfo
×
NEW
2084
        )
×
NEW
2085
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2086
                for _, chanInfo := range chansInfo {
×
NEW
2087
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
NEW
2088
                        var chanIDB [8]byte
×
NEW
2089
                        byteOrder.PutUint64(chanIDB[:], channelID)
×
NEW
2090

×
NEW
2091
                        _, err := db.GetChannelBySCID(
×
NEW
2092
                                ctx, sqlc.GetChannelBySCIDParams{
×
NEW
2093
                                        Version: int16(ProtocolV1),
×
NEW
2094
                                        Scid:    chanIDB[:],
×
NEW
2095
                                },
×
NEW
2096
                        )
×
NEW
2097
                        if err == nil {
×
NEW
2098
                                continue
×
NEW
2099
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
NEW
2100
                                return fmt.Errorf("unable to fetch channel: %w",
×
NEW
2101
                                        err)
×
NEW
2102
                        }
×
2103

NEW
2104
                        isZombie, err := db.IsZombieChannel(
×
NEW
2105
                                ctx, sqlc.IsZombieChannelParams{
×
NEW
2106
                                        Scid:    int64(channelID),
×
NEW
2107
                                        Version: int16(ProtocolV1),
×
NEW
2108
                                },
×
NEW
2109
                        )
×
NEW
2110
                        if err != nil {
×
NEW
2111
                                return fmt.Errorf("unable to fetch zombie "+
×
NEW
2112
                                        "channel: %w", err)
×
NEW
2113
                        }
×
2114

NEW
2115
                        if isZombie {
×
NEW
2116
                                knownZombies = append(knownZombies, chanInfo)
×
NEW
2117

×
NEW
2118
                                continue
×
2119
                        }
2120

NEW
2121
                        newChanIDs = append(newChanIDs, channelID)
×
2122
                }
2123

NEW
2124
                return nil
×
NEW
2125
        }, func() {})
×
NEW
2126
        if err != nil {
×
NEW
2127
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
2128
        }
×
2129

NEW
2130
        return newChanIDs, knownZombies, nil
×
2131
}
2132

2133
// PruneGraphNodes is a garbage collection method which attempts to prune out
2134
// any nodes from the channel graph that are currently unconnected. This ensure
2135
// that we only maintain a graph of reachable nodes. In the event that a pruned
2136
// node gains more channels, it will be re-added back to the graph.
2137
//
2138
// NOTE: part of the V1Store interface.
NEW
2139
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
NEW
2140
        var ctx = context.TODO()
×
NEW
2141

×
NEW
2142
        var prunedNodes []route.Vertex
×
NEW
2143
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2144
                var err error
×
NEW
2145
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
NEW
2146

×
NEW
2147
                return err
×
NEW
2148
        }, func() {})
×
NEW
2149
        if err != nil {
×
NEW
2150
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
NEW
2151
        }
×
2152

NEW
2153
        return prunedNodes, nil
×
2154
}
2155

2156
// PruneGraph prunes newly closed channels from the channel graph in response
2157
// to a new block being solved on the network. Any transactions which spend the
2158
// funding output of any known channels within he graph will be deleted.
2159
// Additionally, the "prune tip", or the last block which has been used to
2160
// prune the graph is stored so callers can ensure the graph is fully in sync
2161
// with the current UTXO state. A slice of channels that have been closed by
2162
// the target block along with any pruned nodes are returned if the function
2163
// succeeds without error.
2164
//
2165
// NOTE: part of the V1Store interface.
2166
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2167
        blockHash *chainhash.Hash, blockHeight uint32) (
NEW
2168
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
NEW
2169

×
NEW
2170
        ctx := context.TODO()
×
NEW
2171

×
NEW
2172
        s.cacheMu.Lock()
×
NEW
2173
        defer s.cacheMu.Unlock()
×
NEW
2174

×
NEW
2175
        var (
×
NEW
2176
                closedChans []*models.ChannelEdgeInfo
×
NEW
2177
                prunedNodes []route.Vertex
×
NEW
2178
        )
×
NEW
2179
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2180
                for _, outpoint := range spentOutputs {
×
NEW
2181
                        row, err := db.GetChannelByOutpoint(
×
NEW
2182
                                ctx, sqlc.GetChannelByOutpointParams{
×
NEW
2183
                                        Outpoint: outpoint.String(),
×
NEW
2184
                                        Version:  int16(ProtocolV1),
×
NEW
2185
                                },
×
NEW
2186
                        )
×
NEW
2187
                        if errors.Is(err, sql.ErrNoRows) {
×
NEW
2188
                                continue
×
NEW
2189
                        } else if err != nil {
×
NEW
2190
                                return fmt.Errorf("unable to fetch channel: %w",
×
NEW
2191
                                        err)
×
NEW
2192
                        }
×
2193

NEW
2194
                        node1, node2, err := buildNodeVertices(
×
NEW
2195
                                row.Node1Pubkey, row.Node2Pubkey,
×
NEW
2196
                        )
×
NEW
2197
                        if err != nil {
×
NEW
2198
                                return err
×
NEW
2199
                        }
×
2200

NEW
2201
                        info, err := getAndBuildEdgeInfo(
×
NEW
2202
                                ctx, db, s.cfg.ChainHash, row.ID, row,
×
NEW
2203
                                node1, node2,
×
NEW
2204
                        )
×
NEW
2205
                        if err != nil {
×
NEW
2206
                                return err
×
NEW
2207
                        }
×
2208

NEW
2209
                        err = db.DeleteChannel(ctx, row.ID)
×
NEW
2210
                        if err != nil {
×
NEW
2211
                                return fmt.Errorf("unable to delete "+
×
NEW
2212
                                        "channel: %w", err)
×
NEW
2213
                        }
×
2214

NEW
2215
                        closedChans = append(closedChans, info)
×
2216
                }
2217

NEW
2218
                err := db.UpsertPruneLogEntry(
×
NEW
2219
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
NEW
2220
                                BlockHash:   blockHash[:],
×
NEW
2221
                                BlockHeight: int64(blockHeight),
×
NEW
2222
                        },
×
NEW
2223
                )
×
NEW
2224
                if err != nil {
×
NEW
2225
                        return fmt.Errorf("unable to insert prune log "+
×
NEW
2226
                                "entry: %w", err)
×
NEW
2227
                }
×
2228

2229
                // Now that we've pruned some channels, we'll also prune any
2230
                // nodes that no longer have any channels.
NEW
2231
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
NEW
2232
                if err != nil {
×
NEW
2233
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
NEW
2234
                                err)
×
NEW
2235
                }
×
2236

NEW
2237
                return nil
×
NEW
2238
        }, func() {
×
NEW
2239
                prunedNodes = nil
×
NEW
2240
                closedChans = nil
×
NEW
2241
        })
×
NEW
2242
        if err != nil {
×
NEW
2243
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
NEW
2244
        }
×
2245

NEW
2246
        for _, channel := range closedChans {
×
NEW
2247
                s.rejectCache.remove(channel.ChannelID)
×
NEW
2248
                s.chanCache.remove(channel.ChannelID)
×
NEW
2249
        }
×
2250

NEW
2251
        return closedChans, prunedNodes, nil
×
2252
}
2253

2254
// ChannelView returns the verifiable edge information for each active channel
2255
// within the known channel graph. The set of UTXO's (along with their scripts)
2256
// returned are the ones that need to be watched on chain to detect channel
2257
// closes on the resident blockchain.
2258
//
2259
// NOTE: part of the V1Store interface.
NEW
2260
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
NEW
2261
        var (
×
NEW
2262
                ctx        = context.TODO()
×
NEW
2263
                edgePoints []EdgePoint
×
NEW
2264
        )
×
NEW
2265
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2266
                dbChannel, err := db.ListAllChannels(ctx, int16(ProtocolV1))
×
NEW
2267
                if err != nil {
×
NEW
2268
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
2269
                }
×
2270

NEW
2271
                for _, dbChan := range dbChannel {
×
NEW
2272
                        if dbChan.BitcoinKey1 == nil {
×
NEW
2273
                                continue
×
2274
                        }
2275

NEW
2276
                        pkScript, err := genMultiSigP2WSH(
×
NEW
2277
                                dbChan.BitcoinKey1, dbChan.BitcoinKey2,
×
NEW
2278
                        )
×
NEW
2279
                        if err != nil {
×
NEW
2280
                                return err
×
NEW
2281
                        }
×
2282

NEW
2283
                        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
NEW
2284
                        if err != nil {
×
NEW
2285
                                return err
×
NEW
2286
                        }
×
2287

NEW
2288
                        edgePoints = append(edgePoints, EdgePoint{
×
NEW
2289
                                FundingPkScript: pkScript,
×
NEW
2290
                                OutPoint:        *op,
×
NEW
2291
                        })
×
2292
                }
2293

NEW
2294
                return nil
×
NEW
2295
        }, func() {})
×
NEW
2296
        if err != nil {
×
NEW
2297
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
NEW
2298
        }
×
2299

NEW
2300
        return edgePoints, nil
×
2301
}
2302

2303
// PruneTip returns the block height and hash of the latest block that has been
2304
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2305
// to tell if the graph is currently in sync with the current best known UTXO
2306
// state.
2307
//
2308
// NOTE: part of the V1Store interface.
NEW
2309
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
NEW
2310
        var (
×
NEW
2311
                ctx       = context.TODO()
×
NEW
2312
                tipHash   chainhash.Hash
×
NEW
2313
                tipHeight uint32
×
NEW
2314
        )
×
NEW
2315
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2316
                pruneTip, err := db.GetPruneTip(ctx)
×
NEW
2317
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
2318
                        return ErrGraphNeverPruned
×
NEW
2319
                } else if err != nil {
×
NEW
2320
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
NEW
2321
                }
×
2322

NEW
2323
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
NEW
2324
                tipHeight = uint32(pruneTip.BlockHeight)
×
NEW
2325

×
NEW
2326
                return nil
×
NEW
2327
        }, func() {})
×
NEW
2328
        if err != nil {
×
NEW
2329
                return nil, 0, err
×
NEW
2330
        }
×
2331

NEW
2332
        return &tipHash, tipHeight, nil
×
2333
}
2334

2335
// pruneGraphNodes deletes any node in the DB that doesn't have a channel
2336
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
NEW
2337
        db SQLQueries) ([]route.Vertex, error) {
×
NEW
2338

×
NEW
2339
        nodes, err := db.GetUnconnectedNodes(ctx)
×
NEW
2340
        if err != nil {
×
NEW
2341
                return nil, fmt.Errorf("unable to fetch unconnected nodes: %w",
×
NEW
2342
                        err)
×
NEW
2343
        }
×
2344

NEW
2345
        nodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
NEW
2346
        if err != nil {
×
NEW
2347
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
NEW
2348
        }
×
2349

NEW
2350
        prunedNodes := make([]route.Vertex, 0, len(nodes))
×
NEW
2351
        for _, node := range nodes {
×
NEW
2352
                // Don't delete the source node.
×
NEW
2353
                if node.ID == nodeID {
×
NEW
2354
                        continue
×
2355
                }
2356

NEW
2357
                _, err = db.DeleteNodeByPubKey(
×
NEW
2358
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
NEW
2359
                                PubKey:  node.PubKey,
×
NEW
2360
                                Version: int16(ProtocolV1),
×
NEW
2361
                        },
×
NEW
2362
                )
×
NEW
2363
                if err != nil {
×
NEW
2364
                        return nil, fmt.Errorf("unable to delete node: %w", err)
×
NEW
2365
                }
×
2366

NEW
2367
                var pubKey route.Vertex
×
NEW
2368
                copy(pubKey[:], node.PubKey)
×
NEW
2369
                prunedNodes = append(prunedNodes, pubKey)
×
2370
        }
2371

NEW
2372
        return prunedNodes, nil
×
2373
}
2374

2375
// DisconnectBlockAtHeight is used to indicate that the block specified
2376
// by the passed height has been disconnected from the main chain. This
2377
// will "rewind" the graph back to the height below, deleting channels
2378
// that are no longer confirmed from the graph. The prune log will be
2379
// set to the last prune height valid for the remaining chain.
2380
// Channels that were removed from the graph resulting from the
2381
// disconnected block are returned.
2382
//
2383
// NOTE: part of the V1Store interface.
2384
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
NEW
2385
        []*models.ChannelEdgeInfo, error) {
×
NEW
2386

×
NEW
2387
        ctx := context.TODO()
×
NEW
2388

×
NEW
2389
        var (
×
NEW
2390
                // Every channel having a ShortChannelID starting at 'height'
×
NEW
2391
                // will no longer be confirmed.
×
NEW
2392
                startShortChanID = lnwire.ShortChannelID{
×
NEW
2393
                        BlockHeight: height,
×
NEW
2394
                }
×
NEW
2395

×
NEW
2396
                // Delete everything after this height from the db up until the
×
NEW
2397
                // SCID alias range.
×
NEW
2398
                endShortChanID = aliasmgr.StartingAlias
×
NEW
2399

×
NEW
2400
                removedChans []*models.ChannelEdgeInfo
×
NEW
2401
        )
×
NEW
2402

×
NEW
2403
        var chanIDStart [8]byte
×
NEW
2404
        byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64())
×
NEW
2405
        var chanIDEnd [8]byte
×
NEW
2406
        byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64())
×
NEW
2407

×
NEW
2408
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2409
                rows, err := db.GetChannelsBySCIDRange(
×
NEW
2410
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
NEW
2411
                                StartScid: chanIDStart[:],
×
NEW
2412
                                EndScid:   chanIDEnd[:],
×
NEW
2413
                        },
×
NEW
2414
                )
×
NEW
2415
                if err != nil {
×
NEW
2416
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
2417
                }
×
2418

NEW
2419
                for _, row := range rows {
×
NEW
2420
                        node1, node2, err := buildNodeVertices(
×
NEW
2421
                                row.Node1PubKey, row.Node2PubKey,
×
NEW
2422
                        )
×
NEW
2423
                        if err != nil {
×
NEW
2424
                                return err
×
NEW
2425
                        }
×
2426

NEW
2427
                        channel, err := getAndBuildEdgeInfo(
×
NEW
2428
                                ctx, db, s.cfg.ChainHash, row.ID, row,
×
NEW
2429
                                node1, node2,
×
NEW
2430
                        )
×
NEW
2431
                        if err != nil {
×
NEW
2432
                                return err
×
NEW
2433
                        }
×
2434

NEW
2435
                        err = db.DeleteChannel(ctx, row.ID)
×
NEW
2436
                        if err != nil {
×
NEW
2437
                                return fmt.Errorf("unable to delete "+
×
NEW
2438
                                        "channel: %w", err)
×
NEW
2439
                        }
×
2440

NEW
2441
                        removedChans = append(removedChans, channel)
×
2442

2443
                }
2444

NEW
2445
                return db.DeletePruneLogEntriesInRange(
×
NEW
2446
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
NEW
2447
                                StartHeight: int64(height),
×
NEW
2448
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
NEW
2449
                        },
×
NEW
2450
                )
×
NEW
2451
        }, func() {
×
NEW
2452
                removedChans = nil
×
NEW
2453
        })
×
NEW
2454
        if err != nil {
×
NEW
2455
                return nil, fmt.Errorf("unable to disconnect block at "+
×
NEW
2456
                        "height: %w", err)
×
NEW
2457
        }
×
2458

NEW
2459
        for _, channel := range removedChans {
×
NEW
2460
                s.rejectCache.remove(channel.ChannelID)
×
NEW
2461
                s.chanCache.remove(channel.ChannelID)
×
NEW
2462
        }
×
2463

NEW
2464
        return removedChans, nil
×
2465
}
2466

2467
// AddEdgeProof sets the proof of an existing edge in the graph database.
2468
//
2469
// NOTE: part of the V1Store interface.
2470
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
NEW
2471
        proof *models.ChannelAuthProof) error {
×
NEW
2472

×
NEW
2473
        var (
×
NEW
2474
                ctx       = context.TODO()
×
NEW
2475
                scidBytes [8]byte
×
NEW
2476
        )
×
NEW
2477
        byteOrder.PutUint64(scidBytes[:], scid.ToUint64())
×
NEW
2478

×
NEW
2479
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2480
                dbChan, err := db.GetChannelBySCID(
×
NEW
2481
                        ctx, sqlc.GetChannelBySCIDParams{
×
NEW
2482
                                Scid:    scidBytes[:],
×
NEW
2483
                                Version: int16(ProtocolV1),
×
NEW
2484
                        },
×
NEW
2485
                )
×
NEW
2486
                if err != nil {
×
NEW
2487
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
NEW
2488
                }
×
2489

NEW
2490
                return db.AddV1ChannelProof(
×
NEW
2491
                        ctx, sqlc.AddV1ChannelProofParams{
×
NEW
2492
                                ID:                dbChan.ID,
×
NEW
2493
                                Node1Signature:    proof.NodeSig1Bytes,
×
NEW
2494
                                Node2Signature:    proof.NodeSig2Bytes,
×
NEW
2495
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
NEW
2496
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
NEW
2497
                        },
×
NEW
2498
                )
×
NEW
2499
        }, func() {})
×
NEW
2500
        if err != nil {
×
NEW
2501
                return fmt.Errorf("unable to add edge proof: %w", err)
×
NEW
2502
        }
×
2503

NEW
2504
        return nil
×
2505
}
2506

2507
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2508
// that we can ignore channel announcements that we know to be closed without
2509
// having to validate them and fetch a block.
2510
//
2511
// NOTE: part of the V1Store interface.
NEW
2512
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
NEW
2513
        ctx := context.TODO()
×
NEW
2514

×
NEW
2515
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2516
                var chanIDB [8]byte
×
NEW
2517
                byteOrder.PutUint64(chanIDB[:], scid.ToUint64())
×
NEW
2518

×
NEW
2519
                return db.InsertClosedChannel(ctx, chanIDB[:])
×
NEW
2520
        }, func() {})
×
2521
}
2522

2523
// IsClosedScid checks whether a channel identified by the passed in scid is
2524
// closed. This helps avoid having to perform expensive validation checks.
2525
//
2526
// NOTE: part of the V1Store interface.
NEW
2527
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
NEW
2528
        var (
×
NEW
2529
                ctx      = context.TODO()
×
NEW
2530
                isClosed bool
×
NEW
2531
        )
×
NEW
2532
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2533
                var chanIDB [8]byte
×
NEW
2534
                byteOrder.PutUint64(chanIDB[:], scid.ToUint64())
×
NEW
2535
                var err error
×
NEW
2536
                isClosed, err = db.IsClosedChannel(ctx, chanIDB[:])
×
NEW
2537
                if err != nil {
×
NEW
2538
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
NEW
2539
                                err)
×
NEW
2540
                }
×
2541

NEW
2542
                return nil
×
NEW
2543
        }, func() {})
×
NEW
2544
        if err != nil {
×
NEW
2545
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
NEW
2546
                        err)
×
NEW
2547
        }
×
2548

NEW
2549
        return isClosed, nil
×
2550
}
2551

2552
// GraphSession will provide the call-back with access to a NodeTraverser
2553
// instance which can be used to perform queries against the channel graph.
2554
//
2555
// NOTE: part of the V1Store interface.
NEW
2556
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error) error {
×
NEW
2557
        var ctx = context.TODO()
×
NEW
2558

×
NEW
2559
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2560
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
NEW
2561
        }, func() {})
×
2562
}
2563

2564
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2565
// read only transaction for a consistent view of the graph.
2566
type sqlNodeTraverser struct {
2567
        db    SQLQueries
2568
        chain chainhash.Hash
2569
}
2570

2571
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2572
// NodeTraverser interface.
2573
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2574

2575
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2576
func newSQLNodeTraverser(db SQLQueries,
NEW
2577
        chain chainhash.Hash) *sqlNodeTraverser {
×
NEW
2578

×
NEW
2579
        return &sqlNodeTraverser{
×
NEW
2580
                db:    db,
×
NEW
2581
                chain: chain,
×
NEW
2582
        }
×
NEW
2583
}
×
2584

2585
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2586
// node.
2587
//
2588
// NOTE: Part of the NodeTraverser interface.
2589
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
NEW
2590
        cb func(channel *DirectedChannel) error) error {
×
NEW
2591

×
NEW
2592
        ctx := context.TODO()
×
NEW
2593

×
NEW
2594
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
NEW
2595
}
×
2596

2597
// FetchNodeFeatures returns the features of the given node. If the node is
2598
// unknown, assume no additional features are supported.
2599
//
2600
// NOTE: Part of the NodeTraverser interface.
2601
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
NEW
2602
        *lnwire.FeatureVector, error) {
×
NEW
2603

×
NEW
2604
        ctx := context.TODO()
×
NEW
2605

×
NEW
2606
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
NEW
2607
}
×
2608

2609
// forEachNodeDirectedChannel iterates through all channels of a given
2610
// node, executing the passed callback on the directed edge representing the
2611
// channel and its incoming policy. If the node is not found, no error is
2612
// returned.
2613
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
NEW
2614
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
NEW
2615

×
NEW
2616
        toNodeCallback := func() route.Vertex {
×
NEW
2617
                return nodePub
×
NEW
2618
        }
×
2619

NEW
2620
        dbNode, err := db.GetNodeByPubKey(
×
NEW
2621
                ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
2622
                        Version: int16(ProtocolV1),
×
NEW
2623
                        PubKey:  nodePub[:],
×
NEW
2624
                },
×
NEW
2625
        )
×
NEW
2626
        if errors.Is(err, sql.ErrNoRows) {
×
NEW
2627
                return nil
×
NEW
2628
        } else if err != nil {
×
NEW
2629
                return fmt.Errorf("unable to fetch node: %w", err)
×
NEW
2630
        }
×
2631

NEW
2632
        features, err := getNodeFeatures(ctx, db, dbNode.ID)
×
NEW
2633
        if err != nil {
×
NEW
2634
                return fmt.Errorf("unable to fetch node features: %w", err)
×
NEW
2635
        }
×
2636

NEW
2637
        rows, err := db.ListChannelsByNodeID(
×
NEW
2638
                ctx, sqlc.ListChannelsByNodeIDParams{
×
NEW
2639
                        Version: int16(ProtocolV1),
×
NEW
2640
                        NodeID1: dbNode.ID,
×
NEW
2641
                },
×
NEW
2642
        )
×
NEW
2643
        if err != nil {
×
NEW
2644
                return fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
2645
        }
×
2646

NEW
2647
        for _, row := range rows {
×
NEW
2648
                node1, node2, err := buildNodeVertices(
×
NEW
2649
                        row.Node1Pubkey, row.Node2Pubkey,
×
NEW
2650
                )
×
NEW
2651
                if err != nil {
×
NEW
2652
                        return fmt.Errorf("unable to build node vertices: %w",
×
NEW
2653
                                err)
×
NEW
2654
                }
×
2655

NEW
2656
                edge, err := buildCacheableChannelInfo(row, node1, node2)
×
NEW
2657
                if err != nil {
×
NEW
2658
                        return err
×
NEW
2659
                }
×
2660

NEW
2661
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
NEW
2662
                if err != nil {
×
NEW
2663
                        return err
×
NEW
2664
                }
×
2665

NEW
2666
                var p1, p2 *models.CachedEdgePolicy
×
NEW
2667
                if dbPol1 != nil {
×
NEW
2668
                        policy1, err := buildChanPolicy(
×
NEW
2669
                                *dbPol1, edge.ChannelID, nil, node2, true,
×
NEW
2670
                        )
×
NEW
2671
                        if err != nil {
×
NEW
2672
                                return err
×
NEW
2673
                        }
×
2674

NEW
2675
                        p1 = models.NewCachedPolicy(policy1)
×
2676
                }
NEW
2677
                if dbPol2 != nil {
×
NEW
2678
                        policy2, err := buildChanPolicy(
×
NEW
2679
                                *dbPol2, edge.ChannelID, nil, node1, false,
×
NEW
2680
                        )
×
NEW
2681
                        if err != nil {
×
NEW
2682
                                return err
×
NEW
2683
                        }
×
2684

NEW
2685
                        p2 = models.NewCachedPolicy(policy2)
×
2686
                }
2687

2688
                // Determine the outgoing and incoming policy for this
2689
                // channel and node combo.
NEW
2690
                outPolicy, inPolicy := p1, p2
×
NEW
2691
                if p1 != nil && node2 == nodePub {
×
NEW
2692
                        outPolicy, inPolicy = p2, p1
×
NEW
2693
                } else if p2 != nil && node1 != nodePub {
×
NEW
2694
                        outPolicy, inPolicy = p2, p1
×
NEW
2695
                }
×
2696

NEW
2697
                var cachedInPolicy *models.CachedEdgePolicy
×
NEW
2698
                if inPolicy != nil {
×
NEW
2699
                        cachedInPolicy = inPolicy
×
NEW
2700
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
NEW
2701
                        cachedInPolicy.ToNodeFeatures = features
×
NEW
2702
                }
×
2703

NEW
2704
                directedChannel := &DirectedChannel{
×
NEW
2705
                        ChannelID:    edge.ChannelID,
×
NEW
2706
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
NEW
2707
                        OtherNode:    edge.NodeKey2Bytes,
×
NEW
2708
                        Capacity:     edge.Capacity,
×
NEW
2709
                        OutPolicySet: outPolicy != nil,
×
NEW
2710
                        InPolicy:     cachedInPolicy,
×
NEW
2711
                }
×
NEW
2712
                if outPolicy != nil {
×
NEW
2713
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
NEW
2714
                                directedChannel.InboundFee = fee
×
NEW
2715
                        })
×
2716
                }
2717

NEW
2718
                if nodePub == edge.NodeKey2Bytes {
×
NEW
2719
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
NEW
2720
                }
×
2721

NEW
2722
                if err := cb(directedChannel); err != nil {
×
NEW
2723
                        return err
×
NEW
2724
                }
×
2725
        }
2726

NEW
2727
        return nil
×
2728
}
2729

2730
// forEachNode fetches all V2 nodes from the database, and executes the
2731
// provided callback for each node. The callback is provided with the node's
2732
// DB-assigned ID and public key.
2733
func forEachNode(ctx context.Context, db SQLQueries,
NEW
2734
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
NEW
2735

×
NEW
2736
        nodes, err := db.ListNodeIDsAndPubKeys(ctx, int16(ProtocolV1))
×
NEW
2737
        if err != nil {
×
NEW
2738
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
NEW
2739
        }
×
2740

NEW
2741
        for _, node := range nodes {
×
NEW
2742
                var pub route.Vertex
×
NEW
2743
                copy(pub[:], node.PubKey)
×
NEW
2744

×
NEW
2745
                if err := cb(node.ID, pub); err != nil {
×
NEW
2746
                        return fmt.Errorf("callback failed: %w", err)
×
NEW
2747
                }
×
2748
        }
2749

NEW
2750
        return nil
×
2751
}
2752

2753
// forEachNodeChannel iterates through all channels of a node, executing
2754
// the passed callback on each. The call-back is provided with the channel's
2755
// edge information, the outgoing policy and the incoming policy for the
2756
// channel and node combo.
2757
func forEachNodeChannel(ctx context.Context, db SQLQueries,
2758
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
2759
                *models.ChannelEdgePolicy,
NEW
2760
                *models.ChannelEdgePolicy) error) error {
×
NEW
2761

×
NEW
2762
        // Get all the V1 channels for this node.
×
NEW
2763
        rows, err := db.ListChannelsByNodeID(
×
NEW
2764
                ctx, sqlc.ListChannelsByNodeIDParams{
×
NEW
2765
                        Version: int16(ProtocolV1),
×
NEW
2766
                        NodeID1: id,
×
NEW
2767
                },
×
NEW
2768
        )
×
NEW
2769
        if err != nil {
×
NEW
2770
                return fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
2771
        }
×
2772

2773
        // Call the call-back for each channel and its known policies.
NEW
2774
        for _, row := range rows {
×
NEW
2775
                node1, node2, err := buildNodeVertices(
×
NEW
2776
                        row.Node1Pubkey, row.Node2Pubkey,
×
NEW
2777
                )
×
NEW
2778
                if err != nil {
×
NEW
2779
                        return fmt.Errorf("unable to build node vertices: %w",
×
NEW
2780
                                err)
×
NEW
2781
                }
×
2782

NEW
2783
                edge, p1, p2, err := getAndBuildEdgeInfoAndPolicies(
×
NEW
2784
                        ctx, db, chain, row.ID, row, node1, node2,
×
NEW
2785
                )
×
NEW
2786
                if err != nil {
×
NEW
2787
                        return fmt.Errorf("unable to build channel "+
×
NEW
2788
                                "info and policies: %w", err)
×
NEW
2789
                }
×
2790

2791
                // Determine the outgoing and incoming policy for this
2792
                // channel and node combo.
NEW
2793
                p1ToNode := row.NodeID2
×
NEW
2794
                p2ToNode := row.NodeID1
×
NEW
2795
                outPolicy, inPolicy := p1, p2
×
NEW
2796
                if (p1 != nil && p1ToNode == id) ||
×
NEW
2797
                        (p2 != nil && p2ToNode != id) {
×
NEW
2798

×
NEW
2799
                        outPolicy, inPolicy = p2, p1
×
NEW
2800
                }
×
2801

NEW
2802
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
NEW
2803
                        return err
×
NEW
2804
                }
×
2805
        }
2806

NEW
2807
        return nil
×
2808
}
2809

2810
// updateChanEdgePolicy upserts the channel policy info we have stored for
2811
// a channel we already know of.
2812
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
2813
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
NEW
2814
        error) {
×
NEW
2815

×
NEW
2816
        var (
×
NEW
2817
                node1Pub, node2Pub route.Vertex
×
NEW
2818
                isNode1            bool
×
NEW
2819
                chanIDB            [8]byte
×
NEW
2820
        )
×
NEW
2821
        byteOrder.PutUint64(chanIDB[:], edge.ChannelID)
×
NEW
2822

×
NEW
2823
        // Check that this edge policy refers to a channel that we already
×
NEW
2824
        // know of. We do this explicitly so that we can return the appropriate
×
NEW
2825
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
NEW
2826
        // abort the transaction which would abort the entire batch.
×
NEW
2827
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
NEW
2828
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
NEW
2829
                        Scid:    chanIDB[:],
×
NEW
2830
                        Version: int16(ProtocolV1),
×
NEW
2831
                },
×
NEW
2832
        )
×
NEW
2833
        if errors.Is(err, sql.ErrNoRows) {
×
NEW
2834
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
NEW
2835
        } else if err != nil {
×
NEW
2836
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
NEW
2837
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
NEW
2838
        }
×
2839

NEW
2840
        copy(node1Pub[:], dbChan.Node1PubKey)
×
NEW
2841
        copy(node2Pub[:], dbChan.Node2PubKey)
×
NEW
2842

×
NEW
2843
        // Figure out which node this edge is from.
×
NEW
2844
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
NEW
2845
        nodeID := dbChan.NodeID1
×
NEW
2846
        if !isNode1 {
×
NEW
2847
                nodeID = dbChan.NodeID2
×
NEW
2848
        }
×
2849

NEW
2850
        var (
×
NEW
2851
                inboundBase sql.NullInt64
×
NEW
2852
                inboundRate sql.NullInt64
×
NEW
2853
        )
×
NEW
2854
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
NEW
2855
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
NEW
2856
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
NEW
2857
        })
×
2858

NEW
2859
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
NEW
2860
                Version:     int16(ProtocolV1),
×
NEW
2861
                ChannelID:   dbChan.ID,
×
NEW
2862
                NodeID:      nodeID,
×
NEW
2863
                Timelock:    int32(edge.TimeLockDelta),
×
NEW
2864
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
NEW
2865
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
NEW
2866
                MinHtlcMsat: int64(edge.MinHTLC),
×
NEW
2867
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
NEW
2868
                Disabled: sql.NullBool{
×
NEW
2869
                        Valid: true,
×
NEW
2870
                        Bool:  edge.IsDisabled(),
×
NEW
2871
                },
×
NEW
2872
                MaxHtlcMsat: sql.NullInt64{
×
NEW
2873
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
NEW
2874
                        Int64: int64(edge.MaxHTLC),
×
NEW
2875
                },
×
NEW
2876
                InboundBaseFeeMsat:      inboundBase,
×
NEW
2877
                InboundFeeRateMilliMsat: inboundRate,
×
NEW
2878
                Signature:               edge.SigBytes,
×
NEW
2879
        })
×
2880
        if err != nil {
×
NEW
2881
                return node1Pub, node2Pub, isNode1,
×
NEW
2882
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
2883
        }
×
2884

2885
        // Convert the flat extra opaque data into a map of TLV types to
2886
        // values.
UNCOV
2887
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
NEW
2888
        if err != nil {
×
UNCOV
2889
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
UNCOV
2890
                        "marshal extra opaque data: %w", err)
×
UNCOV
2891
        }
×
2892

2893
        // Update the channel policy's extra signed fields.
2894
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
2895
        if err != nil {
×
2896
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
2897
                        "policy extra TLVs: %w", err)
×
2898
        }
×
2899

2900
        return node1Pub, node2Pub, isNode1, nil
×
2901
}
2902

2903
// getNodeByPubKey attempts to look up a target node by its public key.
2904
func getNodeByPubKey(ctx context.Context, db SQLQueries,
2905
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
UNCOV
2906

×
2907
        dbNode, err := db.GetNodeByPubKey(
×
2908
                ctx, sqlc.GetNodeByPubKeyParams{
×
2909
                        Version: int16(ProtocolV1),
×
2910
                        PubKey:  pubKey[:],
×
UNCOV
2911
                },
×
2912
        )
×
UNCOV
2913
        if errors.Is(err, sql.ErrNoRows) {
×
UNCOV
2914
                return 0, nil, ErrGraphNodeNotFound
×
NEW
2915
        } else if err != nil {
×
NEW
2916
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
NEW
2917
        }
×
2918

NEW
2919
        node, err := buildNode(ctx, db, &dbNode)
×
NEW
2920
        if err != nil {
×
NEW
2921
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
NEW
2922
        }
×
2923

NEW
2924
        return dbNode.ID, node, nil
×
2925
}
2926

2927
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
2928
// provided database channel row and the public keys of the two nodes
2929
// involved in the channel.
2930
func buildCacheableChannelInfo(row any, node1Pub,
NEW
2931
        node2Pub route.Vertex) (*models.CachedEdgeInfo, error) {
×
NEW
2932

×
NEW
2933
        dbChan, err := extractChannel(row)
×
UNCOV
2934
        if err != nil {
×
UNCOV
2935
                return nil, err
×
UNCOV
2936
        }
×
2937

2938
        return &models.CachedEdgeInfo{
×
2939
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
2940
                NodeKey1Bytes: node1Pub,
×
2941
                NodeKey2Bytes: node2Pub,
×
2942
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
2943
        }, nil
×
2944
}
2945

2946
// buildNode constructs a LightningNode instance from the given database node
2947
// record. The node's features, addresses and extra signed fields are also
2948
// fetched from the database and set on the node.
2949
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
2950
        *models.LightningNode, error) {
×
2951

×
2952
        if dbNode.Version != int16(ProtocolV1) {
×
2953
                return nil, fmt.Errorf("unsupported node version: %d",
×
2954
                        dbNode.Version)
×
2955
        }
×
2956

UNCOV
2957
        var pub [33]byte
×
2958
        copy(pub[:], dbNode.PubKey)
×
2959

×
2960
        node := &models.LightningNode{
×
2961
                PubKeyBytes: pub,
×
2962
                Features:    lnwire.EmptyFeatureVector(),
×
2963
                LastUpdate:  time.Unix(0, 0),
×
NEW
2964
        }
×
NEW
2965

×
NEW
2966
        if len(dbNode.Signature) == 0 {
×
NEW
2967
                return node, nil
×
NEW
2968
        }
×
2969

UNCOV
2970
        node.HaveNodeAnnouncement = true
×
UNCOV
2971
        node.AuthSigBytes = dbNode.Signature
×
UNCOV
2972
        node.Alias = dbNode.Alias.String
×
2973
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
2974

×
2975
        var err error
×
2976
        if dbNode.Color.Valid {
×
2977
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
UNCOV
2978
                if err != nil {
×
UNCOV
2979
                        return nil, fmt.Errorf("unable to decode color: %w", err)
×
2980
                }
×
2981
        }
2982

2983
        // Fetch the node's features.
2984
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
UNCOV
2985
        if err != nil {
×
UNCOV
2986
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
2987
                        "features: %w", dbNode.ID, err)
×
2988
        }
×
2989

2990
        // Fetch the node's addresses.
2991
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
UNCOV
2992
        if err != nil {
×
2993
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
2994
                        "addresses: %w", dbNode.ID, err)
×
2995
        }
×
2996

2997
        // Fetch the node's extra signed fields.
UNCOV
2998
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
2999
        if err != nil {
×
3000
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3001
                        "extra signed fields: %w", dbNode.ID, err)
×
UNCOV
3002
        }
×
3003

UNCOV
3004
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
UNCOV
3005
        if err != nil {
×
UNCOV
3006
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
UNCOV
3007
                        "fields: %w", err)
×
UNCOV
3008
        }
×
3009

3010
        if len(recs) != 0 {
×
3011
                node.ExtraOpaqueData = recs
×
3012
        }
×
3013

3014
        return node, nil
×
3015
}
3016

3017
// getNodeFeatures fetches the feature bits and constructs the feature vector
3018
// for a node with the given DB ID.
3019
func getNodeFeatures(ctx context.Context, db SQLQueries,
3020
        nodeID int64) (*lnwire.FeatureVector, error) {
×
UNCOV
3021

×
3022
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
UNCOV
3023
        if err != nil {
×
UNCOV
3024
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
UNCOV
3025
                        nodeID, err)
×
UNCOV
3026
        }
×
3027

3028
        features := lnwire.EmptyFeatureVector()
×
3029
        for _, feature := range rows {
×
3030
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3031
        }
×
3032

3033
        return features, nil
×
3034
}
3035

3036
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3037
// given DB ID.
3038
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3039
        nodeID int64) (map[uint64][]byte, error) {
×
UNCOV
3040

×
3041
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
UNCOV
3042
        if err != nil {
×
UNCOV
3043
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
UNCOV
3044
                        "signed fields: %w", nodeID, err)
×
UNCOV
3045
        }
×
3046

UNCOV
3047
        extraFields := make(map[uint64][]byte)
×
UNCOV
3048
        for _, field := range fields {
×
3049
                extraFields[uint64(field.Type)] = field.Value
×
3050
        }
×
3051

3052
        return extraFields, nil
×
3053
}
3054

3055
// upsertNode upserts the node record into the database. If the node already
3056
// exists, then the node's information is updated. If the node doesn't exist,
3057
// then a new node is created. The node's features, addresses and extra TLV
3058
// types are also updated. The node's DB ID is returned.
3059
func upsertNode(ctx context.Context, db SQLQueries,
3060
        node *models.LightningNode) (int64, error) {
×
3061

×
UNCOV
3062
        params := sqlc.UpsertNodeParams{
×
3063
                Version: int16(ProtocolV1),
×
3064
                PubKey:  node.PubKeyBytes[:],
×
3065
        }
×
3066

×
3067
        if node.HaveNodeAnnouncement {
×
UNCOV
3068
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
UNCOV
3069
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3070
                params.Alias = sqldb.SQLStr(node.Alias)
×
3071
                params.Signature = node.AuthSigBytes
×
3072
        }
×
3073

UNCOV
3074
        nodeID, err := db.UpsertNode(ctx, params)
×
3075
        if err != nil {
×
3076
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3077
                        err)
×
3078
        }
×
3079

3080
        // We can exit here if we don't have the announcement yet.
3081
        if !node.HaveNodeAnnouncement {
×
3082
                return nodeID, nil
×
3083
        }
×
3084

3085
        // Update the node's features.
UNCOV
3086
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
UNCOV
3087
        if err != nil {
×
3088
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3089
        }
×
3090

3091
        // Update the node's addresses.
3092
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
UNCOV
3093
        if err != nil {
×
UNCOV
3094
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3095
        }
×
3096

3097
        // Convert the flat extra opaque data into a map of TLV types to
3098
        // values.
UNCOV
3099
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3100
        if err != nil {
×
UNCOV
3101
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
UNCOV
3102
                        err)
×
UNCOV
3103
        }
×
3104

3105
        // Update the node's extra signed fields.
UNCOV
3106
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
UNCOV
3107
        if err != nil {
×
3108
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3109
        }
×
3110

3111
        return nodeID, nil
×
3112
}
3113

3114
// upsertNodeFeatures updates the node's features node_features table. This
3115
// includes deleting any feature bits no longer present and inserting any new
3116
// feature bits. If the feature bit does not yet exist in the features table,
3117
// then an entry is created in that table first.
3118
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3119
        features *lnwire.FeatureVector) error {
×
3120

×
3121
        // Get any existing features for the node.
×
UNCOV
3122
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
UNCOV
3123
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
UNCOV
3124
                return err
×
UNCOV
3125
        }
×
3126

3127
        // Copy the nodes latest set of feature bits.
3128
        newFeatures := make(map[int32]struct{})
×
3129
        if features != nil {
×
3130
                for feature := range features.Features() {
×
3131
                        newFeatures[int32(feature)] = struct{}{}
×
3132
                }
×
3133
        }
3134

3135
        // For any current feature that already exists in the DB, remove it from
3136
        // the in-memory map. For any existing feature that does not exist in
3137
        // the in-memory map, delete it from the database.
3138
        for _, feature := range existingFeatures {
×
3139
                // The feature is still present, so there are no updates to be
×
3140
                // made.
×
3141
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3142
                        delete(newFeatures, feature.FeatureBit)
×
3143
                        continue
×
3144
                }
3145

3146
                // The feature is no longer present, so we remove it from the
3147
                // database.
UNCOV
3148
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
UNCOV
3149
                        NodeID:     nodeID,
×
3150
                        FeatureBit: feature.FeatureBit,
×
3151
                })
×
3152
                if err != nil {
×
3153
                        return fmt.Errorf("unable to delete node(%d) "+
×
3154
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3155
                                err)
×
3156
                }
×
3157
        }
3158

3159
        // Any remaining entries in newFeatures are new features that need to be
3160
        // added to the database for the first time.
3161
        for feature := range newFeatures {
×
UNCOV
3162
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
UNCOV
3163
                        NodeID:     nodeID,
×
UNCOV
3164
                        FeatureBit: feature,
×
UNCOV
3165
                })
×
3166
                if err != nil {
×
3167
                        return fmt.Errorf("unable to insert node(%d) "+
×
3168
                                "feature(%v): %w", nodeID, feature, err)
×
3169
                }
×
3170
        }
3171

3172
        return nil
×
3173
}
3174

3175
// fetchNodeFeatures fetches the features for a node with the given public key.
3176
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3177
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
UNCOV
3178

×
3179
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3180
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3181
                        PubKey:  nodePub[:],
×
3182
                        Version: int16(ProtocolV1),
×
UNCOV
3183
                },
×
3184
        )
×
UNCOV
3185
        if err != nil {
×
UNCOV
3186
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
UNCOV
3187
                        nodePub, err)
×
UNCOV
3188
        }
×
3189

UNCOV
3190
        features := lnwire.EmptyFeatureVector()
×
UNCOV
3191
        for _, bit := range rows {
×
UNCOV
3192
                features.Set(lnwire.FeatureBit(bit))
×
UNCOV
3193
        }
×
3194

UNCOV
3195
        return features, nil
×
3196
}
3197

3198
// dbAddressType is an enum type that represents the different address types
3199
// that we store in the node_addresses table. The address type determines how
3200
// the address is to be serialised/deserialize.
3201
type dbAddressType uint8
3202

3203
const (
3204
        addressTypeIPv4   dbAddressType = 1
3205
        addressTypeIPv6   dbAddressType = 2
3206
        addressTypeTorV2  dbAddressType = 3
3207
        addressTypeTorV3  dbAddressType = 4
3208
        addressTypeOpaque dbAddressType = math.MaxInt8
3209
)
3210

3211
// upsertNodeAddresses updates the node's addresses in the database. This
3212
// includes deleting any existing addresses and inserting the new set of
3213
// addresses. The deletion is necessary since the ordering of the addresses may
3214
// change, and we need to ensure that the database reflects the latest set of
3215
// addresses so that at the time of reconstructing the node announcement, the
3216
// order is preserved and the signature over the message remains valid.
3217
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
UNCOV
3218
        addresses []net.Addr) error {
×
3219

×
3220
        // Delete any existing addresses for the node. This is required since
×
3221
        // even if the new set of addresses is the same, the ordering may have
×
3222
        // changed for a given address type.
×
3223
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3224
        if err != nil {
×
3225
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3226
                        nodeID, err)
×
3227
        }
×
3228

3229
        // Copy the nodes latest set of addresses.
3230
        newAddresses := map[dbAddressType][]string{
×
3231
                addressTypeIPv4:   {},
×
3232
                addressTypeIPv6:   {},
×
3233
                addressTypeTorV2:  {},
×
3234
                addressTypeTorV3:  {},
×
3235
                addressTypeOpaque: {},
×
3236
        }
×
3237
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3238
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3239
        }
×
3240

UNCOV
3241
        for _, address := range addresses {
×
3242
                switch addr := address.(type) {
×
3243
                case *net.TCPAddr:
×
3244
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3245
                                addAddr(addressTypeIPv4, addr)
×
3246
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3247
                                addAddr(addressTypeIPv6, addr)
×
3248
                        } else {
×
3249
                                return fmt.Errorf("unhandled IP address: %v",
×
3250
                                        addr)
×
UNCOV
3251
                        }
×
3252

3253
                case *tor.OnionAddr:
×
3254
                        switch len(addr.OnionService) {
×
UNCOV
3255
                        case tor.V2Len:
×
3256
                                addAddr(addressTypeTorV2, addr)
×
3257
                        case tor.V3Len:
×
UNCOV
3258
                                addAddr(addressTypeTorV3, addr)
×
UNCOV
3259
                        default:
×
UNCOV
3260
                                return fmt.Errorf("invalid length for a tor " +
×
UNCOV
3261
                                        "address")
×
3262
                        }
3263

3264
                case *lnwire.OpaqueAddrs:
×
3265
                        addAddr(addressTypeOpaque, addr)
×
3266

3267
                default:
×
3268
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3269
                }
3270
        }
3271

3272
        // Any remaining entries in newAddresses are new addresses that need to
3273
        // be added to the database for the first time.
3274
        for addrType, addrList := range newAddresses {
×
3275
                for position, addr := range addrList {
×
3276
                        err := db.InsertNodeAddress(
×
3277
                                ctx, sqlc.InsertNodeAddressParams{
×
UNCOV
3278
                                        NodeID:   nodeID,
×
UNCOV
3279
                                        Type:     int16(addrType),
×
UNCOV
3280
                                        Address:  addr,
×
3281
                                        Position: int32(position),
×
UNCOV
3282
                                },
×
UNCOV
3283
                        )
×
UNCOV
3284
                        if err != nil {
×
UNCOV
3285
                                return fmt.Errorf("unable to insert "+
×
3286
                                        "node(%d) address(%v): %w", nodeID,
×
3287
                                        addr, err)
×
3288
                        }
×
3289
                }
3290
        }
3291

3292
        return nil
×
3293
}
3294

3295
// getNodeAddresses fetches the addresses for a node with the given public key.
3296
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3297
        []net.Addr, error) {
×
3298

×
UNCOV
3299
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
UNCOV
3300
        // are returned in the same order as they were inserted.
×
UNCOV
3301
        rows, err := db.GetNodeAddressesByPubKey(
×
UNCOV
3302
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
3303
                        Version: int16(ProtocolV1),
×
3304
                        PubKey:  nodePub,
×
3305
                },
×
UNCOV
3306
        )
×
3307
        if err != nil {
×
3308
                return false, nil, err
×
3309
        }
×
3310

3311
        // GetNodeAddressesByPubKey uses a left join so there should always be
3312
        // at least one row returned if the node exists even if it has no
3313
        // addresses.
3314
        if len(rows) == 0 {
×
3315
                return false, nil, nil
×
3316
        }
×
3317

3318
        addresses := make([]net.Addr, 0, len(rows))
×
3319
        for _, addr := range rows {
×
3320
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3321
                        continue
×
3322
                }
3323

UNCOV
3324
                address := addr.Address.String
×
3325

×
3326
                switch dbAddressType(addr.Type.Int16) {
×
3327
                case addressTypeIPv4:
×
3328
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3329
                        if err != nil {
×
3330
                                return false, nil, nil
×
UNCOV
3331
                        }
×
3332
                        tcp.IP = tcp.IP.To4()
×
3333

×
3334
                        addresses = append(addresses, tcp)
×
3335

3336
                case addressTypeIPv6:
×
3337
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3338
                        if err != nil {
×
UNCOV
3339
                                return false, nil, nil
×
3340
                        }
×
3341
                        addresses = append(addresses, tcp)
×
3342

3343
                case addressTypeTorV3, addressTypeTorV2:
×
UNCOV
3344
                        service, portStr, err := net.SplitHostPort(address)
×
3345
                        if err != nil {
×
3346
                                return false, nil, fmt.Errorf("unable to "+
×
3347
                                        "split tor v3 address: %v",
×
3348
                                        addr.Address)
×
UNCOV
3349
                        }
×
3350

3351
                        port, err := strconv.Atoi(portStr)
×
3352
                        if err != nil {
×
3353
                                return false, nil, err
×
3354
                        }
×
3355

UNCOV
3356
                        addresses = append(addresses, &tor.OnionAddr{
×
3357
                                OnionService: service,
×
3358
                                Port:         port,
×
3359
                        })
×
3360

3361
                case addressTypeOpaque:
×
3362
                        opaque, err := hex.DecodeString(address)
×
3363
                        if err != nil {
×
UNCOV
3364
                                return false, nil, fmt.Errorf("unable to "+
×
UNCOV
3365
                                        "decode opaque address: %v", addr)
×
UNCOV
3366
                        }
×
3367

UNCOV
3368
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
UNCOV
3369
                                Payload: opaque,
×
UNCOV
3370
                        })
×
3371

UNCOV
3372
                default:
×
UNCOV
3373
                        return false, nil, fmt.Errorf("unknown address "+
×
3374
                                "type: %v", addr.Type)
×
3375
                }
3376
        }
3377

3378
        return true, addresses, nil
×
3379
}
3380

3381
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3382
// database. This includes updating any existing types, inserting any new types,
3383
// and deleting any types that are no longer present.
3384
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3385
        nodeID int64, extraFields map[uint64][]byte) error {
×
3386

×
3387
        // Get any existing extra signed fields for the node.
×
UNCOV
3388
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
UNCOV
3389
        if err != nil {
×
UNCOV
3390
                return err
×
3391
        }
×
3392

3393
        // Make a lookup map of the existing field types so that we can use it
3394
        // to keep track of any fields we should delete.
3395
        m := make(map[uint64]bool)
×
3396
        for _, field := range existingFields {
×
3397
                m[uint64(field.Type)] = true
×
3398
        }
×
3399

3400
        // For all the new fields, we'll upsert them and remove them from the
3401
        // map of existing fields.
3402
        for tlvType, value := range extraFields {
×
UNCOV
3403
                err = db.UpsertNodeExtraType(
×
UNCOV
3404
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
UNCOV
3405
                                NodeID: nodeID,
×
3406
                                Type:   int64(tlvType),
×
UNCOV
3407
                                Value:  value,
×
UNCOV
3408
                        },
×
UNCOV
3409
                )
×
UNCOV
3410
                if err != nil {
×
3411
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3412
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3413
                }
×
3414

3415
                // Remove the field from the map of existing fields if it was
3416
                // present.
3417
                delete(m, tlvType)
×
3418
        }
3419

3420
        // For all the fields that are left in the map of existing fields, we'll
3421
        // delete them as they are no longer present in the new set of fields.
UNCOV
3422
        for tlvType := range m {
×
UNCOV
3423
                err = db.DeleteExtraNodeType(
×
3424
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
UNCOV
3425
                                NodeID: nodeID,
×
UNCOV
3426
                                Type:   int64(tlvType),
×
UNCOV
3427
                        },
×
UNCOV
3428
                )
×
NEW
3429
                if err != nil {
×
3430
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3431
                                "signed field(%v): %w", nodeID, tlvType, err)
×
NEW
3432
                }
×
3433
        }
3434

NEW
3435
        return nil
×
3436
}
3437

3438
// getSourceNode returns the DB node ID and pub key of the source node for the
3439
// specified protocol version.
3440
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3441
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3442

×
3443
        s.srcNodeMu.Lock()
×
3444
        defer s.srcNodeMu.Unlock()
×
3445

×
3446
        // If we already have the source node ID and pub key cached, then
×
3447
        // return them.
×
UNCOV
3448
        if s.srcNodeID != 0 {
×
3449
                return s.srcNodeID, s.srcNodePub, nil
×
3450
        }
×
3451

3452
        var pubKey route.Vertex
×
3453

×
3454
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
UNCOV
3455
        if err != nil {
×
3456
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3457
                        err)
×
NEW
3458
        }
×
3459

NEW
3460
        if len(nodes) == 0 {
×
UNCOV
3461
                return 0, pubKey, ErrSourceNodeNotSet
×
UNCOV
3462
        } else if len(nodes) > 1 {
×
UNCOV
3463
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
UNCOV
3464
                        "protocol %s found", version)
×
UNCOV
3465
        }
×
3466

3467
        copy(pubKey[:], nodes[0].PubKey)
×
3468

×
3469
        s.srcNodeID = nodes[0].NodeID
×
3470
        s.srcNodePub = pubKey
×
3471

×
3472
        return nodes[0].NodeID, pubKey, nil
×
3473
}
3474

3475
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3476
// This then produces a map from TLV type to value. If the input is not a
3477
// valid TLV stream, then an error is returned.
3478
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
NEW
3479
        r := bytes.NewReader(data)
×
3480

×
3481
        tlvStream, err := tlv.NewStream()
×
3482
        if err != nil {
×
3483
                return nil, err
×
UNCOV
3484
        }
×
3485

3486
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3487
        // pass it into the P2P decoding variant.
3488
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
UNCOV
3489
        if err != nil {
×
3490
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
UNCOV
3491
        }
×
UNCOV
3492
        if len(parsedTypes) == 0 {
×
NEW
3493
                return nil, nil
×
NEW
3494
        }
×
3495

NEW
3496
        records := make(map[uint64][]byte)
×
NEW
3497
        for k, v := range parsedTypes {
×
NEW
3498
                records[uint64(k)] = v
×
UNCOV
3499
        }
×
3500

NEW
3501
        return records, nil
×
3502
}
3503

3504
type dbChanInfo struct {
3505
        channelID int64
3506
        node1ID   int64
3507
        node2ID   int64
3508
}
3509

3510
// insertChannel inserts a new channel record into the database.
3511
func insertChannel(ctx context.Context, db SQLQueries,
3512
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3513

×
3514
        var chanIDB [8]byte
×
3515
        byteOrder.PutUint64(chanIDB[:], edge.ChannelID)
×
3516

×
NEW
3517
        // Make sure that the channel doesn't already exist. We do this
×
3518
        // explicitly instead of relying on catching a unique constraint error
×
NEW
3519
        // because relying on SQL to throw that error would abort the entire
×
3520
        // batch of transactions.
×
UNCOV
3521
        _, err := db.GetChannelBySCID(
×
UNCOV
3522
                ctx, sqlc.GetChannelBySCIDParams{
×
UNCOV
3523
                        Scid:    chanIDB[:],
×
3524
                        Version: int16(ProtocolV1),
×
3525
                },
×
NEW
3526
        )
×
3527
        if err == nil {
×
UNCOV
3528
                return nil, ErrEdgeAlreadyExist
×
3529
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3530
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
NEW
3531
        }
×
3532

3533
        // Make sure that at least a "shell" entry for each node is present in
3534
        // the nodes table.
3535
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3536
        if err != nil {
×
3537
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
UNCOV
3538
        }
×
3539

3540
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3541
        if err != nil {
×
3542
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3543
        }
×
3544

3545
        var capacity sql.NullInt64
×
3546
        if edge.Capacity != 0 {
×
3547
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3548
        }
×
3549

3550
        createParams := sqlc.CreateChannelParams{
×
3551
                Version:     int16(ProtocolV1),
×
3552
                Scid:        chanIDB[:],
×
3553
                NodeID1:     node1DBID,
×
3554
                NodeID2:     node2DBID,
×
3555
                Outpoint:    edge.ChannelPoint.String(),
×
3556
                Capacity:    capacity,
×
3557
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
UNCOV
3558
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
UNCOV
3559
        }
×
3560

×
3561
        if edge.AuthProof != nil {
×
NEW
3562
                proof := edge.AuthProof
×
3563

×
UNCOV
3564
                createParams.Node1Signature = proof.NodeSig1Bytes
×
UNCOV
3565
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3566
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3567
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3568
        }
×
3569

3570
        // Insert the new channel record.
3571
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
UNCOV
3572
        if err != nil {
×
3573
                return nil, err
×
3574
        }
×
3575

3576
        // Insert any channel features.
3577
        if len(edge.Features) != 0 {
×
3578
                chanFeatures := lnwire.NewRawFeatureVector()
×
3579
                err := chanFeatures.Decode(bytes.NewReader(edge.Features))
×
3580
                if err != nil {
×
3581
                        return nil, err
×
NEW
3582
                }
×
3583

3584
                fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features)
×
3585
                for feature := range fv.Features() {
×
UNCOV
3586
                        err = db.InsertChannelFeature(
×
UNCOV
3587
                                ctx, sqlc.InsertChannelFeatureParams{
×
UNCOV
3588
                                        ChannelID:  dbChanID,
×
UNCOV
3589
                                        FeatureBit: int32(feature),
×
3590
                                },
×
3591
                        )
×
NEW
3592
                        if err != nil {
×
NEW
3593
                                return nil, fmt.Errorf("unable to insert "+
×
UNCOV
3594
                                        "channel(%d) feature(%v): %w", dbChanID,
×
UNCOV
3595
                                        feature, err)
×
3596
                        }
×
3597
                }
3598
        }
3599

3600
        // Finally, insert any extra TLV fields in the channel announcement.
3601
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3602
        if err != nil {
×
3603
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
3604
                        "data: %w", err)
×
NEW
3605
        }
×
3606

NEW
3607
        for tlvType, value := range extra {
×
UNCOV
3608
                err := db.CreateChannelExtraType(
×
UNCOV
3609
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
UNCOV
3610
                                ChannelID: dbChanID,
×
NEW
3611
                                Type:      int64(tlvType),
×
NEW
3612
                                Value:     value,
×
NEW
3613
                        },
×
NEW
3614
                )
×
NEW
3615
                if err != nil {
×
UNCOV
3616
                        return nil, fmt.Errorf("unable to upsert "+
×
UNCOV
3617
                                "channel(%d) extra signed field(%v): %w",
×
UNCOV
3618
                                edge.ChannelID, tlvType, err)
×
UNCOV
3619
                }
×
3620
        }
3621

UNCOV
3622
        return &dbChanInfo{
×
3623
                channelID: dbChanID,
×
3624
                node1ID:   node1DBID,
×
3625
                node2ID:   node2DBID,
×
3626
        }, nil
×
3627
}
3628

3629
// maybeCreateShellNode checks if a shell node entry exists for the
3630
// given public key. If it does not exist, then a new shell node entry is
3631
// created. The ID of the node is returned. A shell node only has a protocol
3632
// version and public key persisted.
3633
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3634
        pubKey route.Vertex) (int64, error) {
×
3635

×
3636
        dbNode, err := db.GetNodeByPubKey(
×
UNCOV
3637
                ctx, sqlc.GetNodeByPubKeyParams{
×
UNCOV
3638
                        PubKey:  pubKey[:],
×
UNCOV
3639
                        Version: int16(ProtocolV1),
×
3640
                },
×
3641
        )
×
3642
        // The node exists. Return the ID.
×
3643
        if err == nil {
×
3644
                return dbNode.ID, nil
×
3645
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3646
                return 0, err
×
UNCOV
3647
        }
×
3648

3649
        // Otherwise, the node does not exist, so we create a shell entry for
3650
        // it.
NEW
3651
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
NEW
3652
                Version: int16(ProtocolV1),
×
NEW
3653
                PubKey:  pubKey[:],
×
NEW
3654
        })
×
NEW
3655
        if err != nil {
×
NEW
3656
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
NEW
3657
        }
×
3658

NEW
3659
        return id, nil
×
3660
}
3661

3662
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3663
// the database. This includes updating any existing types, inserting any new
3664
// types, and deleting any types that are no longer present.
3665
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
NEW
3666
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
NEW
3667

×
NEW
3668
        // Get any existing extra signed fields for the channel policy.
×
NEW
3669
        existingFields, err := db.GetChannelPolicyExtraTypes(
×
NEW
3670
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
NEW
3671
                        ID: chanPolicyID,
×
NEW
3672
                },
×
NEW
3673
        )
×
NEW
3674
        if err != nil {
×
NEW
3675
                return err
×
NEW
3676
        }
×
3677

3678
        // Make a lookup map of the existing field types so that we can use it
3679
        // to keep track of any fields we should delete.
NEW
3680
        m := make(map[uint64]bool)
×
NEW
3681
        for _, field := range existingFields {
×
NEW
3682
                if field.PolicyID != chanPolicyID {
×
NEW
3683
                        return fmt.Errorf("channel policy ID mismatch: "+
×
NEW
3684
                                "expected %d, got %d", chanPolicyID,
×
NEW
3685
                                field.PolicyID)
×
NEW
3686
                }
×
3687

NEW
3688
                m[uint64(field.Type)] = true
×
3689
        }
3690

3691
        // For all the new fields, we'll upsert them and remove them from the
3692
        // map of existing fields.
NEW
3693
        for tlvType, value := range extraFields {
×
NEW
3694
                err = db.UpsertChanPolicyExtraType(
×
NEW
3695
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
NEW
3696
                                ChannelPolicyID: chanPolicyID,
×
NEW
3697
                                Type:            int64(tlvType),
×
NEW
3698
                                Value:           value,
×
NEW
3699
                        },
×
NEW
3700
                )
×
NEW
3701
                if err != nil {
×
NEW
3702
                        return fmt.Errorf("unable to upsert "+
×
NEW
3703
                                "channel_policy(%d) extra signed field(%v): %w",
×
NEW
3704
                                chanPolicyID, tlvType, err)
×
NEW
3705
                }
×
3706

3707
                // Remove the field from the map of existing fields if it was
3708
                // present.
NEW
3709
                delete(m, tlvType)
×
3710
        }
3711

3712
        // For all the fields that are left in the map of existing fields, we'll
3713
        // delete them as they are no longer present in the new set of fields.
NEW
3714
        for tlvType := range m {
×
NEW
3715
                err = db.DeleteChannelPolicyExtraType(
×
NEW
3716
                        ctx, sqlc.DeleteChannelPolicyExtraTypeParams{
×
NEW
3717
                                ChannelPolicyID: chanPolicyID,
×
NEW
3718
                                Type:            int64(tlvType),
×
NEW
3719
                        },
×
NEW
3720
                )
×
NEW
3721
                if err != nil {
×
NEW
3722
                        return fmt.Errorf("unable to delete "+
×
NEW
3723
                                "channel_policy(%d) extra signed field(%v): %w",
×
NEW
3724
                                chanPolicyID, tlvType, err)
×
NEW
3725
                }
×
3726
        }
3727

NEW
3728
        return nil
×
3729
}
3730

3731
// getAndBuildEdgeInfoAndPolicies fetches all the data from the DB required to
3732
// build complete models.ChannelEdgeInfo and  models.ChannelEdgePolicy instances
3733
// for a channel with the given DB ID.
3734
func getAndBuildEdgeInfoAndPolicies(ctx context.Context, db SQLQueries,
3735
        chain chainhash.Hash, dbChanID int64, dbChanRow any, node1,
3736
        node2 route.Vertex) (*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
NEW
3737
        *models.ChannelEdgePolicy, error) {
×
NEW
3738

×
NEW
3739
        edge, err := getAndBuildEdgeInfo(
×
NEW
3740
                ctx, db, chain, dbChanID, dbChanRow, node1, node2,
×
NEW
3741
        )
×
NEW
3742
        if err != nil {
×
NEW
3743
                return nil, nil, nil, err
×
NEW
3744
        }
×
3745

NEW
3746
        dbPol1, dbPol2, err := extractChannelPolicies(dbChanRow)
×
NEW
3747
        if err != nil {
×
NEW
3748
                return nil, nil, nil, err
×
NEW
3749
        }
×
3750

NEW
3751
        p1, p2, err := getAndBuildChanPolicies(
×
NEW
3752
                ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
NEW
3753
        )
×
NEW
3754

×
NEW
3755
        return edge, p1, p2, err
×
3756
}
3757

3758
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3759
// provided dbChanRow and also fetches any other required information
3760
// to construct the edge info.
3761
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
3762
        chain chainhash.Hash, dbChanID int64, dbChanRow any, node1,
NEW
3763
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
NEW
3764

×
NEW
3765
        dbChan, err := extractChannel(dbChanRow)
×
NEW
3766
        if err != nil {
×
NEW
3767
                return nil, err
×
NEW
3768
        }
×
3769

NEW
3770
        fv, extras, err := getChanFeaturesAndExtras(
×
NEW
3771
                ctx, db, dbChanID,
×
NEW
3772
        )
×
NEW
3773
        if err != nil {
×
NEW
3774
                return nil, err
×
NEW
3775
        }
×
3776

NEW
3777
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
NEW
3778
        if err != nil {
×
NEW
3779
                return nil, err
×
NEW
3780
        }
×
3781

NEW
3782
        var featureBuf bytes.Buffer
×
NEW
3783
        if err := fv.Encode(&featureBuf); err != nil {
×
NEW
3784
                return nil, fmt.Errorf("unable to encode features: %w", err)
×
NEW
3785
        }
×
3786

NEW
3787
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
NEW
3788
        if err != nil {
×
NEW
3789
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
NEW
3790
                        "fields: %w", err)
×
NEW
3791
        }
×
NEW
3792
        if recs == nil {
×
NEW
3793
                recs = make([]byte, 0)
×
NEW
3794
        }
×
3795

NEW
3796
        var btcKey1, btcKey2 route.Vertex
×
NEW
3797
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
NEW
3798
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
NEW
3799

×
NEW
3800
        channel := &models.ChannelEdgeInfo{
×
NEW
3801
                ChainHash:        chain,
×
NEW
3802
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
NEW
3803
                NodeKey1Bytes:    node1,
×
NEW
3804
                NodeKey2Bytes:    node2,
×
NEW
3805
                BitcoinKey1Bytes: btcKey1,
×
NEW
3806
                BitcoinKey2Bytes: btcKey2,
×
NEW
3807
                ChannelPoint:     *op,
×
NEW
3808
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
NEW
3809
                Features:         featureBuf.Bytes(),
×
NEW
3810
                ExtraOpaqueData:  recs,
×
NEW
3811
        }
×
NEW
3812

×
NEW
3813
        if dbChan.Bitcoin1Signature != nil {
×
NEW
3814
                channel.AuthProof = &models.ChannelAuthProof{
×
NEW
3815
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
NEW
3816
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
NEW
3817
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
NEW
3818
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
NEW
3819
                }
×
NEW
3820
        }
×
3821

NEW
3822
        return channel, nil
×
3823
}
3824

3825
// buildNodeVertices is a helper that converts raw node public keys
3826
// into route.Vertex instances.
3827
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
NEW
3828
        route.Vertex, error) {
×
NEW
3829
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
NEW
3830
        if err != nil {
×
NEW
3831
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
NEW
3832
                        "create vertex from node1 pubkey: %w", err)
×
NEW
3833
        }
×
3834

NEW
3835
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
NEW
3836
        if err != nil {
×
NEW
3837
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
NEW
3838
                        "create vertex from node2 pubkey: %w", err)
×
NEW
3839
        }
×
3840

NEW
3841
        return node1Vertex, node2Vertex, nil
×
3842
}
3843

3844
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
3845
// for a channel with the given ID.
3846
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
NEW
3847
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
NEW
3848

×
NEW
3849
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
NEW
3850
        if err != nil {
×
NEW
3851
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
NEW
3852
                        "features and extras: %w", err)
×
NEW
3853
        }
×
3854

NEW
3855
        var (
×
NEW
3856
                fv     = lnwire.EmptyFeatureVector()
×
NEW
3857
                extras = make(map[uint64][]byte)
×
NEW
3858
        )
×
NEW
3859
        for _, row := range rows {
×
NEW
3860
                switch row.Kind {
×
NEW
3861
                case "feature":
×
NEW
3862
                        featureBit, err := strconv.Atoi(row.Key)
×
NEW
3863
                        if err != nil {
×
NEW
3864
                                return nil, nil, err
×
NEW
3865
                        }
×
NEW
3866
                        fv.Set(lnwire.FeatureBit(featureBit))
×
3867

NEW
3868
                case "extra":
×
NEW
3869
                        tlvType, err := strconv.ParseInt(row.Key, 10, 64)
×
NEW
3870
                        if err != nil {
×
NEW
3871
                                return nil, nil, err
×
NEW
3872
                        }
×
NEW
3873
                        valueBytes, ok := row.Value.([]byte)
×
NEW
3874
                        if !ok {
×
NEW
3875
                                return nil, nil, fmt.Errorf("unexpected type "+
×
NEW
3876
                                        "for Value: %T", row.Value)
×
NEW
3877
                        }
×
NEW
3878
                        extras[uint64(tlvType)] = valueBytes
×
3879
                }
3880
        }
3881

NEW
3882
        return fv, extras, nil
×
3883
}
3884

3885
// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves
3886
// all the extra info required to build the complete models.ChannelEdgePolicy
3887
// types. It returns two policies, which may be nil if the provided
3888
// sqlc.ChannelPolicy records are nil.
3889
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
3890
        dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1,
3891
        node2 route.Vertex) (*models.ChannelEdgePolicy,
NEW
3892
        *models.ChannelEdgePolicy, error) {
×
NEW
3893

×
NEW
3894
        if dbPol1 == nil && dbPol2 == nil {
×
NEW
3895
                return nil, nil, nil
×
NEW
3896
        }
×
3897

NEW
3898
        var (
×
NEW
3899
                policy1ID int64
×
NEW
3900
                policy2ID int64
×
NEW
3901
        )
×
NEW
3902
        if dbPol1 != nil {
×
NEW
3903
                policy1ID = dbPol1.ID
×
NEW
3904
        }
×
NEW
3905
        if dbPol2 != nil {
×
NEW
3906
                policy2ID = dbPol2.ID
×
NEW
3907
        }
×
NEW
3908
        rows, err := db.GetChannelPolicyExtraTypes(
×
NEW
3909
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
NEW
3910
                        ID:   policy1ID,
×
NEW
3911
                        ID_2: policy2ID,
×
NEW
3912
                },
×
NEW
3913
        )
×
NEW
3914
        if err != nil {
×
NEW
3915
                return nil, nil, err
×
NEW
3916
        }
×
3917

NEW
3918
        var (
×
NEW
3919
                dbPol1Extras = make(map[uint64][]byte)
×
NEW
3920
                dbPol2Extras = make(map[uint64][]byte)
×
NEW
3921
        )
×
NEW
3922
        for _, row := range rows {
×
NEW
3923
                switch row.PolicyID {
×
NEW
3924
                case policy1ID:
×
NEW
3925
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
NEW
3926
                case policy2ID:
×
NEW
3927
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
NEW
3928
                default:
×
NEW
3929
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
NEW
3930
                                "in row: %v", row.PolicyID, row)
×
3931
                }
3932
        }
3933

NEW
3934
        var pol1, pol2 *models.ChannelEdgePolicy
×
NEW
3935
        if dbPol1 != nil {
×
NEW
3936
                pol1, err = buildChanPolicy(
×
NEW
3937
                        *dbPol1, channelID, dbPol1Extras, node2, true,
×
NEW
3938
                )
×
NEW
3939
                if err != nil {
×
NEW
3940
                        return nil, nil, err
×
NEW
3941
                }
×
3942
        }
NEW
3943
        if dbPol2 != nil {
×
NEW
3944
                pol2, err = buildChanPolicy(
×
NEW
3945
                        *dbPol2, channelID, dbPol2Extras, node1, false,
×
NEW
3946
                )
×
NEW
3947
                if err != nil {
×
NEW
3948
                        return nil, nil, err
×
NEW
3949
                }
×
3950
        }
3951

NEW
3952
        return pol1, pol2, nil
×
3953
}
3954

3955
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
3956
// provided sqlc.ChannelPolicy and other required information.
3957
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
3958
        extras map[uint64][]byte, toNode route.Vertex,
NEW
3959
        isNode1 bool) (*models.ChannelEdgePolicy, error) {
×
NEW
3960

×
NEW
3961
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
NEW
3962
        if err != nil {
×
NEW
3963
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
NEW
3964
                        "fields: %w", err)
×
NEW
3965
        }
×
3966

NEW
3967
        var msgFlags lnwire.ChanUpdateMsgFlags
×
NEW
3968
        if dbPolicy.MaxHtlcMsat.Valid {
×
NEW
3969
                msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc
×
NEW
3970
        }
×
3971

NEW
3972
        var chanFlags lnwire.ChanUpdateChanFlags
×
NEW
3973
        if !isNode1 {
×
NEW
3974
                chanFlags |= lnwire.ChanUpdateDirection
×
NEW
3975
        }
×
NEW
3976
        if dbPolicy.Disabled.Bool {
×
NEW
3977
                chanFlags |= lnwire.ChanUpdateDisabled
×
NEW
3978
        }
×
3979

NEW
3980
        var inboundFee fn.Option[lnwire.Fee]
×
NEW
3981
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
NEW
3982
                dbPolicy.InboundBaseFeeMsat.Valid {
×
NEW
3983

×
NEW
3984
                inboundFee = fn.Some(lnwire.Fee{
×
NEW
3985
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
NEW
3986
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
NEW
3987
                })
×
NEW
3988
        }
×
3989

NEW
3990
        return &models.ChannelEdgePolicy{
×
NEW
3991
                SigBytes:  dbPolicy.Signature,
×
NEW
3992
                ChannelID: channelID,
×
NEW
3993
                LastUpdate: time.Unix(
×
NEW
3994
                        dbPolicy.LastUpdate.Int64, 0,
×
NEW
3995
                ),
×
NEW
3996
                MessageFlags:  msgFlags,
×
NEW
3997
                ChannelFlags:  chanFlags,
×
NEW
3998
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
NEW
3999
                MinHTLC: lnwire.MilliSatoshi(
×
NEW
4000
                        dbPolicy.MinHtlcMsat,
×
NEW
4001
                ),
×
NEW
4002
                MaxHTLC: lnwire.MilliSatoshi(
×
NEW
4003
                        dbPolicy.MaxHtlcMsat.Int64,
×
NEW
4004
                ),
×
NEW
4005
                FeeBaseMSat: lnwire.MilliSatoshi(
×
NEW
4006
                        dbPolicy.BaseFeeMsat,
×
NEW
4007
                ),
×
NEW
4008
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
NEW
4009
                ToNode:                    toNode,
×
NEW
4010
                InboundFee:                inboundFee,
×
NEW
4011
                ExtraOpaqueData:           recs,
×
NEW
4012
        }, nil
×
4013
}
4014

4015
// getAndBuildNodes builds the models.LightningNode instances for the
4016
// given row which is expected to be a sqlc type that contains node information.
4017
func getAndBuildNodes(ctx context.Context, db SQLQueries,
NEW
4018
        row any) (*models.LightningNode, *models.LightningNode, error) {
×
NEW
4019

×
NEW
4020
        dbNode1, dbNode2, err := extractNodes(row)
×
NEW
4021
        if err != nil {
×
NEW
4022
                return nil, nil, err
×
NEW
4023
        }
×
4024

NEW
4025
        node1, err := buildNode(ctx, db, &dbNode1)
×
NEW
4026
        if err != nil {
×
NEW
4027
                return nil, nil, err
×
NEW
4028
        }
×
4029

NEW
4030
        node2, err := buildNode(ctx, db, &dbNode2)
×
NEW
4031
        if err != nil {
×
NEW
4032
                return nil, nil, err
×
NEW
4033
        }
×
4034

NEW
4035
        return node1, node2, nil
×
4036
}
4037

4038
// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give
4039
// row which is expected to be a sqlc type that contains channel policy
4040
// information. It returns two policies, which may be nil if the policy
4041
// information is not present in the row.
4042
//
4043
//nolint:ll
4044
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
NEW
4045
        error) {
×
NEW
4046

×
NEW
4047
        var policy1, policy2 *sqlc.ChannelPolicy
×
NEW
4048
        switch r := row.(type) {
×
NEW
4049
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
NEW
4050
                if r.Policy1ID.Valid {
×
NEW
4051
                        policy1 = &sqlc.ChannelPolicy{
×
NEW
4052
                                ID:                      r.Policy1ID.Int64,
×
NEW
4053
                                Version:                 r.Policy1Version.Int16,
×
NEW
4054
                                ChannelID:               r.ID,
×
NEW
4055
                                NodeID:                  r.Policy1NodeID.Int64,
×
NEW
4056
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
4057
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
4058
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
4059
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
4060
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
4061
                                LastUpdate:              r.Policy1LastUpdate,
×
NEW
4062
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
4063
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
4064
                                Disabled:                r.Policy1Disabled,
×
NEW
4065
                                Signature:               r.Policy1Signature,
×
NEW
4066
                        }
×
NEW
4067
                }
×
NEW
4068
                if r.Policy2ID.Valid {
×
NEW
4069
                        policy2 = &sqlc.ChannelPolicy{
×
NEW
4070
                                ID:                      r.Policy2ID.Int64,
×
NEW
4071
                                Version:                 r.Policy2Version.Int16,
×
NEW
4072
                                ChannelID:               r.ID,
×
NEW
4073
                                NodeID:                  r.Policy2NodeID.Int64,
×
NEW
4074
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
4075
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
4076
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
4077
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
4078
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
4079
                                LastUpdate:              r.Policy2LastUpdate,
×
NEW
4080
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
4081
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
4082
                                Disabled:                r.Policy2Disabled,
×
NEW
4083
                                Signature:               r.Policy2Signature,
×
NEW
4084
                        }
×
NEW
4085
                }
×
NEW
4086
                return policy1, policy2, nil
×
4087

NEW
4088
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
NEW
4089
                if r.Policy1ID.Valid {
×
NEW
4090
                        policy1 = &sqlc.ChannelPolicy{
×
NEW
4091
                                ID:                      r.Policy1ID.Int64,
×
NEW
4092
                                Version:                 r.Policy1Version.Int16,
×
NEW
4093
                                ChannelID:               r.ID,
×
NEW
4094
                                NodeID:                  r.Policy1NodeID.Int64,
×
NEW
4095
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
4096
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
4097
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
4098
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
4099
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
4100
                                LastUpdate:              r.Policy1LastUpdate,
×
NEW
4101
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
4102
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
4103
                                Disabled:                r.Policy1Disabled,
×
NEW
4104
                                Signature:               r.Policy1Signature,
×
NEW
4105
                        }
×
NEW
4106
                }
×
NEW
4107
                if r.Policy2ID.Valid {
×
NEW
4108
                        policy2 = &sqlc.ChannelPolicy{
×
NEW
4109
                                ID:                      r.Policy2ID.Int64,
×
NEW
4110
                                Version:                 r.Policy2Version.Int16,
×
NEW
4111
                                ChannelID:               r.ID,
×
NEW
4112
                                NodeID:                  r.Policy2NodeID.Int64,
×
NEW
4113
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
4114
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
4115
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
4116
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
4117
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
4118
                                LastUpdate:              r.Policy2LastUpdate,
×
NEW
4119
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
4120
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
4121
                                Disabled:                r.Policy2Disabled,
×
NEW
4122
                                Signature:               r.Policy2Signature,
×
NEW
4123
                        }
×
NEW
4124
                }
×
NEW
4125
                return policy1, policy2, nil
×
4126

NEW
4127
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
NEW
4128
                if r.Policy1ID.Valid {
×
NEW
4129
                        policy1 = &sqlc.ChannelPolicy{
×
NEW
4130
                                ID:                      r.Policy1ID.Int64,
×
NEW
4131
                                Version:                 r.Policy1Version.Int16,
×
NEW
4132
                                ChannelID:               r.ID,
×
NEW
4133
                                NodeID:                  r.Policy1NodeID.Int64,
×
NEW
4134
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
4135
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
4136
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
4137
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
4138
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
4139
                                LastUpdate:              r.Policy1LastUpdate,
×
NEW
4140
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
4141
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
4142
                                Disabled:                r.Policy1Disabled,
×
NEW
4143
                                Signature:               r.Policy1Signature,
×
NEW
4144
                        }
×
NEW
4145
                }
×
NEW
4146
                if r.Policy2ID.Valid {
×
NEW
4147
                        policy2 = &sqlc.ChannelPolicy{
×
NEW
4148
                                ID:                      r.Policy2ID.Int64,
×
NEW
4149
                                Version:                 r.Policy2Version.Int16,
×
NEW
4150
                                ChannelID:               r.ID,
×
NEW
4151
                                NodeID:                  r.Policy2NodeID.Int64,
×
NEW
4152
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
4153
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
4154
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
4155
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
4156
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
4157
                                LastUpdate:              r.Policy2LastUpdate,
×
NEW
4158
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
4159
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
4160
                                Disabled:                r.Policy2Disabled,
×
NEW
4161
                                Signature:               r.Policy2Signature,
×
NEW
4162
                        }
×
NEW
4163
                }
×
NEW
4164
                return policy1, policy2, nil
×
4165

NEW
4166
        case sqlc.ListChannelsByNodeIDRow:
×
NEW
4167
                if r.Policy1ID.Valid {
×
NEW
4168
                        policy1 = &sqlc.ChannelPolicy{
×
NEW
4169
                                ID:                      r.Policy1ID.Int64,
×
NEW
4170
                                Version:                 r.Policy1Version.Int16,
×
NEW
4171
                                ChannelID:               r.ID,
×
NEW
4172
                                NodeID:                  r.Policy1NodeID.Int64,
×
NEW
4173
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
4174
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
4175
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
4176
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
4177
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
4178
                                LastUpdate:              r.Policy1LastUpdate,
×
NEW
4179
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
4180
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
4181
                                Disabled:                r.Policy1Disabled,
×
NEW
4182
                                Signature:               r.Policy1Signature,
×
NEW
4183
                        }
×
NEW
4184
                }
×
NEW
4185
                if r.Policy2ID.Valid {
×
NEW
4186
                        policy2 = &sqlc.ChannelPolicy{
×
NEW
4187
                                ID:                      r.Policy2ID.Int64,
×
NEW
4188
                                Version:                 r.Policy2Version.Int16,
×
NEW
4189
                                ChannelID:               r.ID,
×
NEW
4190
                                NodeID:                  r.Policy2NodeID.Int64,
×
NEW
4191
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
4192
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
4193
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
4194
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
4195
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
4196
                                LastUpdate:              r.Policy2LastUpdate,
×
NEW
4197
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
4198
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
4199
                                Disabled:                r.Policy2Disabled,
×
NEW
4200
                                Signature:               r.Policy2Signature,
×
NEW
4201
                        }
×
NEW
4202
                }
×
NEW
4203
                return policy1, policy2, nil
×
4204

NEW
4205
        case sqlc.ListAllChannelsRow:
×
NEW
4206
                if r.Policy1ID.Valid {
×
NEW
4207
                        policy1 = &sqlc.ChannelPolicy{
×
NEW
4208
                                ID:                      r.Policy1ID.Int64,
×
NEW
4209
                                Version:                 r.Policy1Version.Int16,
×
NEW
4210
                                ChannelID:               r.ID,
×
NEW
4211
                                NodeID:                  r.Policy1NodeID.Int64,
×
NEW
4212
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
4213
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
4214
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
4215
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
4216
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
4217
                                LastUpdate:              r.Policy1LastUpdate,
×
NEW
4218
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
4219
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
4220
                                Disabled:                r.Policy1Disabled,
×
NEW
4221
                                Signature:               r.Policy1Signature,
×
NEW
4222
                        }
×
NEW
4223
                }
×
NEW
4224
                if r.Policy2ID.Valid {
×
NEW
4225
                        policy2 = &sqlc.ChannelPolicy{
×
NEW
4226
                                ID:                      r.Policy2ID.Int64,
×
NEW
4227
                                Version:                 r.Policy2Version.Int16,
×
NEW
4228
                                ChannelID:               r.ID,
×
NEW
4229
                                NodeID:                  r.Policy2NodeID.Int64,
×
NEW
4230
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
4231
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
4232
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
4233
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
4234
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
4235
                                LastUpdate:              r.Policy2LastUpdate,
×
NEW
4236
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
4237
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
4238
                                Disabled:                r.Policy2Disabled,
×
NEW
4239
                                Signature:               r.Policy2Signature,
×
NEW
4240
                        }
×
NEW
4241
                }
×
4242

NEW
4243
                return policy1, policy2, nil
×
NEW
4244
        default:
×
NEW
4245
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
NEW
4246
                        "extractChannelPolicies: %T", r)
×
4247
        }
4248
}
4249

4250
// extractChannel extracts the sqlc.Channel record from the given row
4251
// which is expected to be a sqlc type that contains channel information.
NEW
4252
func extractChannel(row any) (sqlc.Channel, error) {
×
NEW
4253
        switch r := row.(type) {
×
NEW
4254
        case sqlc.GetChannelsBySCIDRangeRow:
×
NEW
4255
                return sqlc.Channel{
×
NEW
4256
                        ID:                r.ID,
×
NEW
4257
                        Version:           r.Version,
×
NEW
4258
                        Scid:              r.Scid,
×
NEW
4259
                        NodeID1:           r.NodeID1,
×
NEW
4260
                        NodeID2:           r.NodeID2,
×
NEW
4261
                        Outpoint:          r.Outpoint,
×
NEW
4262
                        Capacity:          r.Capacity,
×
NEW
4263
                        BitcoinKey1:       r.BitcoinKey1,
×
NEW
4264
                        BitcoinKey2:       r.BitcoinKey2,
×
NEW
4265
                        Node1Signature:    r.Node1Signature,
×
NEW
4266
                        Node2Signature:    r.Node2Signature,
×
NEW
4267
                        Bitcoin1Signature: r.Bitcoin1Signature,
×
NEW
4268
                        Bitcoin2Signature: r.Bitcoin2Signature,
×
NEW
4269
                }, nil
×
4270

NEW
4271
        case sqlc.GetChannelByOutpointRow:
×
NEW
4272
                return sqlc.Channel{
×
NEW
4273
                        ID:                r.ID,
×
NEW
4274
                        Version:           r.Version,
×
NEW
4275
                        Scid:              r.Scid,
×
NEW
4276
                        NodeID1:           r.NodeID1,
×
NEW
4277
                        NodeID2:           r.NodeID2,
×
NEW
4278
                        Outpoint:          r.Outpoint,
×
NEW
4279
                        Capacity:          r.Capacity,
×
NEW
4280
                        BitcoinKey1:       r.BitcoinKey1,
×
NEW
4281
                        BitcoinKey2:       r.BitcoinKey2,
×
NEW
4282
                        Node1Signature:    r.Node1Signature,
×
NEW
4283
                        Node2Signature:    r.Node2Signature,
×
NEW
4284
                        Bitcoin1Signature: r.Bitcoin1Signature,
×
NEW
4285
                        Bitcoin2Signature: r.Bitcoin2Signature,
×
NEW
4286
                }, nil
×
4287

NEW
4288
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
NEW
4289
                return sqlc.Channel{
×
NEW
4290
                        ID:                r.ID,
×
NEW
4291
                        Version:           r.Version,
×
NEW
4292
                        Scid:              r.Scid,
×
NEW
4293
                        NodeID1:           r.NodeID1,
×
NEW
4294
                        NodeID2:           r.NodeID2,
×
NEW
4295
                        Outpoint:          r.Outpoint,
×
NEW
4296
                        Capacity:          r.Capacity,
×
NEW
4297
                        BitcoinKey1:       r.BitcoinKey1,
×
NEW
4298
                        BitcoinKey2:       r.BitcoinKey2,
×
NEW
4299
                        Node1Signature:    r.Node1Signature,
×
NEW
4300
                        Node2Signature:    r.Node2Signature,
×
NEW
4301
                        Bitcoin1Signature: r.Bitcoin1Signature,
×
NEW
4302
                        Bitcoin2Signature: r.Bitcoin2Signature,
×
NEW
4303
                }, nil
×
4304

NEW
4305
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
NEW
4306
                return sqlc.Channel{
×
NEW
4307
                        ID:                r.ID,
×
NEW
4308
                        Version:           r.Version,
×
NEW
4309
                        Scid:              r.Scid,
×
NEW
4310
                        NodeID1:           r.NodeID1,
×
NEW
4311
                        NodeID2:           r.NodeID2,
×
NEW
4312
                        Outpoint:          r.Outpoint,
×
NEW
4313
                        Capacity:          r.Capacity,
×
NEW
4314
                        BitcoinKey1:       r.BitcoinKey1,
×
NEW
4315
                        BitcoinKey2:       r.BitcoinKey2,
×
NEW
4316
                        Node1Signature:    r.Node1Signature,
×
NEW
4317
                        Node2Signature:    r.Node2Signature,
×
NEW
4318
                        Bitcoin1Signature: r.Bitcoin1Signature,
×
NEW
4319
                        Bitcoin2Signature: r.Bitcoin2Signature,
×
NEW
4320
                }, nil
×
4321

NEW
4322
        case sqlc.ListAllChannelsRow:
×
NEW
4323
                return sqlc.Channel{
×
NEW
4324
                        ID:                r.ID,
×
NEW
4325
                        Version:           r.Version,
×
NEW
4326
                        Scid:              r.Scid,
×
NEW
4327
                        NodeID1:           r.NodeID1,
×
NEW
4328
                        NodeID2:           r.NodeID2,
×
NEW
4329
                        Outpoint:          r.Outpoint,
×
NEW
4330
                        Capacity:          r.Capacity,
×
NEW
4331
                        BitcoinKey1:       r.BitcoinKey1,
×
NEW
4332
                        BitcoinKey2:       r.BitcoinKey2,
×
NEW
4333
                        Node1Signature:    r.Node1Signature,
×
NEW
4334
                        Node2Signature:    r.Node2Signature,
×
NEW
4335
                        Bitcoin1Signature: r.Bitcoin1Signature,
×
NEW
4336
                        Bitcoin2Signature: r.Bitcoin2Signature,
×
NEW
4337
                }, nil
×
4338

NEW
4339
        case sqlc.ListChannelsByNodeIDRow:
×
NEW
4340
                return sqlc.Channel{
×
NEW
4341
                        ID:                r.ID,
×
NEW
4342
                        Version:           r.Version,
×
NEW
4343
                        Scid:              r.Scid,
×
NEW
4344
                        NodeID1:           r.NodeID1,
×
NEW
4345
                        NodeID2:           r.NodeID2,
×
NEW
4346
                        Outpoint:          r.Outpoint,
×
NEW
4347
                        Capacity:          r.Capacity,
×
NEW
4348
                        BitcoinKey1:       r.BitcoinKey1,
×
NEW
4349
                        BitcoinKey2:       r.BitcoinKey2,
×
NEW
4350
                        Node1Signature:    r.Node1Signature,
×
NEW
4351
                        Node2Signature:    r.Node2Signature,
×
NEW
4352
                        Bitcoin1Signature: r.Bitcoin1Signature,
×
NEW
4353
                        Bitcoin2Signature: r.Bitcoin2Signature,
×
NEW
4354
                }, nil
×
4355

NEW
4356
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
NEW
4357
                return sqlc.Channel{
×
NEW
4358
                        ID:                r.ID,
×
NEW
4359
                        Version:           r.Version,
×
NEW
4360
                        Scid:              r.Scid,
×
NEW
4361
                        NodeID1:           r.NodeID1,
×
NEW
4362
                        NodeID2:           r.NodeID2,
×
NEW
4363
                        Outpoint:          r.Outpoint,
×
NEW
4364
                        Capacity:          r.Capacity,
×
NEW
4365
                        BitcoinKey1:       r.BitcoinKey1,
×
NEW
4366
                        BitcoinKey2:       r.BitcoinKey2,
×
NEW
4367
                        Node1Signature:    r.Node1Signature,
×
NEW
4368
                        Node2Signature:    r.Node2Signature,
×
NEW
4369
                        Bitcoin1Signature: r.Bitcoin1Signature,
×
NEW
4370
                        Bitcoin2Signature: r.Bitcoin2Signature,
×
NEW
4371
                }, nil
×
4372

NEW
4373
        default:
×
NEW
4374
                return sqlc.Channel{}, fmt.Errorf("unexpected row type in "+
×
NEW
4375
                        "extractChannel: %T", r)
×
4376
        }
4377
}
4378

4379
// extractNodes extracts the sqlc.Node records from the given row
4380
// which is expected to be a sqlc type that contains node information.
NEW
4381
func extractNodes(row any) (sqlc.Node, sqlc.Node, error) {
×
NEW
4382
        switch r := row.(type) {
×
NEW
4383
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
NEW
4384
                return sqlc.Node{
×
NEW
4385
                                ID:         r.Node1ID,
×
NEW
4386
                                Version:    r.Node1Version,
×
NEW
4387
                                PubKey:     r.Node1PubKey,
×
NEW
4388
                                Alias:      r.Node1Alias,
×
NEW
4389
                                LastUpdate: r.Node1LastUpdate,
×
NEW
4390
                                Color:      r.Node1Color,
×
NEW
4391
                                Signature:  r.Node1AnnSignature,
×
NEW
4392
                        }, sqlc.Node{
×
NEW
4393
                                ID:         r.Node2ID,
×
NEW
4394
                                Version:    r.Node2Version,
×
NEW
4395
                                PubKey:     r.Node2PubKey,
×
NEW
4396
                                Alias:      r.Node2Alias,
×
NEW
4397
                                LastUpdate: r.Node2LastUpdate,
×
NEW
4398
                                Color:      r.Node2Color,
×
NEW
4399
                                Signature:  r.Node2AnnSignature,
×
NEW
4400
                        }, nil
×
4401

NEW
4402
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
NEW
4403
                return sqlc.Node{
×
NEW
4404
                                ID:         r.Node1ID,
×
NEW
4405
                                Version:    r.Node1Version,
×
NEW
4406
                                PubKey:     r.Node1PubKey,
×
NEW
4407
                                Alias:      r.Node1Alias,
×
NEW
4408
                                LastUpdate: r.Node1LastUpdate,
×
NEW
4409
                                Color:      r.Node1Color,
×
NEW
4410
                                Signature:  r.Node1AnnSignature,
×
NEW
4411
                        }, sqlc.Node{
×
NEW
4412
                                ID:         r.Node2ID,
×
NEW
4413
                                Version:    r.Node2Version,
×
NEW
4414
                                PubKey:     r.Node2PubKey,
×
NEW
4415
                                Alias:      r.Node2Alias,
×
NEW
4416
                                LastUpdate: r.Node2LastUpdate,
×
UNCOV
4417
                                Color:      r.Node2Color,
×
UNCOV
4418
                                Signature:  r.Node2AnnSignature,
×
UNCOV
4419
                        }, nil
×
4420

UNCOV
4421
        default:
×
UNCOV
4422
                return sqlc.Node{}, sqlc.Node{}, fmt.Errorf("unexpected row "+
×
UNCOV
4423
                        "type in extractNodes: %T", r)
×
4424
        }
4425
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc