• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16032370868

02 Jul 2025 05:59PM UTC coverage: 67.589% (+0.05%) from 67.54%
16032370868

push

github

web-flow
Merge pull request #10017 from ellemouton/strictTypeForChanFeatures

refactor+multi: use *lnwire.FeatureVector for ChannelEdgeInfo features

16 of 31 new or added lines in 5 files covered. (51.61%)

41 existing lines in 13 files now uncovered.

135179 of 200001 relevant lines covered (67.59%)

21871.33 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
72
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
73
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
74

75
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
96
        GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error)
97
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
98
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
99
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
100
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
101
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
102
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
103
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
104
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
105
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
106
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
107
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
108
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
109
        DeleteChannel(ctx context.Context, id int64) error
110

111
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
112
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
113

114
        /*
115
                Channel Policy table queries.
116
        */
117
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
118
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.ChannelPolicy, error)
119
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
120

121
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
122
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
123
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
124

125
        /*
126
                Zombie index queries.
127
        */
128
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
129
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error)
130
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
131
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
132
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
133

134
        /*
135
                Prune log table queries.
136
        */
137
        GetPruneTip(ctx context.Context) (sqlc.PruneLog, error)
138
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
139
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
140

141
        /*
142
                Closed SCID table queries.
143
        */
144
        InsertClosedChannel(ctx context.Context, scid []byte) error
145
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
146
}
147

148
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
149
// database operations.
150
type BatchedSQLQueries interface {
151
        SQLQueries
152
        sqldb.BatchedTx[SQLQueries]
153
}
154

155
// SQLStore is an implementation of the V1Store interface that uses a SQL
156
// database as the backend.
157
type SQLStore struct {
158
        cfg *SQLStoreConfig
159
        db  BatchedSQLQueries
160

161
        // cacheMu guards all caches (rejectCache and chanCache). If
162
        // this mutex will be acquired at the same time as the DB mutex then
163
        // the cacheMu MUST be acquired first to prevent deadlock.
164
        cacheMu     sync.RWMutex
165
        rejectCache *rejectCache
166
        chanCache   *channelCache
167

168
        chanScheduler batch.Scheduler[SQLQueries]
169
        nodeScheduler batch.Scheduler[SQLQueries]
170

171
        srcNodes  map[ProtocolVersion]*srcNodeInfo
172
        srcNodeMu sync.Mutex
173
}
174

175
// A compile-time assertion to ensure that SQLStore implements the V1Store
176
// interface.
177
var _ V1Store = (*SQLStore)(nil)
178

179
// SQLStoreConfig holds the configuration for the SQLStore.
180
type SQLStoreConfig struct {
181
        // ChainHash is the genesis hash for the chain that all the gossip
182
        // messages in this store are aimed at.
183
        ChainHash chainhash.Hash
184
}
185

186
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
187
// storage backend.
188
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
189
        options ...StoreOptionModifier) (*SQLStore, error) {
×
190

×
191
        opts := DefaultOptions()
×
192
        for _, o := range options {
×
193
                o(opts)
×
194
        }
×
195

196
        if opts.NoMigration {
×
197
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
198
                        "supported for SQL stores")
×
199
        }
×
200

201
        s := &SQLStore{
×
202
                cfg:         cfg,
×
203
                db:          db,
×
204
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
205
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
206
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
207
        }
×
208

×
209
        s.chanScheduler = batch.NewTimeScheduler(
×
210
                db, &s.cacheMu, opts.BatchCommitInterval,
×
211
        )
×
212
        s.nodeScheduler = batch.NewTimeScheduler(
×
213
                db, nil, opts.BatchCommitInterval,
×
214
        )
×
215

×
216
        return s, nil
×
217
}
218

219
// AddLightningNode adds a vertex/node to the graph database. If the node is not
220
// in the database from before, this will add a new, unconnected one to the
221
// graph. If it is present from before, this will update that node's
222
// information.
223
//
224
// NOTE: part of the V1Store interface.
225
func (s *SQLStore) AddLightningNode(ctx context.Context,
226
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
227

×
228
        r := &batch.Request[SQLQueries]{
×
229
                Opts: batch.NewSchedulerOptions(opts...),
×
230
                Do: func(queries SQLQueries) error {
×
231
                        _, err := upsertNode(ctx, queries, node)
×
232
                        return err
×
233
                },
×
234
        }
235

236
        return s.nodeScheduler.Execute(ctx, r)
×
237
}
238

239
// FetchLightningNode attempts to look up a target node by its identity public
240
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
241
// returned.
242
//
243
// NOTE: part of the V1Store interface.
244
func (s *SQLStore) FetchLightningNode(ctx context.Context,
245
        pubKey route.Vertex) (*models.LightningNode, error) {
×
246

×
247
        var node *models.LightningNode
×
248
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
249
                var err error
×
250
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
251

×
252
                return err
×
253
        }, sqldb.NoOpReset)
×
254
        if err != nil {
×
255
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
256
        }
×
257

258
        return node, nil
×
259
}
260

261
// HasLightningNode determines if the graph has a vertex identified by the
262
// target node identity public key. If the node exists in the database, a
263
// timestamp of when the data for the node was lasted updated is returned along
264
// with a true boolean. Otherwise, an empty time.Time is returned with a false
265
// boolean.
266
//
267
// NOTE: part of the V1Store interface.
268
func (s *SQLStore) HasLightningNode(ctx context.Context,
269
        pubKey [33]byte) (time.Time, bool, error) {
×
270

×
271
        var (
×
272
                exists     bool
×
273
                lastUpdate time.Time
×
274
        )
×
275
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
276
                dbNode, err := db.GetNodeByPubKey(
×
277
                        ctx, sqlc.GetNodeByPubKeyParams{
×
278
                                Version: int16(ProtocolV1),
×
279
                                PubKey:  pubKey[:],
×
280
                        },
×
281
                )
×
282
                if errors.Is(err, sql.ErrNoRows) {
×
283
                        return nil
×
284
                } else if err != nil {
×
285
                        return fmt.Errorf("unable to fetch node: %w", err)
×
286
                }
×
287

288
                exists = true
×
289

×
290
                if dbNode.LastUpdate.Valid {
×
291
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
292
                }
×
293

294
                return nil
×
295
        }, sqldb.NoOpReset)
296
        if err != nil {
×
297
                return time.Time{}, false,
×
298
                        fmt.Errorf("unable to fetch node: %w", err)
×
299
        }
×
300

301
        return lastUpdate, exists, nil
×
302
}
303

304
// AddrsForNode returns all known addresses for the target node public key
305
// that the graph DB is aware of. The returned boolean indicates if the
306
// given node is unknown to the graph DB or not.
307
//
308
// NOTE: part of the V1Store interface.
309
func (s *SQLStore) AddrsForNode(ctx context.Context,
310
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
311

×
312
        var (
×
313
                addresses []net.Addr
×
314
                known     bool
×
315
        )
×
316
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
317
                var err error
×
318
                known, addresses, err = getNodeAddresses(
×
319
                        ctx, db, nodePub.SerializeCompressed(),
×
320
                )
×
321
                if err != nil {
×
322
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
323
                                err)
×
324
                }
×
325

326
                return nil
×
327
        }, sqldb.NoOpReset)
328
        if err != nil {
×
329
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
330
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
331
        }
×
332

333
        return known, addresses, nil
×
334
}
335

336
// DeleteLightningNode starts a new database transaction to remove a vertex/node
337
// from the database according to the node's public key.
338
//
339
// NOTE: part of the V1Store interface.
340
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
341
        pubKey route.Vertex) error {
×
342

×
343
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
344
                res, err := db.DeleteNodeByPubKey(
×
345
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
346
                                Version: int16(ProtocolV1),
×
347
                                PubKey:  pubKey[:],
×
348
                        },
×
349
                )
×
350
                if err != nil {
×
351
                        return err
×
352
                }
×
353

354
                rows, err := res.RowsAffected()
×
355
                if err != nil {
×
356
                        return err
×
357
                }
×
358

359
                if rows == 0 {
×
360
                        return ErrGraphNodeNotFound
×
361
                } else if rows > 1 {
×
362
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
363
                }
×
364

365
                return err
×
366
        }, sqldb.NoOpReset)
367
        if err != nil {
×
368
                return fmt.Errorf("unable to delete node: %w", err)
×
369
        }
×
370

371
        return nil
×
372
}
373

374
// FetchNodeFeatures returns the features of the given node. If no features are
375
// known for the node, an empty feature vector is returned.
376
//
377
// NOTE: this is part of the graphdb.NodeTraverser interface.
378
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
379
        *lnwire.FeatureVector, error) {
×
380

×
381
        ctx := context.TODO()
×
382

×
383
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
384
}
×
385

386
// DisabledChannelIDs returns the channel ids of disabled channels.
387
// A channel is disabled when two of the associated ChanelEdgePolicies
388
// have their disabled bit on.
389
//
390
// NOTE: part of the V1Store interface.
391
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
392
        var (
×
393
                ctx     = context.TODO()
×
394
                chanIDs []uint64
×
395
        )
×
396
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
397
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
398
                if err != nil {
×
399
                        return fmt.Errorf("unable to fetch disabled "+
×
400
                                "channels: %w", err)
×
401
                }
×
402

403
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
404

×
405
                return nil
×
406
        }, sqldb.NoOpReset)
407
        if err != nil {
×
408
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
409
                        err)
×
410
        }
×
411

412
        return chanIDs, nil
×
413
}
414

415
// LookupAlias attempts to return the alias as advertised by the target node.
416
//
417
// NOTE: part of the V1Store interface.
418
func (s *SQLStore) LookupAlias(ctx context.Context,
419
        pub *btcec.PublicKey) (string, error) {
×
420

×
421
        var alias string
×
422
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
423
                dbNode, err := db.GetNodeByPubKey(
×
424
                        ctx, sqlc.GetNodeByPubKeyParams{
×
425
                                Version: int16(ProtocolV1),
×
426
                                PubKey:  pub.SerializeCompressed(),
×
427
                        },
×
428
                )
×
429
                if errors.Is(err, sql.ErrNoRows) {
×
430
                        return ErrNodeAliasNotFound
×
431
                } else if err != nil {
×
432
                        return fmt.Errorf("unable to fetch node: %w", err)
×
433
                }
×
434

435
                if !dbNode.Alias.Valid {
×
436
                        return ErrNodeAliasNotFound
×
437
                }
×
438

439
                alias = dbNode.Alias.String
×
440

×
441
                return nil
×
442
        }, sqldb.NoOpReset)
443
        if err != nil {
×
444
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
445
        }
×
446

447
        return alias, nil
×
448
}
449

450
// SourceNode returns the source node of the graph. The source node is treated
451
// as the center node within a star-graph. This method may be used to kick off
452
// a path finding algorithm in order to explore the reachability of another
453
// node based off the source node.
454
//
455
// NOTE: part of the V1Store interface.
456
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
457
        error) {
×
458

×
459
        var node *models.LightningNode
×
460
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
461
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
462
                if err != nil {
×
463
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
464
                                err)
×
465
                }
×
466

467
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
468

×
469
                return err
×
470
        }, sqldb.NoOpReset)
471
        if err != nil {
×
472
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
473
        }
×
474

475
        return node, nil
×
476
}
477

478
// SetSourceNode sets the source node within the graph database. The source
479
// node is to be used as the center of a star-graph within path finding
480
// algorithms.
481
//
482
// NOTE: part of the V1Store interface.
483
func (s *SQLStore) SetSourceNode(ctx context.Context,
484
        node *models.LightningNode) error {
×
485

×
486
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
487
                id, err := upsertNode(ctx, db, node)
×
488
                if err != nil {
×
489
                        return fmt.Errorf("unable to upsert source node: %w",
×
490
                                err)
×
491
                }
×
492

493
                // Make sure that if a source node for this version is already
494
                // set, then the ID is the same as the one we are about to set.
495
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
496
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
497
                        return fmt.Errorf("unable to fetch source node: %w",
×
498
                                err)
×
499
                } else if err == nil {
×
500
                        if dbSourceNodeID != id {
×
501
                                return fmt.Errorf("v1 source node already "+
×
502
                                        "set to a different node: %d vs %d",
×
503
                                        dbSourceNodeID, id)
×
504
                        }
×
505

506
                        return nil
×
507
                }
508

509
                return db.AddSourceNode(ctx, id)
×
510
        }, sqldb.NoOpReset)
511
}
512

513
// NodeUpdatesInHorizon returns all the known lightning node which have an
514
// update timestamp within the passed range. This method can be used by two
515
// nodes to quickly determine if they have the same set of up to date node
516
// announcements.
517
//
518
// NOTE: This is part of the V1Store interface.
519
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
520
        endTime time.Time) ([]models.LightningNode, error) {
×
521

×
522
        ctx := context.TODO()
×
523

×
524
        var nodes []models.LightningNode
×
525
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
526
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
527
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
528
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
529
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
530
                        },
×
531
                )
×
532
                if err != nil {
×
533
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
534
                }
×
535

536
                for _, dbNode := range dbNodes {
×
537
                        node, err := buildNode(ctx, db, &dbNode)
×
538
                        if err != nil {
×
539
                                return fmt.Errorf("unable to build node: %w",
×
540
                                        err)
×
541
                        }
×
542

543
                        nodes = append(nodes, *node)
×
544
                }
545

546
                return nil
×
547
        }, sqldb.NoOpReset)
548
        if err != nil {
×
549
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
550
        }
×
551

552
        return nodes, nil
×
553
}
554

555
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
556
// undirected edge from the two target nodes are created. The information stored
557
// denotes the static attributes of the channel, such as the channelID, the keys
558
// involved in creation of the channel, and the set of features that the channel
559
// supports. The chanPoint and chanID are used to uniquely identify the edge
560
// globally within the database.
561
//
562
// NOTE: part of the V1Store interface.
563
func (s *SQLStore) AddChannelEdge(ctx context.Context,
564
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
565

×
566
        var alreadyExists bool
×
567
        r := &batch.Request[SQLQueries]{
×
568
                Opts: batch.NewSchedulerOptions(opts...),
×
569
                Reset: func() {
×
570
                        alreadyExists = false
×
571
                },
×
572
                Do: func(tx SQLQueries) error {
×
573
                        _, err := insertChannel(ctx, tx, edge)
×
574

×
575
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
576
                        // succeed, but propagate the error via local state.
×
577
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
578
                                alreadyExists = true
×
579
                                return nil
×
580
                        }
×
581

582
                        return err
×
583
                },
584
                OnCommit: func(err error) error {
×
585
                        switch {
×
586
                        case err != nil:
×
587
                                return err
×
588
                        case alreadyExists:
×
589
                                return ErrEdgeAlreadyExist
×
590
                        default:
×
591
                                s.rejectCache.remove(edge.ChannelID)
×
592
                                s.chanCache.remove(edge.ChannelID)
×
593
                                return nil
×
594
                        }
595
                },
596
        }
597

598
        return s.chanScheduler.Execute(ctx, r)
×
599
}
600

601
// HighestChanID returns the "highest" known channel ID in the channel graph.
602
// This represents the "newest" channel from the PoV of the chain. This method
603
// can be used by peers to quickly determine if their graphs are in sync.
604
//
605
// NOTE: This is part of the V1Store interface.
606
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
607
        var highestChanID uint64
×
608
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
609
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
610
                if errors.Is(err, sql.ErrNoRows) {
×
611
                        return nil
×
612
                } else if err != nil {
×
613
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
614
                                err)
×
615
                }
×
616

617
                highestChanID = byteOrder.Uint64(chanID)
×
618

×
619
                return nil
×
620
        }, sqldb.NoOpReset)
621
        if err != nil {
×
622
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
623
        }
×
624

625
        return highestChanID, nil
×
626
}
627

628
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
629
// within the database for the referenced channel. The `flags` attribute within
630
// the ChannelEdgePolicy determines which of the directed edges are being
631
// updated. If the flag is 1, then the first node's information is being
632
// updated, otherwise it's the second node's information. The node ordering is
633
// determined by the lexicographical ordering of the identity public keys of the
634
// nodes on either side of the channel.
635
//
636
// NOTE: part of the V1Store interface.
637
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
638
        edge *models.ChannelEdgePolicy,
639
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
640

×
641
        var (
×
642
                isUpdate1    bool
×
643
                edgeNotFound bool
×
644
                from, to     route.Vertex
×
645
        )
×
646

×
647
        r := &batch.Request[SQLQueries]{
×
648
                Opts: batch.NewSchedulerOptions(opts...),
×
649
                Reset: func() {
×
650
                        isUpdate1 = false
×
651
                        edgeNotFound = false
×
652
                },
×
653
                Do: func(tx SQLQueries) error {
×
654
                        var err error
×
655
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
656
                                ctx, tx, edge,
×
657
                        )
×
658
                        if err != nil {
×
659
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
660
                        }
×
661

662
                        // Silence ErrEdgeNotFound so that the batch can
663
                        // succeed, but propagate the error via local state.
664
                        if errors.Is(err, ErrEdgeNotFound) {
×
665
                                edgeNotFound = true
×
666
                                return nil
×
667
                        }
×
668

669
                        return err
×
670
                },
671
                OnCommit: func(err error) error {
×
672
                        switch {
×
673
                        case err != nil:
×
674
                                return err
×
675
                        case edgeNotFound:
×
676
                                return ErrEdgeNotFound
×
677
                        default:
×
678
                                s.updateEdgeCache(edge, isUpdate1)
×
679
                                return nil
×
680
                        }
681
                },
682
        }
683

684
        err := s.chanScheduler.Execute(ctx, r)
×
685

×
686
        return from, to, err
×
687
}
688

689
// updateEdgeCache updates our reject and channel caches with the new
690
// edge policy information.
691
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
692
        isUpdate1 bool) {
×
693

×
694
        // If an entry for this channel is found in reject cache, we'll modify
×
695
        // the entry with the updated timestamp for the direction that was just
×
696
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
697
        // during the next query for this edge.
×
698
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
699
                if isUpdate1 {
×
700
                        entry.upd1Time = e.LastUpdate.Unix()
×
701
                } else {
×
702
                        entry.upd2Time = e.LastUpdate.Unix()
×
703
                }
×
704
                s.rejectCache.insert(e.ChannelID, entry)
×
705
        }
706

707
        // If an entry for this channel is found in channel cache, we'll modify
708
        // the entry with the updated policy for the direction that was just
709
        // written. If the edge doesn't exist, we'll defer loading the info and
710
        // policies and lazily read from disk during the next query.
711
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
712
                if isUpdate1 {
×
713
                        channel.Policy1 = e
×
714
                } else {
×
715
                        channel.Policy2 = e
×
716
                }
×
717
                s.chanCache.insert(e.ChannelID, channel)
×
718
        }
719
}
720

721
// ForEachSourceNodeChannel iterates through all channels of the source node,
722
// executing the passed callback on each. The call-back is provided with the
723
// channel's outpoint, whether we have a policy for the channel and the channel
724
// peer's node information.
725
//
726
// NOTE: part of the V1Store interface.
727
func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint,
728
        havePolicy bool, otherNode *models.LightningNode) error) error {
×
729

×
730
        var ctx = context.TODO()
×
731

×
732
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
733
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
734
                if err != nil {
×
735
                        return fmt.Errorf("unable to fetch source node: %w",
×
736
                                err)
×
737
                }
×
738

739
                return forEachNodeChannel(
×
740
                        ctx, db, s.cfg.ChainHash, nodeID,
×
741
                        func(info *models.ChannelEdgeInfo,
×
742
                                outPolicy *models.ChannelEdgePolicy,
×
743
                                _ *models.ChannelEdgePolicy) error {
×
744

×
745
                                // Fetch the other node.
×
746
                                var (
×
747
                                        otherNodePub [33]byte
×
748
                                        node1        = info.NodeKey1Bytes
×
749
                                        node2        = info.NodeKey2Bytes
×
750
                                )
×
751
                                switch {
×
752
                                case bytes.Equal(node1[:], nodePub[:]):
×
753
                                        otherNodePub = node2
×
754
                                case bytes.Equal(node2[:], nodePub[:]):
×
755
                                        otherNodePub = node1
×
756
                                default:
×
757
                                        return fmt.Errorf("node not " +
×
758
                                                "participating in this channel")
×
759
                                }
760

761
                                _, otherNode, err := getNodeByPubKey(
×
762
                                        ctx, db, otherNodePub,
×
763
                                )
×
764
                                if err != nil {
×
765
                                        return fmt.Errorf("unable to fetch "+
×
766
                                                "other node(%x): %w",
×
767
                                                otherNodePub, err)
×
768
                                }
×
769

770
                                return cb(
×
771
                                        info.ChannelPoint, outPolicy != nil,
×
772
                                        otherNode,
×
773
                                )
×
774
                        },
775
                )
776
        }, sqldb.NoOpReset)
777
}
778

779
// ForEachNode iterates through all the stored vertices/nodes in the graph,
780
// executing the passed callback with each node encountered. If the callback
781
// returns an error, then the transaction is aborted and the iteration stops
782
// early. Any operations performed on the NodeTx passed to the call-back are
783
// executed under the same read transaction and so, methods on the NodeTx object
784
// _MUST_ only be called from within the call-back.
785
//
786
// NOTE: part of the V1Store interface.
787
func (s *SQLStore) ForEachNode(cb func(tx NodeRTx) error) error {
×
788
        var (
×
789
                ctx          = context.TODO()
×
790
                lastID int64 = 0
×
791
        )
×
792

×
793
        handleNode := func(db SQLQueries, dbNode sqlc.Node) error {
×
794
                node, err := buildNode(ctx, db, &dbNode)
×
795
                if err != nil {
×
796
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
797
                                dbNode.ID, err)
×
798
                }
×
799

800
                err = cb(
×
801
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
802
                )
×
803
                if err != nil {
×
804
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
805
                                dbNode.ID, err)
×
806
                }
×
807

808
                return nil
×
809
        }
810

811
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
812
                for {
×
813
                        nodes, err := db.ListNodesPaginated(
×
814
                                ctx, sqlc.ListNodesPaginatedParams{
×
815
                                        Version: int16(ProtocolV1),
×
816
                                        ID:      lastID,
×
817
                                        Limit:   pageSize,
×
818
                                },
×
819
                        )
×
820
                        if err != nil {
×
821
                                return fmt.Errorf("unable to fetch nodes: %w",
×
822
                                        err)
×
823
                        }
×
824

825
                        if len(nodes) == 0 {
×
826
                                break
×
827
                        }
828

829
                        for _, dbNode := range nodes {
×
830
                                err = handleNode(db, dbNode)
×
831
                                if err != nil {
×
832
                                        return err
×
833
                                }
×
834

835
                                lastID = dbNode.ID
×
836
                        }
837
                }
838

839
                return nil
×
840
        }, sqldb.NoOpReset)
841
}
842

843
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
844
// SQLStore and a SQL transaction.
845
type sqlGraphNodeTx struct {
846
        db    SQLQueries
847
        id    int64
848
        node  *models.LightningNode
849
        chain chainhash.Hash
850
}
851

852
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
853
// interface.
854
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
855

856
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
857
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
858

×
859
        return &sqlGraphNodeTx{
×
860
                db:    db,
×
861
                chain: chain,
×
862
                id:    id,
×
863
                node:  node,
×
864
        }
×
865
}
×
866

867
// Node returns the raw information of the node.
868
//
869
// NOTE: This is a part of the NodeRTx interface.
870
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
871
        return s.node
×
872
}
×
873

874
// ForEachChannel can be used to iterate over the node's channels under the same
875
// transaction used to fetch the node.
876
//
877
// NOTE: This is a part of the NodeRTx interface.
878
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
879
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
880

×
881
        ctx := context.TODO()
×
882

×
883
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
884
}
×
885

886
// FetchNode fetches the node with the given pub key under the same transaction
887
// used to fetch the current node. The returned node is also a NodeRTx and any
888
// operations on that NodeRTx will also be done under the same transaction.
889
//
890
// NOTE: This is a part of the NodeRTx interface.
891
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
892
        ctx := context.TODO()
×
893

×
894
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
895
        if err != nil {
×
896
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
897
                        nodePub, err)
×
898
        }
×
899

900
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
901
}
902

903
// ForEachNodeDirectedChannel iterates through all channels of a given node,
904
// executing the passed callback on the directed edge representing the channel
905
// and its incoming policy. If the callback returns an error, then the iteration
906
// is halted with the error propagated back up to the caller.
907
//
908
// Unknown policies are passed into the callback as nil values.
909
//
910
// NOTE: this is part of the graphdb.NodeTraverser interface.
911
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
912
        cb func(channel *DirectedChannel) error) error {
×
913

×
914
        var ctx = context.TODO()
×
915

×
916
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
917
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
918
        }, sqldb.NoOpReset)
×
919
}
920

921
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
922
// graph, executing the passed callback with each node encountered. If the
923
// callback returns an error, then the transaction is aborted and the iteration
924
// stops early.
925
//
926
// NOTE: This is a part of the V1Store interface.
927
func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex,
928
        *lnwire.FeatureVector) error) error {
×
929

×
930
        ctx := context.TODO()
×
931

×
932
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
933
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
934
                        nodePub route.Vertex) error {
×
935

×
936
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
937
                        if err != nil {
×
938
                                return fmt.Errorf("unable to fetch node "+
×
939
                                        "features: %w", err)
×
940
                        }
×
941

942
                        return cb(nodePub, features)
×
943
                })
944
        }, sqldb.NoOpReset)
945
        if err != nil {
×
946
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
947
        }
×
948

949
        return nil
×
950
}
951

952
// ForEachNodeChannel iterates through all channels of the given node,
953
// executing the passed callback with an edge info structure and the policies
954
// of each end of the channel. The first edge policy is the outgoing edge *to*
955
// the connecting node, while the second is the incoming edge *from* the
956
// connecting node. If the callback returns an error, then the iteration is
957
// halted with the error propagated back up to the caller.
958
//
959
// Unknown policies are passed into the callback as nil values.
960
//
961
// NOTE: part of the V1Store interface.
962
func (s *SQLStore) ForEachNodeChannel(nodePub route.Vertex,
963
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
964
                *models.ChannelEdgePolicy) error) error {
×
965

×
966
        var ctx = context.TODO()
×
967

×
968
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
969
                dbNode, err := db.GetNodeByPubKey(
×
970
                        ctx, sqlc.GetNodeByPubKeyParams{
×
971
                                Version: int16(ProtocolV1),
×
972
                                PubKey:  nodePub[:],
×
973
                        },
×
974
                )
×
975
                if errors.Is(err, sql.ErrNoRows) {
×
976
                        return nil
×
977
                } else if err != nil {
×
978
                        return fmt.Errorf("unable to fetch node: %w", err)
×
979
                }
×
980

981
                return forEachNodeChannel(
×
982
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
983
                )
×
984
        }, sqldb.NoOpReset)
985
}
986

987
// ChanUpdatesInHorizon returns all the known channel edges which have at least
988
// one edge that has an update timestamp within the specified horizon.
989
//
990
// NOTE: This is part of the V1Store interface.
991
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
992
        endTime time.Time) ([]ChannelEdge, error) {
×
993

×
994
        s.cacheMu.Lock()
×
995
        defer s.cacheMu.Unlock()
×
996

×
997
        var (
×
998
                ctx = context.TODO()
×
999
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
1000
                // an additional map to keep track of the edges already seen to
×
1001
                // prevent re-adding it.
×
1002
                edgesSeen    = make(map[uint64]struct{})
×
1003
                edgesToCache = make(map[uint64]ChannelEdge)
×
1004
                edges        []ChannelEdge
×
1005
                hits         int
×
1006
        )
×
1007
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1008
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1009
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1010
                                Version:   int16(ProtocolV1),
×
1011
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1012
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1013
                        },
×
1014
                )
×
1015
                if err != nil {
×
1016
                        return err
×
1017
                }
×
1018

1019
                for _, row := range rows {
×
1020
                        // If we've already retrieved the info and policies for
×
1021
                        // this edge, then we can skip it as we don't need to do
×
1022
                        // so again.
×
1023
                        chanIDInt := byteOrder.Uint64(row.Channel.Scid)
×
1024
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1025
                                continue
×
1026
                        }
1027

1028
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1029
                                hits++
×
1030
                                edgesSeen[chanIDInt] = struct{}{}
×
1031
                                edges = append(edges, channel)
×
1032

×
1033
                                continue
×
1034
                        }
1035

1036
                        node1, node2, err := buildNodes(
×
1037
                                ctx, db, row.Node, row.Node_2,
×
1038
                        )
×
1039
                        if err != nil {
×
1040
                                return err
×
1041
                        }
×
1042

1043
                        channel, err := getAndBuildEdgeInfo(
×
1044
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1045
                                row.Channel, node1.PubKeyBytes,
×
1046
                                node2.PubKeyBytes,
×
1047
                        )
×
1048
                        if err != nil {
×
1049
                                return fmt.Errorf("unable to build channel "+
×
1050
                                        "info: %w", err)
×
1051
                        }
×
1052

1053
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1054
                        if err != nil {
×
1055
                                return fmt.Errorf("unable to extract channel "+
×
1056
                                        "policies: %w", err)
×
1057
                        }
×
1058

1059
                        p1, p2, err := getAndBuildChanPolicies(
×
1060
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1061
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1062
                        )
×
1063
                        if err != nil {
×
1064
                                return fmt.Errorf("unable to build channel "+
×
1065
                                        "policies: %w", err)
×
1066
                        }
×
1067

1068
                        edgesSeen[chanIDInt] = struct{}{}
×
1069
                        chanEdge := ChannelEdge{
×
1070
                                Info:    channel,
×
1071
                                Policy1: p1,
×
1072
                                Policy2: p2,
×
1073
                                Node1:   node1,
×
1074
                                Node2:   node2,
×
1075
                        }
×
1076
                        edges = append(edges, chanEdge)
×
1077
                        edgesToCache[chanIDInt] = chanEdge
×
1078
                }
1079

1080
                return nil
×
1081
        }, func() {
×
1082
                edgesSeen = make(map[uint64]struct{})
×
1083
                edgesToCache = make(map[uint64]ChannelEdge)
×
1084
                edges = nil
×
1085
        })
×
1086
        if err != nil {
×
1087
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1088
        }
×
1089

1090
        // Insert any edges loaded from disk into the cache.
1091
        for chanid, channel := range edgesToCache {
×
1092
                s.chanCache.insert(chanid, channel)
×
1093
        }
×
1094

1095
        if len(edges) > 0 {
×
1096
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1097
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1098
        } else {
×
1099
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1100
                        "horizon (%s, %s)", startTime, endTime)
×
1101
        }
×
1102

1103
        return edges, nil
×
1104
}
1105

1106
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1107
// data to the call-back.
1108
//
1109
// NOTE: The callback contents MUST not be modified.
1110
//
1111
// NOTE: part of the V1Store interface.
1112
func (s *SQLStore) ForEachNodeCached(cb func(node route.Vertex,
1113
        chans map[uint64]*DirectedChannel) error) error {
×
1114

×
1115
        var ctx = context.TODO()
×
1116

×
1117
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1118
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1119
                        nodePub route.Vertex) error {
×
1120

×
1121
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1122
                        if err != nil {
×
1123
                                return fmt.Errorf("unable to fetch "+
×
1124
                                        "node(id=%d) features: %w", nodeID, err)
×
1125
                        }
×
1126

1127
                        toNodeCallback := func() route.Vertex {
×
1128
                                return nodePub
×
1129
                        }
×
1130

1131
                        rows, err := db.ListChannelsByNodeID(
×
1132
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1133
                                        Version: int16(ProtocolV1),
×
1134
                                        NodeID1: nodeID,
×
1135
                                },
×
1136
                        )
×
1137
                        if err != nil {
×
1138
                                return fmt.Errorf("unable to fetch channels "+
×
1139
                                        "of node(id=%d): %w", nodeID, err)
×
1140
                        }
×
1141

1142
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1143
                        for _, row := range rows {
×
1144
                                node1, node2, err := buildNodeVertices(
×
1145
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1146
                                )
×
1147
                                if err != nil {
×
1148
                                        return err
×
1149
                                }
×
1150

1151
                                e, err := getAndBuildEdgeInfo(
×
1152
                                        ctx, db, s.cfg.ChainHash,
×
1153
                                        row.Channel.ID, row.Channel, node1,
×
1154
                                        node2,
×
1155
                                )
×
1156
                                if err != nil {
×
1157
                                        return fmt.Errorf("unable to build "+
×
1158
                                                "channel info: %w", err)
×
1159
                                }
×
1160

1161
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1162
                                        row,
×
1163
                                )
×
1164
                                if err != nil {
×
1165
                                        return fmt.Errorf("unable to "+
×
1166
                                                "extract channel "+
×
1167
                                                "policies: %w", err)
×
1168
                                }
×
1169

1170
                                p1, p2, err := getAndBuildChanPolicies(
×
1171
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1172
                                        node1, node2,
×
1173
                                )
×
1174
                                if err != nil {
×
1175
                                        return fmt.Errorf("unable to "+
×
1176
                                                "build channel policies: %w",
×
1177
                                                err)
×
1178
                                }
×
1179

1180
                                // Determine the outgoing and incoming policy
1181
                                // for this channel and node combo.
1182
                                outPolicy, inPolicy := p1, p2
×
1183
                                if p1 != nil && p1.ToNode == nodePub {
×
1184
                                        outPolicy, inPolicy = p2, p1
×
1185
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1186
                                        outPolicy, inPolicy = p2, p1
×
1187
                                }
×
1188

1189
                                var cachedInPolicy *models.CachedEdgePolicy
×
1190
                                if inPolicy != nil {
×
1191
                                        cachedInPolicy = models.NewCachedPolicy(
×
1192
                                                p2,
×
1193
                                        )
×
1194
                                        cachedInPolicy.ToNodePubKey =
×
1195
                                                toNodeCallback
×
1196
                                        cachedInPolicy.ToNodeFeatures =
×
1197
                                                features
×
1198
                                }
×
1199

1200
                                var inboundFee lnwire.Fee
×
1201
                                outPolicy.InboundFee.WhenSome(
×
1202
                                        func(fee lnwire.Fee) {
×
1203
                                                inboundFee = fee
×
1204
                                        },
×
1205
                                )
1206

1207
                                directedChannel := &DirectedChannel{
×
1208
                                        ChannelID: e.ChannelID,
×
1209
                                        IsNode1: nodePub ==
×
1210
                                                e.NodeKey1Bytes,
×
1211
                                        OtherNode:    e.NodeKey2Bytes,
×
1212
                                        Capacity:     e.Capacity,
×
1213
                                        OutPolicySet: p1 != nil,
×
1214
                                        InPolicy:     cachedInPolicy,
×
1215
                                        InboundFee:   inboundFee,
×
1216
                                }
×
1217

×
1218
                                if nodePub == e.NodeKey2Bytes {
×
1219
                                        directedChannel.OtherNode =
×
1220
                                                e.NodeKey1Bytes
×
1221
                                }
×
1222

1223
                                channels[e.ChannelID] = directedChannel
×
1224
                        }
1225

1226
                        return cb(nodePub, channels)
×
1227
                })
1228
        }, sqldb.NoOpReset)
1229
}
1230

1231
// ForEachChannelCacheable iterates through all the channel edges stored
1232
// within the graph and invokes the passed callback for each edge. The
1233
// callback takes two edges as since this is a directed graph, both the
1234
// in/out edges are visited. If the callback returns an error, then the
1235
// transaction is aborted and the iteration stops early.
1236
//
1237
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1238
// pointer for that particular channel edge routing policy will be
1239
// passed into the callback.
1240
//
1241
// NOTE: this method is like ForEachChannel but fetches only the data
1242
// required for the graph cache.
1243
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1244
        *models.CachedEdgePolicy,
1245
        *models.CachedEdgePolicy) error) error {
×
1246

×
1247
        ctx := context.TODO()
×
1248

×
1249
        handleChannel := func(db SQLQueries,
×
1250
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1251

×
1252
                node1, node2, err := buildNodeVertices(
×
1253
                        row.Node1Pubkey, row.Node2Pubkey,
×
1254
                )
×
1255
                if err != nil {
×
1256
                        return err
×
1257
                }
×
1258

1259
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
1260

×
1261
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1262
                if err != nil {
×
1263
                        return err
×
1264
                }
×
1265

1266
                var pol1, pol2 *models.CachedEdgePolicy
×
1267
                if dbPol1 != nil {
×
1268
                        policy1, err := buildChanPolicy(
×
1269
                                *dbPol1, edge.ChannelID, nil, node2,
×
1270
                        )
×
1271
                        if err != nil {
×
1272
                                return err
×
1273
                        }
×
1274

1275
                        pol1 = models.NewCachedPolicy(policy1)
×
1276
                }
1277
                if dbPol2 != nil {
×
1278
                        policy2, err := buildChanPolicy(
×
1279
                                *dbPol2, edge.ChannelID, nil, node1,
×
1280
                        )
×
1281
                        if err != nil {
×
1282
                                return err
×
1283
                        }
×
1284

1285
                        pol2 = models.NewCachedPolicy(policy2)
×
1286
                }
1287

1288
                if err := cb(edge, pol1, pol2); err != nil {
×
1289
                        return err
×
1290
                }
×
1291

1292
                return nil
×
1293
        }
1294

1295
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1296
                lastID := int64(-1)
×
1297
                for {
×
1298
                        //nolint:ll
×
1299
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1300
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1301
                                        Version: int16(ProtocolV1),
×
1302
                                        ID:      lastID,
×
1303
                                        Limit:   pageSize,
×
1304
                                },
×
1305
                        )
×
1306
                        if err != nil {
×
1307
                                return err
×
1308
                        }
×
1309

1310
                        if len(rows) == 0 {
×
1311
                                break
×
1312
                        }
1313

1314
                        for _, row := range rows {
×
1315
                                err := handleChannel(db, row)
×
1316
                                if err != nil {
×
1317
                                        return err
×
1318
                                }
×
1319

1320
                                lastID = row.Channel.ID
×
1321
                        }
1322
                }
1323

1324
                return nil
×
1325
        }, sqldb.NoOpReset)
1326
}
1327

1328
// ForEachChannel iterates through all the channel edges stored within the
1329
// graph and invokes the passed callback for each edge. The callback takes two
1330
// edges as since this is a directed graph, both the in/out edges are visited.
1331
// If the callback returns an error, then the transaction is aborted and the
1332
// iteration stops early.
1333
//
1334
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1335
// for that particular channel edge routing policy will be passed into the
1336
// callback.
1337
//
1338
// NOTE: part of the V1Store interface.
1339
func (s *SQLStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
1340
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
1341

×
1342
        ctx := context.TODO()
×
1343

×
1344
        handleChannel := func(db SQLQueries,
×
1345
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1346

×
1347
                node1, node2, err := buildNodeVertices(
×
1348
                        row.Node1Pubkey, row.Node2Pubkey,
×
1349
                )
×
1350
                if err != nil {
×
1351
                        return fmt.Errorf("unable to build node vertices: %w",
×
1352
                                err)
×
1353
                }
×
1354

1355
                edge, err := getAndBuildEdgeInfo(
×
1356
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1357
                        node1, node2,
×
1358
                )
×
1359
                if err != nil {
×
1360
                        return fmt.Errorf("unable to build channel info: %w",
×
1361
                                err)
×
1362
                }
×
1363

1364
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1365
                if err != nil {
×
1366
                        return fmt.Errorf("unable to extract channel "+
×
1367
                                "policies: %w", err)
×
1368
                }
×
1369

1370
                p1, p2, err := getAndBuildChanPolicies(
×
1371
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1372
                )
×
1373
                if err != nil {
×
1374
                        return fmt.Errorf("unable to build channel "+
×
1375
                                "policies: %w", err)
×
1376
                }
×
1377

1378
                err = cb(edge, p1, p2)
×
1379
                if err != nil {
×
1380
                        return fmt.Errorf("callback failed for channel "+
×
1381
                                "id=%d: %w", edge.ChannelID, err)
×
1382
                }
×
1383

1384
                return nil
×
1385
        }
1386

1387
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1388
                lastID := int64(-1)
×
1389
                for {
×
1390
                        //nolint:ll
×
1391
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1392
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1393
                                        Version: int16(ProtocolV1),
×
1394
                                        ID:      lastID,
×
1395
                                        Limit:   pageSize,
×
1396
                                },
×
1397
                        )
×
1398
                        if err != nil {
×
1399
                                return err
×
1400
                        }
×
1401

1402
                        if len(rows) == 0 {
×
1403
                                break
×
1404
                        }
1405

1406
                        for _, row := range rows {
×
1407
                                err := handleChannel(db, row)
×
1408
                                if err != nil {
×
1409
                                        return err
×
1410
                                }
×
1411

1412
                                lastID = row.Channel.ID
×
1413
                        }
1414
                }
1415

1416
                return nil
×
1417
        }, sqldb.NoOpReset)
1418
}
1419

1420
// FilterChannelRange returns the channel ID's of all known channels which were
1421
// mined in a block height within the passed range. The channel IDs are grouped
1422
// by their common block height. This method can be used to quickly share with a
1423
// peer the set of channels we know of within a particular range to catch them
1424
// up after a period of time offline. If withTimestamps is true then the
1425
// timestamp info of the latest received channel update messages of the channel
1426
// will be included in the response.
1427
//
1428
// NOTE: This is part of the V1Store interface.
1429
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1430
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1431

×
1432
        var (
×
1433
                ctx       = context.TODO()
×
1434
                startSCID = &lnwire.ShortChannelID{
×
1435
                        BlockHeight: startHeight,
×
1436
                }
×
1437
                endSCID = lnwire.ShortChannelID{
×
1438
                        BlockHeight: endHeight,
×
1439
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1440
                        TxPosition:  math.MaxUint16,
×
1441
                }
×
1442
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1443
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1444
        )
×
1445

×
1446
        // 1) get all channels where channelID is between start and end chan ID.
×
1447
        // 2) skip if not public (ie, no channel_proof)
×
1448
        // 3) collect that channel.
×
1449
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1450
        //    and add those timestamps to the collected channel.
×
1451
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1452
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1453
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1454
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1455
                                StartScid: chanIDStart,
×
1456
                                EndScid:   chanIDEnd,
×
1457
                        },
×
1458
                )
×
1459
                if err != nil {
×
1460
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1461
                                err)
×
1462
                }
×
1463

1464
                for _, dbChan := range dbChans {
×
1465
                        cid := lnwire.NewShortChanIDFromInt(
×
1466
                                byteOrder.Uint64(dbChan.Scid),
×
1467
                        )
×
1468
                        chanInfo := NewChannelUpdateInfo(
×
1469
                                cid, time.Time{}, time.Time{},
×
1470
                        )
×
1471

×
1472
                        if !withTimestamps {
×
1473
                                channelsPerBlock[cid.BlockHeight] = append(
×
1474
                                        channelsPerBlock[cid.BlockHeight],
×
1475
                                        chanInfo,
×
1476
                                )
×
1477

×
1478
                                continue
×
1479
                        }
1480

1481
                        //nolint:ll
1482
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1483
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1484
                                        Version:   int16(ProtocolV1),
×
1485
                                        ChannelID: dbChan.ID,
×
1486
                                        NodeID:    dbChan.NodeID1,
×
1487
                                },
×
1488
                        )
×
1489
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1490
                                return fmt.Errorf("unable to fetch node1 "+
×
1491
                                        "policy: %w", err)
×
1492
                        } else if err == nil {
×
1493
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1494
                                        node1Policy.LastUpdate.Int64, 0,
×
1495
                                )
×
1496
                        }
×
1497

1498
                        //nolint:ll
1499
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1500
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1501
                                        Version:   int16(ProtocolV1),
×
1502
                                        ChannelID: dbChan.ID,
×
1503
                                        NodeID:    dbChan.NodeID2,
×
1504
                                },
×
1505
                        )
×
1506
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1507
                                return fmt.Errorf("unable to fetch node2 "+
×
1508
                                        "policy: %w", err)
×
1509
                        } else if err == nil {
×
1510
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1511
                                        node2Policy.LastUpdate.Int64, 0,
×
1512
                                )
×
1513
                        }
×
1514

1515
                        channelsPerBlock[cid.BlockHeight] = append(
×
1516
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1517
                        )
×
1518
                }
1519

1520
                return nil
×
1521
        }, func() {
×
1522
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1523
        })
×
1524
        if err != nil {
×
1525
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1526
        }
×
1527

1528
        if len(channelsPerBlock) == 0 {
×
1529
                return nil, nil
×
1530
        }
×
1531

1532
        // Return the channel ranges in ascending block height order.
1533
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1534
        slices.Sort(blocks)
×
1535

×
1536
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1537
                return BlockChannelRange{
×
1538
                        Height:   block,
×
1539
                        Channels: channelsPerBlock[block],
×
1540
                }
×
1541
        }), nil
×
1542
}
1543

1544
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1545
// zombie. This method is used on an ad-hoc basis, when channels need to be
1546
// marked as zombies outside the normal pruning cycle.
1547
//
1548
// NOTE: part of the V1Store interface.
1549
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1550
        pubKey1, pubKey2 [33]byte) error {
×
1551

×
1552
        ctx := context.TODO()
×
1553

×
1554
        s.cacheMu.Lock()
×
1555
        defer s.cacheMu.Unlock()
×
1556

×
1557
        chanIDB := channelIDToBytes(chanID)
×
1558

×
1559
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1560
                return db.UpsertZombieChannel(
×
1561
                        ctx, sqlc.UpsertZombieChannelParams{
×
1562
                                Version:  int16(ProtocolV1),
×
1563
                                Scid:     chanIDB,
×
1564
                                NodeKey1: pubKey1[:],
×
1565
                                NodeKey2: pubKey2[:],
×
1566
                        },
×
1567
                )
×
1568
        }, sqldb.NoOpReset)
×
1569
        if err != nil {
×
1570
                return fmt.Errorf("unable to upsert zombie channel "+
×
1571
                        "(channel_id=%d): %w", chanID, err)
×
1572
        }
×
1573

1574
        s.rejectCache.remove(chanID)
×
1575
        s.chanCache.remove(chanID)
×
1576

×
1577
        return nil
×
1578
}
1579

1580
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1581
//
1582
// NOTE: part of the V1Store interface.
1583
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1584
        s.cacheMu.Lock()
×
1585
        defer s.cacheMu.Unlock()
×
1586

×
1587
        var (
×
1588
                ctx     = context.TODO()
×
1589
                chanIDB = channelIDToBytes(chanID)
×
1590
        )
×
1591

×
1592
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1593
                res, err := db.DeleteZombieChannel(
×
1594
                        ctx, sqlc.DeleteZombieChannelParams{
×
1595
                                Scid:    chanIDB,
×
1596
                                Version: int16(ProtocolV1),
×
1597
                        },
×
1598
                )
×
1599
                if err != nil {
×
1600
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1601
                                err)
×
1602
                }
×
1603

1604
                rows, err := res.RowsAffected()
×
1605
                if err != nil {
×
1606
                        return err
×
1607
                }
×
1608

1609
                if rows == 0 {
×
1610
                        return ErrZombieEdgeNotFound
×
1611
                } else if rows > 1 {
×
1612
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1613
                                "expected 1", rows)
×
1614
                }
×
1615

1616
                return nil
×
1617
        }, sqldb.NoOpReset)
1618
        if err != nil {
×
1619
                return fmt.Errorf("unable to mark edge live "+
×
1620
                        "(channel_id=%d): %w", chanID, err)
×
1621
        }
×
1622

1623
        s.rejectCache.remove(chanID)
×
1624
        s.chanCache.remove(chanID)
×
1625

×
1626
        return err
×
1627
}
1628

1629
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1630
// zombie, then the two node public keys corresponding to this edge are also
1631
// returned.
1632
//
1633
// NOTE: part of the V1Store interface.
1634
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1635
        error) {
×
1636

×
1637
        var (
×
1638
                ctx              = context.TODO()
×
1639
                isZombie         bool
×
1640
                pubKey1, pubKey2 route.Vertex
×
1641
                chanIDB          = channelIDToBytes(chanID)
×
1642
        )
×
1643

×
1644
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1645
                zombie, err := db.GetZombieChannel(
×
1646
                        ctx, sqlc.GetZombieChannelParams{
×
1647
                                Scid:    chanIDB,
×
1648
                                Version: int16(ProtocolV1),
×
1649
                        },
×
1650
                )
×
1651
                if errors.Is(err, sql.ErrNoRows) {
×
1652
                        return nil
×
1653
                }
×
1654
                if err != nil {
×
1655
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1656
                                err)
×
1657
                }
×
1658

1659
                copy(pubKey1[:], zombie.NodeKey1)
×
1660
                copy(pubKey2[:], zombie.NodeKey2)
×
1661
                isZombie = true
×
1662

×
1663
                return nil
×
1664
        }, sqldb.NoOpReset)
1665
        if err != nil {
×
1666
                return false, route.Vertex{}, route.Vertex{},
×
1667
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1668
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1669
        }
×
1670

1671
        return isZombie, pubKey1, pubKey2, nil
×
1672
}
1673

1674
// NumZombies returns the current number of zombie channels in the graph.
1675
//
1676
// NOTE: part of the V1Store interface.
1677
func (s *SQLStore) NumZombies() (uint64, error) {
×
1678
        var (
×
1679
                ctx        = context.TODO()
×
1680
                numZombies uint64
×
1681
        )
×
1682
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1683
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1684
                if err != nil {
×
1685
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1686
                                err)
×
1687
                }
×
1688

1689
                numZombies = uint64(count)
×
1690

×
1691
                return nil
×
1692
        }, sqldb.NoOpReset)
1693
        if err != nil {
×
1694
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1695
        }
×
1696

1697
        return numZombies, nil
×
1698
}
1699

1700
// DeleteChannelEdges removes edges with the given channel IDs from the
1701
// database and marks them as zombies. This ensures that we're unable to re-add
1702
// it to our database once again. If an edge does not exist within the
1703
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1704
// true, then when we mark these edges as zombies, we'll set up the keys such
1705
// that we require the node that failed to send the fresh update to be the one
1706
// that resurrects the channel from its zombie state. The markZombie bool
1707
// denotes whether to mark the channel as a zombie.
1708
//
1709
// NOTE: part of the V1Store interface.
1710
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1711
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1712

×
1713
        s.cacheMu.Lock()
×
1714
        defer s.cacheMu.Unlock()
×
1715

×
1716
        var (
×
1717
                ctx     = context.TODO()
×
1718
                deleted []*models.ChannelEdgeInfo
×
1719
        )
×
1720
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1721
                for _, chanID := range chanIDs {
×
1722
                        chanIDB := channelIDToBytes(chanID)
×
1723

×
1724
                        row, err := db.GetChannelBySCIDWithPolicies(
×
1725
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1726
                                        Scid:    chanIDB,
×
1727
                                        Version: int16(ProtocolV1),
×
1728
                                },
×
1729
                        )
×
1730
                        if errors.Is(err, sql.ErrNoRows) {
×
1731
                                return ErrEdgeNotFound
×
1732
                        } else if err != nil {
×
1733
                                return fmt.Errorf("unable to fetch channel: %w",
×
1734
                                        err)
×
1735
                        }
×
1736

1737
                        node1, node2, err := buildNodeVertices(
×
1738
                                row.Node.PubKey, row.Node_2.PubKey,
×
1739
                        )
×
1740
                        if err != nil {
×
1741
                                return err
×
1742
                        }
×
1743

1744
                        info, err := getAndBuildEdgeInfo(
×
1745
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1746
                                row.Channel, node1, node2,
×
1747
                        )
×
1748
                        if err != nil {
×
1749
                                return err
×
1750
                        }
×
1751

1752
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
1753
                        if err != nil {
×
1754
                                return fmt.Errorf("unable to delete "+
×
1755
                                        "channel: %w", err)
×
1756
                        }
×
1757

1758
                        deleted = append(deleted, info)
×
1759

×
1760
                        if !markZombie {
×
1761
                                continue
×
1762
                        }
1763

1764
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1765
                                info.NodeKey2Bytes
×
1766
                        if strictZombiePruning {
×
1767
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1768
                                if row.Policy1LastUpdate.Valid {
×
1769
                                        e1Time := time.Unix(
×
1770
                                                row.Policy1LastUpdate.Int64, 0,
×
1771
                                        )
×
1772
                                        e1UpdateTime = &e1Time
×
1773
                                }
×
1774
                                if row.Policy2LastUpdate.Valid {
×
1775
                                        e2Time := time.Unix(
×
1776
                                                row.Policy2LastUpdate.Int64, 0,
×
1777
                                        )
×
1778
                                        e2UpdateTime = &e2Time
×
1779
                                }
×
1780

1781
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1782
                                        info, e1UpdateTime, e2UpdateTime,
×
1783
                                )
×
1784
                        }
1785

1786
                        err = db.UpsertZombieChannel(
×
1787
                                ctx, sqlc.UpsertZombieChannelParams{
×
1788
                                        Version:  int16(ProtocolV1),
×
1789
                                        Scid:     chanIDB,
×
1790
                                        NodeKey1: nodeKey1[:],
×
1791
                                        NodeKey2: nodeKey2[:],
×
1792
                                },
×
1793
                        )
×
1794
                        if err != nil {
×
1795
                                return fmt.Errorf("unable to mark channel as "+
×
1796
                                        "zombie: %w", err)
×
1797
                        }
×
1798
                }
1799

1800
                return nil
×
1801
        }, func() {
×
1802
                deleted = nil
×
1803
        })
×
1804
        if err != nil {
×
1805
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1806
                        err)
×
1807
        }
×
1808

1809
        for _, chanID := range chanIDs {
×
1810
                s.rejectCache.remove(chanID)
×
1811
                s.chanCache.remove(chanID)
×
1812
        }
×
1813

1814
        return deleted, nil
×
1815
}
1816

1817
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1818
// channel identified by the channel ID. If the channel can't be found, then
1819
// ErrEdgeNotFound is returned. A struct which houses the general information
1820
// for the channel itself is returned as well as two structs that contain the
1821
// routing policies for the channel in either direction.
1822
//
1823
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1824
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1825
// the ChannelEdgeInfo will only include the public keys of each node.
1826
//
1827
// NOTE: part of the V1Store interface.
1828
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1829
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1830
        *models.ChannelEdgePolicy, error) {
×
1831

×
1832
        var (
×
1833
                ctx              = context.TODO()
×
1834
                edge             *models.ChannelEdgeInfo
×
1835
                policy1, policy2 *models.ChannelEdgePolicy
×
1836
                chanIDB          = channelIDToBytes(chanID)
×
1837
        )
×
1838
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1839
                row, err := db.GetChannelBySCIDWithPolicies(
×
1840
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1841
                                Scid:    chanIDB,
×
1842
                                Version: int16(ProtocolV1),
×
1843
                        },
×
1844
                )
×
1845
                if errors.Is(err, sql.ErrNoRows) {
×
1846
                        // First check if this edge is perhaps in the zombie
×
1847
                        // index.
×
1848
                        zombie, err := db.GetZombieChannel(
×
1849
                                ctx, sqlc.GetZombieChannelParams{
×
1850
                                        Scid:    chanIDB,
×
1851
                                        Version: int16(ProtocolV1),
×
1852
                                },
×
1853
                        )
×
1854
                        if errors.Is(err, sql.ErrNoRows) {
×
1855
                                return ErrEdgeNotFound
×
1856
                        } else if err != nil {
×
1857
                                return fmt.Errorf("unable to check if "+
×
1858
                                        "channel is zombie: %w", err)
×
1859
                        }
×
1860

1861
                        // At this point, we know the channel is a zombie, so
1862
                        // we'll return an error indicating this, and we will
1863
                        // populate the edge info with the public keys of each
1864
                        // party as this is the only information we have about
1865
                        // it.
1866
                        edge = &models.ChannelEdgeInfo{}
×
1867
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1868
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1869

×
1870
                        return ErrZombieEdge
×
1871
                } else if err != nil {
×
1872
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1873
                }
×
1874

1875
                node1, node2, err := buildNodeVertices(
×
1876
                        row.Node.PubKey, row.Node_2.PubKey,
×
1877
                )
×
1878
                if err != nil {
×
1879
                        return err
×
1880
                }
×
1881

1882
                edge, err = getAndBuildEdgeInfo(
×
1883
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1884
                        node1, node2,
×
1885
                )
×
1886
                if err != nil {
×
1887
                        return fmt.Errorf("unable to build channel info: %w",
×
1888
                                err)
×
1889
                }
×
1890

1891
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1892
                if err != nil {
×
1893
                        return fmt.Errorf("unable to extract channel "+
×
1894
                                "policies: %w", err)
×
1895
                }
×
1896

1897
                policy1, policy2, err = getAndBuildChanPolicies(
×
1898
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1899
                )
×
1900
                if err != nil {
×
1901
                        return fmt.Errorf("unable to build channel "+
×
1902
                                "policies: %w", err)
×
1903
                }
×
1904

1905
                return nil
×
1906
        }, sqldb.NoOpReset)
1907
        if err != nil {
×
1908
                // If we are returning the ErrZombieEdge, then we also need to
×
1909
                // return the edge info as the method comment indicates that
×
1910
                // this will be populated when the edge is a zombie.
×
1911
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1912
                        err)
×
1913
        }
×
1914

1915
        return edge, policy1, policy2, nil
×
1916
}
1917

1918
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1919
// the channel identified by the funding outpoint. If the channel can't be
1920
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1921
// information for the channel itself is returned as well as two structs that
1922
// contain the routing policies for the channel in either direction.
1923
//
1924
// NOTE: part of the V1Store interface.
1925
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1926
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1927
        *models.ChannelEdgePolicy, error) {
×
1928

×
1929
        var (
×
1930
                ctx              = context.TODO()
×
1931
                edge             *models.ChannelEdgeInfo
×
1932
                policy1, policy2 *models.ChannelEdgePolicy
×
1933
        )
×
1934
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1935
                row, err := db.GetChannelByOutpointWithPolicies(
×
1936
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1937
                                Outpoint: op.String(),
×
1938
                                Version:  int16(ProtocolV1),
×
1939
                        },
×
1940
                )
×
1941
                if errors.Is(err, sql.ErrNoRows) {
×
1942
                        return ErrEdgeNotFound
×
1943
                } else if err != nil {
×
1944
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1945
                }
×
1946

1947
                node1, node2, err := buildNodeVertices(
×
1948
                        row.Node1Pubkey, row.Node2Pubkey,
×
1949
                )
×
1950
                if err != nil {
×
1951
                        return err
×
1952
                }
×
1953

1954
                edge, err = getAndBuildEdgeInfo(
×
1955
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1956
                        node1, node2,
×
1957
                )
×
1958
                if err != nil {
×
1959
                        return fmt.Errorf("unable to build channel info: %w",
×
1960
                                err)
×
1961
                }
×
1962

1963
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1964
                if err != nil {
×
1965
                        return fmt.Errorf("unable to extract channel "+
×
1966
                                "policies: %w", err)
×
1967
                }
×
1968

1969
                policy1, policy2, err = getAndBuildChanPolicies(
×
1970
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1971
                )
×
1972
                if err != nil {
×
1973
                        return fmt.Errorf("unable to build channel "+
×
1974
                                "policies: %w", err)
×
1975
                }
×
1976

1977
                return nil
×
1978
        }, sqldb.NoOpReset)
1979
        if err != nil {
×
1980
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1981
                        err)
×
1982
        }
×
1983

1984
        return edge, policy1, policy2, nil
×
1985
}
1986

1987
// HasChannelEdge returns true if the database knows of a channel edge with the
1988
// passed channel ID, and false otherwise. If an edge with that ID is found
1989
// within the graph, then two time stamps representing the last time the edge
1990
// was updated for both directed edges are returned along with the boolean. If
1991
// it is not found, then the zombie index is checked and its result is returned
1992
// as the second boolean.
1993
//
1994
// NOTE: part of the V1Store interface.
1995
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1996
        bool, error) {
×
1997

×
1998
        ctx := context.TODO()
×
1999

×
2000
        var (
×
2001
                exists          bool
×
2002
                isZombie        bool
×
2003
                node1LastUpdate time.Time
×
2004
                node2LastUpdate time.Time
×
2005
        )
×
2006

×
2007
        // We'll query the cache with the shared lock held to allow multiple
×
2008
        // readers to access values in the cache concurrently if they exist.
×
2009
        s.cacheMu.RLock()
×
2010
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2011
                s.cacheMu.RUnlock()
×
2012
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2013
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2014
                exists, isZombie = entry.flags.unpack()
×
2015

×
2016
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2017
        }
×
2018
        s.cacheMu.RUnlock()
×
2019

×
2020
        s.cacheMu.Lock()
×
2021
        defer s.cacheMu.Unlock()
×
2022

×
2023
        // The item was not found with the shared lock, so we'll acquire the
×
2024
        // exclusive lock and check the cache again in case another method added
×
2025
        // the entry to the cache while no lock was held.
×
2026
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2027
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2028
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2029
                exists, isZombie = entry.flags.unpack()
×
2030

×
2031
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2032
        }
×
2033

2034
        chanIDB := channelIDToBytes(chanID)
×
2035
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2036
                channel, err := db.GetChannelBySCID(
×
2037
                        ctx, sqlc.GetChannelBySCIDParams{
×
2038
                                Scid:    chanIDB,
×
2039
                                Version: int16(ProtocolV1),
×
2040
                        },
×
2041
                )
×
2042
                if errors.Is(err, sql.ErrNoRows) {
×
2043
                        // Check if it is a zombie channel.
×
2044
                        isZombie, err = db.IsZombieChannel(
×
2045
                                ctx, sqlc.IsZombieChannelParams{
×
2046
                                        Scid:    chanIDB,
×
2047
                                        Version: int16(ProtocolV1),
×
2048
                                },
×
2049
                        )
×
2050
                        if err != nil {
×
2051
                                return fmt.Errorf("could not check if channel "+
×
2052
                                        "is zombie: %w", err)
×
2053
                        }
×
2054

2055
                        return nil
×
2056
                } else if err != nil {
×
2057
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2058
                }
×
2059

2060
                exists = true
×
2061

×
2062
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2063
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2064
                                Version:   int16(ProtocolV1),
×
2065
                                ChannelID: channel.ID,
×
2066
                                NodeID:    channel.NodeID1,
×
2067
                        },
×
2068
                )
×
2069
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2070
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2071
                                err)
×
2072
                } else if err == nil {
×
2073
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2074
                }
×
2075

2076
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2077
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2078
                                Version:   int16(ProtocolV1),
×
2079
                                ChannelID: channel.ID,
×
2080
                                NodeID:    channel.NodeID2,
×
2081
                        },
×
2082
                )
×
2083
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2084
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2085
                                err)
×
2086
                } else if err == nil {
×
2087
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2088
                }
×
2089

2090
                return nil
×
2091
        }, sqldb.NoOpReset)
2092
        if err != nil {
×
2093
                return time.Time{}, time.Time{}, false, false,
×
2094
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2095
        }
×
2096

2097
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2098
                upd1Time: node1LastUpdate.Unix(),
×
2099
                upd2Time: node2LastUpdate.Unix(),
×
2100
                flags:    packRejectFlags(exists, isZombie),
×
2101
        })
×
2102

×
2103
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2104
}
2105

2106
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2107
// passed channel point (outpoint). If the passed channel doesn't exist within
2108
// the database, then ErrEdgeNotFound is returned.
2109
//
2110
// NOTE: part of the V1Store interface.
2111
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2112
        var (
×
2113
                ctx       = context.TODO()
×
2114
                channelID uint64
×
2115
        )
×
2116
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2117
                chanID, err := db.GetSCIDByOutpoint(
×
2118
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2119
                                Outpoint: chanPoint.String(),
×
2120
                                Version:  int16(ProtocolV1),
×
2121
                        },
×
2122
                )
×
2123
                if errors.Is(err, sql.ErrNoRows) {
×
2124
                        return ErrEdgeNotFound
×
2125
                } else if err != nil {
×
2126
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2127
                                err)
×
2128
                }
×
2129

2130
                channelID = byteOrder.Uint64(chanID)
×
2131

×
2132
                return nil
×
2133
        }, sqldb.NoOpReset)
2134
        if err != nil {
×
2135
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2136
        }
×
2137

2138
        return channelID, nil
×
2139
}
2140

2141
// IsPublicNode is a helper method that determines whether the node with the
2142
// given public key is seen as a public node in the graph from the graph's
2143
// source node's point of view.
2144
//
2145
// NOTE: part of the V1Store interface.
2146
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2147
        ctx := context.TODO()
×
2148

×
2149
        var isPublic bool
×
2150
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2151
                var err error
×
2152
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2153

×
2154
                return err
×
2155
        }, sqldb.NoOpReset)
×
2156
        if err != nil {
×
2157
                return false, fmt.Errorf("unable to check if node is "+
×
2158
                        "public: %w", err)
×
2159
        }
×
2160

2161
        return isPublic, nil
×
2162
}
2163

2164
// FetchChanInfos returns the set of channel edges that correspond to the passed
2165
// channel ID's. If an edge is the query is unknown to the database, it will
2166
// skipped and the result will contain only those edges that exist at the time
2167
// of the query. This can be used to respond to peer queries that are seeking to
2168
// fill in gaps in their view of the channel graph.
2169
//
2170
// NOTE: part of the V1Store interface.
2171
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2172
        var (
×
2173
                ctx   = context.TODO()
×
2174
                edges []ChannelEdge
×
2175
        )
×
2176
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2177
                for _, chanID := range chanIDs {
×
2178
                        chanIDB := channelIDToBytes(chanID)
×
2179

×
2180
                        // TODO(elle): potentially optimize this by using
×
2181
                        //  sqlc.slice() once that works for both SQLite and
×
2182
                        //  Postgres.
×
2183
                        row, err := db.GetChannelBySCIDWithPolicies(
×
2184
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
2185
                                        Scid:    chanIDB,
×
2186
                                        Version: int16(ProtocolV1),
×
2187
                                },
×
2188
                        )
×
2189
                        if errors.Is(err, sql.ErrNoRows) {
×
2190
                                continue
×
2191
                        } else if err != nil {
×
2192
                                return fmt.Errorf("unable to fetch channel: %w",
×
2193
                                        err)
×
2194
                        }
×
2195

2196
                        node1, node2, err := buildNodes(
×
2197
                                ctx, db, row.Node, row.Node_2,
×
2198
                        )
×
2199
                        if err != nil {
×
2200
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2201
                                        err)
×
2202
                        }
×
2203

2204
                        edge, err := getAndBuildEdgeInfo(
×
2205
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2206
                                row.Channel, node1.PubKeyBytes,
×
2207
                                node2.PubKeyBytes,
×
2208
                        )
×
2209
                        if err != nil {
×
2210
                                return fmt.Errorf("unable to build "+
×
2211
                                        "channel info: %w", err)
×
2212
                        }
×
2213

2214
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2215
                        if err != nil {
×
2216
                                return fmt.Errorf("unable to extract channel "+
×
2217
                                        "policies: %w", err)
×
2218
                        }
×
2219

2220
                        p1, p2, err := getAndBuildChanPolicies(
×
2221
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2222
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2223
                        )
×
2224
                        if err != nil {
×
2225
                                return fmt.Errorf("unable to build channel "+
×
2226
                                        "policies: %w", err)
×
2227
                        }
×
2228

2229
                        edges = append(edges, ChannelEdge{
×
2230
                                Info:    edge,
×
2231
                                Policy1: p1,
×
2232
                                Policy2: p2,
×
2233
                                Node1:   node1,
×
2234
                                Node2:   node2,
×
2235
                        })
×
2236
                }
2237

2238
                return nil
×
2239
        }, func() {
×
2240
                edges = nil
×
2241
        })
×
2242
        if err != nil {
×
2243
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2244
        }
×
2245

2246
        return edges, nil
×
2247
}
2248

2249
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2250
// ID's that we don't know and are not known zombies of the passed set. In other
2251
// words, we perform a set difference of our set of chan ID's and the ones
2252
// passed in. This method can be used by callers to determine the set of
2253
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2254
// known zombies is also returned.
2255
//
2256
// NOTE: part of the V1Store interface.
2257
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2258
        []ChannelUpdateInfo, error) {
×
2259

×
2260
        var (
×
2261
                ctx          = context.TODO()
×
2262
                newChanIDs   []uint64
×
2263
                knownZombies []ChannelUpdateInfo
×
2264
        )
×
2265
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2266
                for _, chanInfo := range chansInfo {
×
2267
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2268
                        chanIDB := channelIDToBytes(channelID)
×
2269

×
2270
                        // TODO(elle): potentially optimize this by using
×
2271
                        //  sqlc.slice() once that works for both SQLite and
×
2272
                        //  Postgres.
×
2273
                        _, err := db.GetChannelBySCID(
×
2274
                                ctx, sqlc.GetChannelBySCIDParams{
×
2275
                                        Version: int16(ProtocolV1),
×
2276
                                        Scid:    chanIDB,
×
2277
                                },
×
2278
                        )
×
2279
                        if err == nil {
×
2280
                                continue
×
2281
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
2282
                                return fmt.Errorf("unable to fetch channel: %w",
×
2283
                                        err)
×
2284
                        }
×
2285

2286
                        isZombie, err := db.IsZombieChannel(
×
2287
                                ctx, sqlc.IsZombieChannelParams{
×
2288
                                        Scid:    chanIDB,
×
2289
                                        Version: int16(ProtocolV1),
×
2290
                                },
×
2291
                        )
×
2292
                        if err != nil {
×
2293
                                return fmt.Errorf("unable to fetch zombie "+
×
2294
                                        "channel: %w", err)
×
2295
                        }
×
2296

2297
                        if isZombie {
×
2298
                                knownZombies = append(knownZombies, chanInfo)
×
2299

×
2300
                                continue
×
2301
                        }
2302

2303
                        newChanIDs = append(newChanIDs, channelID)
×
2304
                }
2305

2306
                return nil
×
2307
        }, func() {
×
2308
                newChanIDs = nil
×
2309
                knownZombies = nil
×
2310
        })
×
2311
        if err != nil {
×
2312
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2313
        }
×
2314

2315
        return newChanIDs, knownZombies, nil
×
2316
}
2317

2318
// PruneGraphNodes is a garbage collection method which attempts to prune out
2319
// any nodes from the channel graph that are currently unconnected. This ensure
2320
// that we only maintain a graph of reachable nodes. In the event that a pruned
2321
// node gains more channels, it will be re-added back to the graph.
2322
//
2323
// NOTE: this prunes nodes across protocol versions. It will never prune the
2324
// source nodes.
2325
//
2326
// NOTE: part of the V1Store interface.
2327
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2328
        var ctx = context.TODO()
×
2329

×
2330
        var prunedNodes []route.Vertex
×
2331
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2332
                var err error
×
2333
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2334

×
2335
                return err
×
2336
        }, func() {
×
2337
                prunedNodes = nil
×
2338
        })
×
2339
        if err != nil {
×
2340
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2341
        }
×
2342

2343
        return prunedNodes, nil
×
2344
}
2345

2346
// PruneGraph prunes newly closed channels from the channel graph in response
2347
// to a new block being solved on the network. Any transactions which spend the
2348
// funding output of any known channels within he graph will be deleted.
2349
// Additionally, the "prune tip", or the last block which has been used to
2350
// prune the graph is stored so callers can ensure the graph is fully in sync
2351
// with the current UTXO state. A slice of channels that have been closed by
2352
// the target block along with any pruned nodes are returned if the function
2353
// succeeds without error.
2354
//
2355
// NOTE: part of the V1Store interface.
2356
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2357
        blockHash *chainhash.Hash, blockHeight uint32) (
2358
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2359

×
2360
        ctx := context.TODO()
×
2361

×
2362
        s.cacheMu.Lock()
×
2363
        defer s.cacheMu.Unlock()
×
2364

×
2365
        var (
×
2366
                closedChans []*models.ChannelEdgeInfo
×
2367
                prunedNodes []route.Vertex
×
2368
        )
×
2369
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2370
                for _, outpoint := range spentOutputs {
×
2371
                        // TODO(elle): potentially optimize this by using
×
2372
                        //  sqlc.slice() once that works for both SQLite and
×
2373
                        //  Postgres.
×
2374
                        //
×
2375
                        // NOTE: this fetches channels for all protocol
×
2376
                        // versions.
×
2377
                        row, err := db.GetChannelByOutpoint(
×
2378
                                ctx, outpoint.String(),
×
2379
                        )
×
2380
                        if errors.Is(err, sql.ErrNoRows) {
×
2381
                                continue
×
2382
                        } else if err != nil {
×
2383
                                return fmt.Errorf("unable to fetch channel: %w",
×
2384
                                        err)
×
2385
                        }
×
2386

2387
                        node1, node2, err := buildNodeVertices(
×
2388
                                row.Node1Pubkey, row.Node2Pubkey,
×
2389
                        )
×
2390
                        if err != nil {
×
2391
                                return err
×
2392
                        }
×
2393

2394
                        info, err := getAndBuildEdgeInfo(
×
2395
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2396
                                row.Channel, node1, node2,
×
2397
                        )
×
2398
                        if err != nil {
×
2399
                                return err
×
2400
                        }
×
2401

2402
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2403
                        if err != nil {
×
2404
                                return fmt.Errorf("unable to delete "+
×
2405
                                        "channel: %w", err)
×
2406
                        }
×
2407

2408
                        closedChans = append(closedChans, info)
×
2409
                }
2410

2411
                err := db.UpsertPruneLogEntry(
×
2412
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2413
                                BlockHash:   blockHash[:],
×
2414
                                BlockHeight: int64(blockHeight),
×
2415
                        },
×
2416
                )
×
2417
                if err != nil {
×
2418
                        return fmt.Errorf("unable to insert prune log "+
×
2419
                                "entry: %w", err)
×
2420
                }
×
2421

2422
                // Now that we've pruned some channels, we'll also prune any
2423
                // nodes that no longer have any channels.
2424
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2425
                if err != nil {
×
2426
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2427
                                err)
×
2428
                }
×
2429

2430
                return nil
×
2431
        }, func() {
×
2432
                prunedNodes = nil
×
2433
                closedChans = nil
×
2434
        })
×
2435
        if err != nil {
×
2436
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2437
        }
×
2438

2439
        for _, channel := range closedChans {
×
2440
                s.rejectCache.remove(channel.ChannelID)
×
2441
                s.chanCache.remove(channel.ChannelID)
×
2442
        }
×
2443

2444
        return closedChans, prunedNodes, nil
×
2445
}
2446

2447
// ChannelView returns the verifiable edge information for each active channel
2448
// within the known channel graph. The set of UTXOs (along with their scripts)
2449
// returned are the ones that need to be watched on chain to detect channel
2450
// closes on the resident blockchain.
2451
//
2452
// NOTE: part of the V1Store interface.
2453
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2454
        var (
×
2455
                ctx        = context.TODO()
×
2456
                edgePoints []EdgePoint
×
2457
        )
×
2458

×
2459
        handleChannel := func(db SQLQueries,
×
2460
                channel sqlc.ListChannelsPaginatedRow) error {
×
2461

×
2462
                pkScript, err := genMultiSigP2WSH(
×
2463
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
2464
                )
×
2465
                if err != nil {
×
2466
                        return err
×
2467
                }
×
2468

2469
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2470
                if err != nil {
×
2471
                        return err
×
2472
                }
×
2473

2474
                edgePoints = append(edgePoints, EdgePoint{
×
2475
                        FundingPkScript: pkScript,
×
2476
                        OutPoint:        *op,
×
2477
                })
×
2478

×
2479
                return nil
×
2480
        }
2481

2482
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2483
                lastID := int64(-1)
×
2484
                for {
×
2485
                        rows, err := db.ListChannelsPaginated(
×
2486
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2487
                                        Version: int16(ProtocolV1),
×
2488
                                        ID:      lastID,
×
2489
                                        Limit:   pageSize,
×
2490
                                },
×
2491
                        )
×
2492
                        if err != nil {
×
2493
                                return err
×
2494
                        }
×
2495

2496
                        if len(rows) == 0 {
×
2497
                                break
×
2498
                        }
2499

2500
                        for _, row := range rows {
×
2501
                                err := handleChannel(db, row)
×
2502
                                if err != nil {
×
2503
                                        return err
×
2504
                                }
×
2505

2506
                                lastID = row.ID
×
2507
                        }
2508
                }
2509

2510
                return nil
×
2511
        }, func() {
×
2512
                edgePoints = nil
×
2513
        })
×
2514
        if err != nil {
×
2515
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2516
        }
×
2517

2518
        return edgePoints, nil
×
2519
}
2520

2521
// PruneTip returns the block height and hash of the latest block that has been
2522
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2523
// to tell if the graph is currently in sync with the current best known UTXO
2524
// state.
2525
//
2526
// NOTE: part of the V1Store interface.
2527
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2528
        var (
×
2529
                ctx       = context.TODO()
×
2530
                tipHash   chainhash.Hash
×
2531
                tipHeight uint32
×
2532
        )
×
2533
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2534
                pruneTip, err := db.GetPruneTip(ctx)
×
2535
                if errors.Is(err, sql.ErrNoRows) {
×
2536
                        return ErrGraphNeverPruned
×
2537
                } else if err != nil {
×
2538
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2539
                }
×
2540

2541
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2542
                tipHeight = uint32(pruneTip.BlockHeight)
×
2543

×
2544
                return nil
×
2545
        }, sqldb.NoOpReset)
2546
        if err != nil {
×
2547
                return nil, 0, err
×
2548
        }
×
2549

2550
        return &tipHash, tipHeight, nil
×
2551
}
2552

2553
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2554
//
2555
// NOTE: this prunes nodes across protocol versions. It will never prune the
2556
// source nodes.
2557
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2558
        db SQLQueries) ([]route.Vertex, error) {
×
2559

×
2560
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2561
        if err != nil {
×
2562
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2563
                        "nodes: %w", err)
×
2564
        }
×
2565

2566
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2567
        for i, nodeKey := range nodeKeys {
×
2568
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2569
                if err != nil {
×
2570
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2571
                                "from bytes: %w", err)
×
2572
                }
×
2573

2574
                prunedNodes[i] = pub
×
2575
        }
2576

2577
        return prunedNodes, nil
×
2578
}
2579

2580
// DisconnectBlockAtHeight is used to indicate that the block specified
2581
// by the passed height has been disconnected from the main chain. This
2582
// will "rewind" the graph back to the height below, deleting channels
2583
// that are no longer confirmed from the graph. The prune log will be
2584
// set to the last prune height valid for the remaining chain.
2585
// Channels that were removed from the graph resulting from the
2586
// disconnected block are returned.
2587
//
2588
// NOTE: part of the V1Store interface.
2589
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2590
        []*models.ChannelEdgeInfo, error) {
×
2591

×
2592
        ctx := context.TODO()
×
2593

×
2594
        var (
×
2595
                // Every channel having a ShortChannelID starting at 'height'
×
2596
                // will no longer be confirmed.
×
2597
                startShortChanID = lnwire.ShortChannelID{
×
2598
                        BlockHeight: height,
×
2599
                }
×
2600

×
2601
                // Delete everything after this height from the db up until the
×
2602
                // SCID alias range.
×
2603
                endShortChanID = aliasmgr.StartingAlias
×
2604

×
2605
                removedChans []*models.ChannelEdgeInfo
×
2606

×
2607
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2608
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2609
        )
×
2610

×
2611
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2612
                rows, err := db.GetChannelsBySCIDRange(
×
2613
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2614
                                StartScid: chanIDStart,
×
2615
                                EndScid:   chanIDEnd,
×
2616
                        },
×
2617
                )
×
2618
                if err != nil {
×
2619
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2620
                }
×
2621

2622
                for _, row := range rows {
×
2623
                        node1, node2, err := buildNodeVertices(
×
2624
                                row.Node1PubKey, row.Node2PubKey,
×
2625
                        )
×
2626
                        if err != nil {
×
2627
                                return err
×
2628
                        }
×
2629

2630
                        channel, err := getAndBuildEdgeInfo(
×
2631
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2632
                                row.Channel, node1, node2,
×
2633
                        )
×
2634
                        if err != nil {
×
2635
                                return err
×
2636
                        }
×
2637

2638
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2639
                        if err != nil {
×
2640
                                return fmt.Errorf("unable to delete "+
×
2641
                                        "channel: %w", err)
×
2642
                        }
×
2643

2644
                        removedChans = append(removedChans, channel)
×
2645
                }
2646

2647
                return db.DeletePruneLogEntriesInRange(
×
2648
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2649
                                StartHeight: int64(height),
×
2650
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2651
                        },
×
2652
                )
×
2653
        }, func() {
×
2654
                removedChans = nil
×
2655
        })
×
2656
        if err != nil {
×
2657
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2658
                        "height: %w", err)
×
2659
        }
×
2660

2661
        for _, channel := range removedChans {
×
2662
                s.rejectCache.remove(channel.ChannelID)
×
2663
                s.chanCache.remove(channel.ChannelID)
×
2664
        }
×
2665

2666
        return removedChans, nil
×
2667
}
2668

2669
// AddEdgeProof sets the proof of an existing edge in the graph database.
2670
//
2671
// NOTE: part of the V1Store interface.
2672
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2673
        proof *models.ChannelAuthProof) error {
×
2674

×
2675
        var (
×
2676
                ctx       = context.TODO()
×
2677
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2678
        )
×
2679

×
2680
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2681
                res, err := db.AddV1ChannelProof(
×
2682
                        ctx, sqlc.AddV1ChannelProofParams{
×
2683
                                Scid:              scidBytes,
×
2684
                                Node1Signature:    proof.NodeSig1Bytes,
×
2685
                                Node2Signature:    proof.NodeSig2Bytes,
×
2686
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2687
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2688
                        },
×
2689
                )
×
2690
                if err != nil {
×
2691
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2692
                }
×
2693

2694
                n, err := res.RowsAffected()
×
2695
                if err != nil {
×
2696
                        return err
×
2697
                }
×
2698

2699
                if n == 0 {
×
2700
                        return fmt.Errorf("no rows affected when adding edge "+
×
2701
                                "proof for SCID %v", scid)
×
2702
                } else if n > 1 {
×
2703
                        return fmt.Errorf("multiple rows affected when adding "+
×
2704
                                "edge proof for SCID %v: %d rows affected",
×
2705
                                scid, n)
×
2706
                }
×
2707

2708
                return nil
×
2709
        }, sqldb.NoOpReset)
2710
        if err != nil {
×
2711
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2712
        }
×
2713

2714
        return nil
×
2715
}
2716

2717
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2718
// that we can ignore channel announcements that we know to be closed without
2719
// having to validate them and fetch a block.
2720
//
2721
// NOTE: part of the V1Store interface.
2722
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2723
        var (
×
2724
                ctx     = context.TODO()
×
2725
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2726
        )
×
2727

×
2728
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2729
                return db.InsertClosedChannel(ctx, chanIDB)
×
2730
        }, sqldb.NoOpReset)
×
2731
}
2732

2733
// IsClosedScid checks whether a channel identified by the passed in scid is
2734
// closed. This helps avoid having to perform expensive validation checks.
2735
//
2736
// NOTE: part of the V1Store interface.
2737
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2738
        var (
×
2739
                ctx      = context.TODO()
×
2740
                isClosed bool
×
2741
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2742
        )
×
2743
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2744
                var err error
×
2745
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2746
                if err != nil {
×
2747
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2748
                                err)
×
2749
                }
×
2750

2751
                return nil
×
2752
        }, sqldb.NoOpReset)
2753
        if err != nil {
×
2754
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2755
                        err)
×
2756
        }
×
2757

2758
        return isClosed, nil
×
2759
}
2760

2761
// GraphSession will provide the call-back with access to a NodeTraverser
2762
// instance which can be used to perform queries against the channel graph.
2763
//
2764
// NOTE: part of the V1Store interface.
2765
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error) error {
×
2766
        var ctx = context.TODO()
×
2767

×
2768
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2769
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2770
        }, sqldb.NoOpReset)
×
2771
}
2772

2773
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2774
// read only transaction for a consistent view of the graph.
2775
type sqlNodeTraverser struct {
2776
        db    SQLQueries
2777
        chain chainhash.Hash
2778
}
2779

2780
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2781
// NodeTraverser interface.
2782
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2783

2784
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2785
func newSQLNodeTraverser(db SQLQueries,
2786
        chain chainhash.Hash) *sqlNodeTraverser {
×
2787

×
2788
        return &sqlNodeTraverser{
×
2789
                db:    db,
×
2790
                chain: chain,
×
2791
        }
×
2792
}
×
2793

2794
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2795
// node.
2796
//
2797
// NOTE: Part of the NodeTraverser interface.
2798
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2799
        cb func(channel *DirectedChannel) error) error {
×
2800

×
2801
        ctx := context.TODO()
×
2802

×
2803
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2804
}
×
2805

2806
// FetchNodeFeatures returns the features of the given node. If the node is
2807
// unknown, assume no additional features are supported.
2808
//
2809
// NOTE: Part of the NodeTraverser interface.
2810
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2811
        *lnwire.FeatureVector, error) {
×
2812

×
2813
        ctx := context.TODO()
×
2814

×
2815
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2816
}
×
2817

2818
// forEachNodeDirectedChannel iterates through all channels of a given
2819
// node, executing the passed callback on the directed edge representing the
2820
// channel and its incoming policy. If the node is not found, no error is
2821
// returned.
2822
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2823
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2824

×
2825
        toNodeCallback := func() route.Vertex {
×
2826
                return nodePub
×
2827
        }
×
2828

2829
        dbID, err := db.GetNodeIDByPubKey(
×
2830
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2831
                        Version: int16(ProtocolV1),
×
2832
                        PubKey:  nodePub[:],
×
2833
                },
×
2834
        )
×
2835
        if errors.Is(err, sql.ErrNoRows) {
×
2836
                return nil
×
2837
        } else if err != nil {
×
2838
                return fmt.Errorf("unable to fetch node: %w", err)
×
2839
        }
×
2840

2841
        rows, err := db.ListChannelsByNodeID(
×
2842
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2843
                        Version: int16(ProtocolV1),
×
2844
                        NodeID1: dbID,
×
2845
                },
×
2846
        )
×
2847
        if err != nil {
×
2848
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2849
        }
×
2850

2851
        // Exit early if there are no channels for this node so we don't
2852
        // do the unnecessary feature fetching.
2853
        if len(rows) == 0 {
×
2854
                return nil
×
2855
        }
×
2856

2857
        features, err := getNodeFeatures(ctx, db, dbID)
×
2858
        if err != nil {
×
2859
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2860
        }
×
2861

2862
        for _, row := range rows {
×
2863
                node1, node2, err := buildNodeVertices(
×
2864
                        row.Node1Pubkey, row.Node2Pubkey,
×
2865
                )
×
2866
                if err != nil {
×
2867
                        return fmt.Errorf("unable to build node vertices: %w",
×
2868
                                err)
×
2869
                }
×
2870

2871
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
2872

×
2873
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2874
                if err != nil {
×
2875
                        return err
×
2876
                }
×
2877

2878
                var p1, p2 *models.CachedEdgePolicy
×
2879
                if dbPol1 != nil {
×
2880
                        policy1, err := buildChanPolicy(
×
2881
                                *dbPol1, edge.ChannelID, nil, node2,
×
2882
                        )
×
2883
                        if err != nil {
×
2884
                                return err
×
2885
                        }
×
2886

2887
                        p1 = models.NewCachedPolicy(policy1)
×
2888
                }
2889
                if dbPol2 != nil {
×
2890
                        policy2, err := buildChanPolicy(
×
2891
                                *dbPol2, edge.ChannelID, nil, node1,
×
2892
                        )
×
2893
                        if err != nil {
×
2894
                                return err
×
2895
                        }
×
2896

2897
                        p2 = models.NewCachedPolicy(policy2)
×
2898
                }
2899

2900
                // Determine the outgoing and incoming policy for this
2901
                // channel and node combo.
2902
                outPolicy, inPolicy := p1, p2
×
2903
                if p1 != nil && node2 == nodePub {
×
2904
                        outPolicy, inPolicy = p2, p1
×
2905
                } else if p2 != nil && node1 != nodePub {
×
2906
                        outPolicy, inPolicy = p2, p1
×
2907
                }
×
2908

2909
                var cachedInPolicy *models.CachedEdgePolicy
×
2910
                if inPolicy != nil {
×
2911
                        cachedInPolicy = inPolicy
×
2912
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2913
                        cachedInPolicy.ToNodeFeatures = features
×
2914
                }
×
2915

2916
                directedChannel := &DirectedChannel{
×
2917
                        ChannelID:    edge.ChannelID,
×
2918
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2919
                        OtherNode:    edge.NodeKey2Bytes,
×
2920
                        Capacity:     edge.Capacity,
×
2921
                        OutPolicySet: outPolicy != nil,
×
2922
                        InPolicy:     cachedInPolicy,
×
2923
                }
×
2924
                if outPolicy != nil {
×
2925
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2926
                                directedChannel.InboundFee = fee
×
2927
                        })
×
2928
                }
2929

2930
                if nodePub == edge.NodeKey2Bytes {
×
2931
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2932
                }
×
2933

2934
                if err := cb(directedChannel); err != nil {
×
2935
                        return err
×
2936
                }
×
2937
        }
2938

2939
        return nil
×
2940
}
2941

2942
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2943
// and executes the provided callback for each node.
2944
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
2945
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
2946

×
2947
        lastID := int64(-1)
×
2948

×
2949
        for {
×
2950
                nodes, err := db.ListNodeIDsAndPubKeys(
×
2951
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
2952
                                Version: int16(ProtocolV1),
×
2953
                                ID:      lastID,
×
2954
                                Limit:   pageSize,
×
2955
                        },
×
2956
                )
×
2957
                if err != nil {
×
2958
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
2959
                }
×
2960

2961
                if len(nodes) == 0 {
×
2962
                        break
×
2963
                }
2964

2965
                for _, node := range nodes {
×
2966
                        var pub route.Vertex
×
2967
                        copy(pub[:], node.PubKey)
×
2968

×
2969
                        if err := cb(node.ID, pub); err != nil {
×
2970
                                return fmt.Errorf("forEachNodeCacheable "+
×
2971
                                        "callback failed for node(id=%d): %w",
×
2972
                                        node.ID, err)
×
2973
                        }
×
2974

2975
                        lastID = node.ID
×
2976
                }
2977
        }
2978

2979
        return nil
×
2980
}
2981

2982
// forEachNodeChannel iterates through all channels of a node, executing
2983
// the passed callback on each. The call-back is provided with the channel's
2984
// edge information, the outgoing policy and the incoming policy for the
2985
// channel and node combo.
2986
func forEachNodeChannel(ctx context.Context, db SQLQueries,
2987
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
2988
                *models.ChannelEdgePolicy,
2989
                *models.ChannelEdgePolicy) error) error {
×
2990

×
2991
        // Get all the V1 channels for this node.Add commentMore actions
×
2992
        rows, err := db.ListChannelsByNodeID(
×
2993
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2994
                        Version: int16(ProtocolV1),
×
2995
                        NodeID1: id,
×
2996
                },
×
2997
        )
×
2998
        if err != nil {
×
2999
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3000
        }
×
3001

3002
        // Call the call-back for each channel and its known policies.
3003
        for _, row := range rows {
×
3004
                node1, node2, err := buildNodeVertices(
×
3005
                        row.Node1Pubkey, row.Node2Pubkey,
×
3006
                )
×
3007
                if err != nil {
×
3008
                        return fmt.Errorf("unable to build node vertices: %w",
×
3009
                                err)
×
3010
                }
×
3011

3012
                edge, err := getAndBuildEdgeInfo(
×
3013
                        ctx, db, chain, row.Channel.ID, row.Channel, node1,
×
3014
                        node2,
×
3015
                )
×
3016
                if err != nil {
×
3017
                        return fmt.Errorf("unable to build channel info: %w",
×
3018
                                err)
×
3019
                }
×
3020

3021
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3022
                if err != nil {
×
3023
                        return fmt.Errorf("unable to extract channel "+
×
3024
                                "policies: %w", err)
×
3025
                }
×
3026

3027
                p1, p2, err := getAndBuildChanPolicies(
×
3028
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3029
                )
×
3030
                if err != nil {
×
3031
                        return fmt.Errorf("unable to build channel "+
×
3032
                                "policies: %w", err)
×
3033
                }
×
3034

3035
                // Determine the outgoing and incoming policy for this
3036
                // channel and node combo.
3037
                p1ToNode := row.Channel.NodeID2
×
3038
                p2ToNode := row.Channel.NodeID1
×
3039
                outPolicy, inPolicy := p1, p2
×
3040
                if (p1 != nil && p1ToNode == id) ||
×
3041
                        (p2 != nil && p2ToNode != id) {
×
3042

×
3043
                        outPolicy, inPolicy = p2, p1
×
3044
                }
×
3045

3046
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3047
                        return err
×
3048
                }
×
3049
        }
3050

3051
        return nil
×
3052
}
3053

3054
// updateChanEdgePolicy upserts the channel policy info we have stored for
3055
// a channel we already know of.
3056
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3057
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3058
        error) {
×
3059

×
3060
        var (
×
3061
                node1Pub, node2Pub route.Vertex
×
3062
                isNode1            bool
×
3063
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3064
        )
×
3065

×
3066
        // Check that this edge policy refers to a channel that we already
×
3067
        // know of. We do this explicitly so that we can return the appropriate
×
3068
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3069
        // abort the transaction which would abort the entire batch.
×
3070
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3071
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3072
                        Scid:    chanIDB,
×
3073
                        Version: int16(ProtocolV1),
×
3074
                },
×
3075
        )
×
3076
        if errors.Is(err, sql.ErrNoRows) {
×
3077
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3078
        } else if err != nil {
×
3079
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3080
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3081
        }
×
3082

3083
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3084
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3085

×
3086
        // Figure out which node this edge is from.
×
3087
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3088
        nodeID := dbChan.NodeID1
×
3089
        if !isNode1 {
×
3090
                nodeID = dbChan.NodeID2
×
3091
        }
×
3092

3093
        var (
×
3094
                inboundBase sql.NullInt64
×
3095
                inboundRate sql.NullInt64
×
3096
        )
×
3097
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3098
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3099
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3100
        })
×
3101

3102
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3103
                Version:     int16(ProtocolV1),
×
3104
                ChannelID:   dbChan.ID,
×
3105
                NodeID:      nodeID,
×
3106
                Timelock:    int32(edge.TimeLockDelta),
×
3107
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3108
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3109
                MinHtlcMsat: int64(edge.MinHTLC),
×
3110
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3111
                Disabled: sql.NullBool{
×
3112
                        Valid: true,
×
3113
                        Bool:  edge.IsDisabled(),
×
3114
                },
×
3115
                MaxHtlcMsat: sql.NullInt64{
×
3116
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3117
                        Int64: int64(edge.MaxHTLC),
×
3118
                },
×
3119
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3120
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3121
                InboundBaseFeeMsat:      inboundBase,
×
3122
                InboundFeeRateMilliMsat: inboundRate,
×
3123
                Signature:               edge.SigBytes,
×
3124
        })
×
3125
        if err != nil {
×
3126
                return node1Pub, node2Pub, isNode1,
×
3127
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3128
        }
×
3129

3130
        // Convert the flat extra opaque data into a map of TLV types to
3131
        // values.
3132
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3133
        if err != nil {
×
3134
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3135
                        "marshal extra opaque data: %w", err)
×
3136
        }
×
3137

3138
        // Update the channel policy's extra signed fields.
3139
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3140
        if err != nil {
×
3141
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3142
                        "policy extra TLVs: %w", err)
×
3143
        }
×
3144

3145
        return node1Pub, node2Pub, isNode1, nil
×
3146
}
3147

3148
// getNodeByPubKey attempts to look up a target node by its public key.
3149
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3150
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3151

×
3152
        dbNode, err := db.GetNodeByPubKey(
×
3153
                ctx, sqlc.GetNodeByPubKeyParams{
×
3154
                        Version: int16(ProtocolV1),
×
3155
                        PubKey:  pubKey[:],
×
3156
                },
×
3157
        )
×
3158
        if errors.Is(err, sql.ErrNoRows) {
×
3159
                return 0, nil, ErrGraphNodeNotFound
×
3160
        } else if err != nil {
×
3161
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3162
        }
×
3163

3164
        node, err := buildNode(ctx, db, &dbNode)
×
3165
        if err != nil {
×
3166
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3167
        }
×
3168

3169
        return dbNode.ID, node, nil
×
3170
}
3171

3172
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3173
// provided database channel row and the public keys of the two nodes
3174
// involved in the channel.
3175
func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub,
3176
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3177

×
3178
        return &models.CachedEdgeInfo{
×
3179
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
3180
                NodeKey1Bytes: node1Pub,
×
3181
                NodeKey2Bytes: node2Pub,
×
3182
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
3183
        }
×
3184
}
×
3185

3186
// buildNode constructs a LightningNode instance from the given database node
3187
// record. The node's features, addresses and extra signed fields are also
3188
// fetched from the database and set on the node.
3189
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
3190
        *models.LightningNode, error) {
×
3191

×
3192
        if dbNode.Version != int16(ProtocolV1) {
×
3193
                return nil, fmt.Errorf("unsupported node version: %d",
×
3194
                        dbNode.Version)
×
3195
        }
×
3196

3197
        var pub [33]byte
×
3198
        copy(pub[:], dbNode.PubKey)
×
3199

×
3200
        node := &models.LightningNode{
×
3201
                PubKeyBytes: pub,
×
3202
                Features:    lnwire.EmptyFeatureVector(),
×
3203
                LastUpdate:  time.Unix(0, 0),
×
3204
        }
×
3205

×
3206
        if len(dbNode.Signature) == 0 {
×
3207
                return node, nil
×
3208
        }
×
3209

3210
        node.HaveNodeAnnouncement = true
×
3211
        node.AuthSigBytes = dbNode.Signature
×
3212
        node.Alias = dbNode.Alias.String
×
3213
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3214

×
3215
        var err error
×
3216
        if dbNode.Color.Valid {
×
3217
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3218
                if err != nil {
×
3219
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3220
                                err)
×
3221
                }
×
3222
        }
3223

3224
        // Fetch the node's features.
3225
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
3226
        if err != nil {
×
3227
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3228
                        "features: %w", dbNode.ID, err)
×
3229
        }
×
3230

3231
        // Fetch the node's addresses.
3232
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
3233
        if err != nil {
×
3234
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3235
                        "addresses: %w", dbNode.ID, err)
×
3236
        }
×
3237

3238
        // Fetch the node's extra signed fields.
3239
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
3240
        if err != nil {
×
3241
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3242
                        "extra signed fields: %w", dbNode.ID, err)
×
3243
        }
×
3244

3245
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
3246
        if err != nil {
×
3247
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3248
                        "fields: %w", err)
×
3249
        }
×
3250

3251
        if len(recs) != 0 {
×
3252
                node.ExtraOpaqueData = recs
×
3253
        }
×
3254

3255
        return node, nil
×
3256
}
3257

3258
// getNodeFeatures fetches the feature bits and constructs the feature vector
3259
// for a node with the given DB ID.
3260
func getNodeFeatures(ctx context.Context, db SQLQueries,
3261
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3262

×
3263
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3264
        if err != nil {
×
3265
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3266
                        nodeID, err)
×
3267
        }
×
3268

3269
        features := lnwire.EmptyFeatureVector()
×
3270
        for _, feature := range rows {
×
3271
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3272
        }
×
3273

3274
        return features, nil
×
3275
}
3276

3277
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3278
// given DB ID.
3279
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3280
        nodeID int64) (map[uint64][]byte, error) {
×
3281

×
3282
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3283
        if err != nil {
×
3284
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
3285
                        "signed fields: %w", nodeID, err)
×
3286
        }
×
3287

3288
        extraFields := make(map[uint64][]byte)
×
3289
        for _, field := range fields {
×
3290
                extraFields[uint64(field.Type)] = field.Value
×
3291
        }
×
3292

3293
        return extraFields, nil
×
3294
}
3295

3296
// upsertNode upserts the node record into the database. If the node already
3297
// exists, then the node's information is updated. If the node doesn't exist,
3298
// then a new node is created. The node's features, addresses and extra TLV
3299
// types are also updated. The node's DB ID is returned.
3300
func upsertNode(ctx context.Context, db SQLQueries,
3301
        node *models.LightningNode) (int64, error) {
×
3302

×
3303
        params := sqlc.UpsertNodeParams{
×
3304
                Version: int16(ProtocolV1),
×
3305
                PubKey:  node.PubKeyBytes[:],
×
3306
        }
×
3307

×
3308
        if node.HaveNodeAnnouncement {
×
3309
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3310
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3311
                params.Alias = sqldb.SQLStr(node.Alias)
×
3312
                params.Signature = node.AuthSigBytes
×
3313
        }
×
3314

3315
        nodeID, err := db.UpsertNode(ctx, params)
×
3316
        if err != nil {
×
3317
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3318
                        err)
×
3319
        }
×
3320

3321
        // We can exit here if we don't have the announcement yet.
3322
        if !node.HaveNodeAnnouncement {
×
3323
                return nodeID, nil
×
3324
        }
×
3325

3326
        // Update the node's features.
3327
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3328
        if err != nil {
×
3329
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3330
        }
×
3331

3332
        // Update the node's addresses.
3333
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3334
        if err != nil {
×
3335
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3336
        }
×
3337

3338
        // Convert the flat extra opaque data into a map of TLV types to
3339
        // values.
3340
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3341
        if err != nil {
×
3342
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3343
                        err)
×
3344
        }
×
3345

3346
        // Update the node's extra signed fields.
3347
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3348
        if err != nil {
×
3349
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3350
        }
×
3351

3352
        return nodeID, nil
×
3353
}
3354

3355
// upsertNodeFeatures updates the node's features node_features table. This
3356
// includes deleting any feature bits no longer present and inserting any new
3357
// feature bits. If the feature bit does not yet exist in the features table,
3358
// then an entry is created in that table first.
3359
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3360
        features *lnwire.FeatureVector) error {
×
3361

×
3362
        // Get any existing features for the node.
×
3363
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3364
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3365
                return err
×
3366
        }
×
3367

3368
        // Copy the nodes latest set of feature bits.
3369
        newFeatures := make(map[int32]struct{})
×
3370
        if features != nil {
×
3371
                for feature := range features.Features() {
×
3372
                        newFeatures[int32(feature)] = struct{}{}
×
3373
                }
×
3374
        }
3375

3376
        // For any current feature that already exists in the DB, remove it from
3377
        // the in-memory map. For any existing feature that does not exist in
3378
        // the in-memory map, delete it from the database.
3379
        for _, feature := range existingFeatures {
×
3380
                // The feature is still present, so there are no updates to be
×
3381
                // made.
×
3382
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3383
                        delete(newFeatures, feature.FeatureBit)
×
3384
                        continue
×
3385
                }
3386

3387
                // The feature is no longer present, so we remove it from the
3388
                // database.
3389
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3390
                        NodeID:     nodeID,
×
3391
                        FeatureBit: feature.FeatureBit,
×
3392
                })
×
3393
                if err != nil {
×
3394
                        return fmt.Errorf("unable to delete node(%d) "+
×
3395
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3396
                                err)
×
3397
                }
×
3398
        }
3399

3400
        // Any remaining entries in newFeatures are new features that need to be
3401
        // added to the database for the first time.
3402
        for feature := range newFeatures {
×
3403
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3404
                        NodeID:     nodeID,
×
3405
                        FeatureBit: feature,
×
3406
                })
×
3407
                if err != nil {
×
3408
                        return fmt.Errorf("unable to insert node(%d) "+
×
3409
                                "feature(%v): %w", nodeID, feature, err)
×
3410
                }
×
3411
        }
3412

3413
        return nil
×
3414
}
3415

3416
// fetchNodeFeatures fetches the features for a node with the given public key.
3417
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3418
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3419

×
3420
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3421
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3422
                        PubKey:  nodePub[:],
×
3423
                        Version: int16(ProtocolV1),
×
3424
                },
×
3425
        )
×
3426
        if err != nil {
×
3427
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3428
                        nodePub, err)
×
3429
        }
×
3430

3431
        features := lnwire.EmptyFeatureVector()
×
3432
        for _, bit := range rows {
×
3433
                features.Set(lnwire.FeatureBit(bit))
×
3434
        }
×
3435

3436
        return features, nil
×
3437
}
3438

3439
// dbAddressType is an enum type that represents the different address types
3440
// that we store in the node_addresses table. The address type determines how
3441
// the address is to be serialised/deserialize.
3442
type dbAddressType uint8
3443

3444
const (
3445
        addressTypeIPv4   dbAddressType = 1
3446
        addressTypeIPv6   dbAddressType = 2
3447
        addressTypeTorV2  dbAddressType = 3
3448
        addressTypeTorV3  dbAddressType = 4
3449
        addressTypeOpaque dbAddressType = math.MaxInt8
3450
)
3451

3452
// upsertNodeAddresses updates the node's addresses in the database. This
3453
// includes deleting any existing addresses and inserting the new set of
3454
// addresses. The deletion is necessary since the ordering of the addresses may
3455
// change, and we need to ensure that the database reflects the latest set of
3456
// addresses so that at the time of reconstructing the node announcement, the
3457
// order is preserved and the signature over the message remains valid.
3458
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3459
        addresses []net.Addr) error {
×
3460

×
3461
        // Delete any existing addresses for the node. This is required since
×
3462
        // even if the new set of addresses is the same, the ordering may have
×
3463
        // changed for a given address type.
×
3464
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3465
        if err != nil {
×
3466
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3467
                        nodeID, err)
×
3468
        }
×
3469

3470
        // Copy the nodes latest set of addresses.
3471
        newAddresses := map[dbAddressType][]string{
×
3472
                addressTypeIPv4:   {},
×
3473
                addressTypeIPv6:   {},
×
3474
                addressTypeTorV2:  {},
×
3475
                addressTypeTorV3:  {},
×
3476
                addressTypeOpaque: {},
×
3477
        }
×
3478
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3479
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3480
        }
×
3481

3482
        for _, address := range addresses {
×
3483
                switch addr := address.(type) {
×
3484
                case *net.TCPAddr:
×
3485
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3486
                                addAddr(addressTypeIPv4, addr)
×
3487
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3488
                                addAddr(addressTypeIPv6, addr)
×
3489
                        } else {
×
3490
                                return fmt.Errorf("unhandled IP address: %v",
×
3491
                                        addr)
×
3492
                        }
×
3493

3494
                case *tor.OnionAddr:
×
3495
                        switch len(addr.OnionService) {
×
3496
                        case tor.V2Len:
×
3497
                                addAddr(addressTypeTorV2, addr)
×
3498
                        case tor.V3Len:
×
3499
                                addAddr(addressTypeTorV3, addr)
×
3500
                        default:
×
3501
                                return fmt.Errorf("invalid length for a tor " +
×
3502
                                        "address")
×
3503
                        }
3504

3505
                case *lnwire.OpaqueAddrs:
×
3506
                        addAddr(addressTypeOpaque, addr)
×
3507

3508
                default:
×
3509
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3510
                }
3511
        }
3512

3513
        // Any remaining entries in newAddresses are new addresses that need to
3514
        // be added to the database for the first time.
3515
        for addrType, addrList := range newAddresses {
×
3516
                for position, addr := range addrList {
×
3517
                        err := db.InsertNodeAddress(
×
3518
                                ctx, sqlc.InsertNodeAddressParams{
×
3519
                                        NodeID:   nodeID,
×
3520
                                        Type:     int16(addrType),
×
3521
                                        Address:  addr,
×
3522
                                        Position: int32(position),
×
3523
                                },
×
3524
                        )
×
3525
                        if err != nil {
×
3526
                                return fmt.Errorf("unable to insert "+
×
3527
                                        "node(%d) address(%v): %w", nodeID,
×
3528
                                        addr, err)
×
3529
                        }
×
3530
                }
3531
        }
3532

3533
        return nil
×
3534
}
3535

3536
// getNodeAddresses fetches the addresses for a node with the given public key.
3537
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3538
        []net.Addr, error) {
×
3539

×
3540
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
3541
        // are returned in the same order as they were inserted.
×
3542
        rows, err := db.GetNodeAddressesByPubKey(
×
3543
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
3544
                        Version: int16(ProtocolV1),
×
3545
                        PubKey:  nodePub,
×
3546
                },
×
3547
        )
×
3548
        if err != nil {
×
3549
                return false, nil, err
×
3550
        }
×
3551

3552
        // GetNodeAddressesByPubKey uses a left join so there should always be
3553
        // at least one row returned if the node exists even if it has no
3554
        // addresses.
3555
        if len(rows) == 0 {
×
3556
                return false, nil, nil
×
3557
        }
×
3558

3559
        addresses := make([]net.Addr, 0, len(rows))
×
3560
        for _, addr := range rows {
×
3561
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3562
                        continue
×
3563
                }
3564

3565
                address := addr.Address.String
×
3566

×
3567
                switch dbAddressType(addr.Type.Int16) {
×
3568
                case addressTypeIPv4:
×
3569
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3570
                        if err != nil {
×
3571
                                return false, nil, nil
×
3572
                        }
×
3573
                        tcp.IP = tcp.IP.To4()
×
3574

×
3575
                        addresses = append(addresses, tcp)
×
3576

3577
                case addressTypeIPv6:
×
3578
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3579
                        if err != nil {
×
3580
                                return false, nil, nil
×
3581
                        }
×
3582
                        addresses = append(addresses, tcp)
×
3583

3584
                case addressTypeTorV3, addressTypeTorV2:
×
3585
                        service, portStr, err := net.SplitHostPort(address)
×
3586
                        if err != nil {
×
3587
                                return false, nil, fmt.Errorf("unable to "+
×
3588
                                        "split tor v3 address: %v",
×
3589
                                        addr.Address)
×
3590
                        }
×
3591

3592
                        port, err := strconv.Atoi(portStr)
×
3593
                        if err != nil {
×
3594
                                return false, nil, err
×
3595
                        }
×
3596

3597
                        addresses = append(addresses, &tor.OnionAddr{
×
3598
                                OnionService: service,
×
3599
                                Port:         port,
×
3600
                        })
×
3601

3602
                case addressTypeOpaque:
×
3603
                        opaque, err := hex.DecodeString(address)
×
3604
                        if err != nil {
×
3605
                                return false, nil, fmt.Errorf("unable to "+
×
3606
                                        "decode opaque address: %v", addr)
×
3607
                        }
×
3608

3609
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3610
                                Payload: opaque,
×
3611
                        })
×
3612

3613
                default:
×
3614
                        return false, nil, fmt.Errorf("unknown address "+
×
3615
                                "type: %v", addr.Type)
×
3616
                }
3617
        }
3618

3619
        // If we have no addresses, then we'll return nil instead of an
3620
        // empty slice.
3621
        if len(addresses) == 0 {
×
3622
                addresses = nil
×
3623
        }
×
3624

3625
        return true, addresses, nil
×
3626
}
3627

3628
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3629
// database. This includes updating any existing types, inserting any new types,
3630
// and deleting any types that are no longer present.
3631
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3632
        nodeID int64, extraFields map[uint64][]byte) error {
×
3633

×
3634
        // Get any existing extra signed fields for the node.
×
3635
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3636
        if err != nil {
×
3637
                return err
×
3638
        }
×
3639

3640
        // Make a lookup map of the existing field types so that we can use it
3641
        // to keep track of any fields we should delete.
3642
        m := make(map[uint64]bool)
×
3643
        for _, field := range existingFields {
×
3644
                m[uint64(field.Type)] = true
×
3645
        }
×
3646

3647
        // For all the new fields, we'll upsert them and remove them from the
3648
        // map of existing fields.
3649
        for tlvType, value := range extraFields {
×
3650
                err = db.UpsertNodeExtraType(
×
3651
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3652
                                NodeID: nodeID,
×
3653
                                Type:   int64(tlvType),
×
3654
                                Value:  value,
×
3655
                        },
×
3656
                )
×
3657
                if err != nil {
×
3658
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3659
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3660
                }
×
3661

3662
                // Remove the field from the map of existing fields if it was
3663
                // present.
3664
                delete(m, tlvType)
×
3665
        }
3666

3667
        // For all the fields that are left in the map of existing fields, we'll
3668
        // delete them as they are no longer present in the new set of fields.
3669
        for tlvType := range m {
×
3670
                err = db.DeleteExtraNodeType(
×
3671
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3672
                                NodeID: nodeID,
×
3673
                                Type:   int64(tlvType),
×
3674
                        },
×
3675
                )
×
3676
                if err != nil {
×
3677
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3678
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3679
                }
×
3680
        }
3681

3682
        return nil
×
3683
}
3684

3685
// srcNodeInfo holds the information about the source node of the graph.
3686
type srcNodeInfo struct {
3687
        // id is the DB level ID of the source node entry in the "nodes" table.
3688
        id int64
3689

3690
        // pub is the public key of the source node.
3691
        pub route.Vertex
3692
}
3693

3694
// sourceNode returns the DB node ID and pub key of the source node for the
3695
// specified protocol version.
3696
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3697
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3698

×
3699
        s.srcNodeMu.Lock()
×
3700
        defer s.srcNodeMu.Unlock()
×
3701

×
3702
        // If we already have the source node ID and pub key cached, then
×
3703
        // return them.
×
3704
        if info, ok := s.srcNodes[version]; ok {
×
3705
                return info.id, info.pub, nil
×
3706
        }
×
3707

3708
        var pubKey route.Vertex
×
3709

×
3710
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3711
        if err != nil {
×
3712
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3713
                        err)
×
3714
        }
×
3715

3716
        if len(nodes) == 0 {
×
3717
                return 0, pubKey, ErrSourceNodeNotSet
×
3718
        } else if len(nodes) > 1 {
×
3719
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3720
                        "protocol %s found", version)
×
3721
        }
×
3722

3723
        copy(pubKey[:], nodes[0].PubKey)
×
3724

×
3725
        s.srcNodes[version] = &srcNodeInfo{
×
3726
                id:  nodes[0].NodeID,
×
3727
                pub: pubKey,
×
3728
        }
×
3729

×
3730
        return nodes[0].NodeID, pubKey, nil
×
3731
}
3732

3733
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3734
// This then produces a map from TLV type to value. If the input is not a
3735
// valid TLV stream, then an error is returned.
3736
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3737
        r := bytes.NewReader(data)
×
3738

×
3739
        tlvStream, err := tlv.NewStream()
×
3740
        if err != nil {
×
3741
                return nil, err
×
3742
        }
×
3743

3744
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3745
        // pass it into the P2P decoding variant.
3746
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3747
        if err != nil {
×
3748
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3749
        }
×
3750
        if len(parsedTypes) == 0 {
×
3751
                return nil, nil
×
3752
        }
×
3753

3754
        records := make(map[uint64][]byte)
×
3755
        for k, v := range parsedTypes {
×
3756
                records[uint64(k)] = v
×
3757
        }
×
3758

3759
        return records, nil
×
3760
}
3761

3762
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3763
// channel.
3764
type dbChanInfo struct {
3765
        channelID int64
3766
        node1ID   int64
3767
        node2ID   int64
3768
}
3769

3770
// insertChannel inserts a new channel record into the database.
3771
func insertChannel(ctx context.Context, db SQLQueries,
3772
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3773

×
3774
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3775

×
3776
        // Make sure that the channel doesn't already exist. We do this
×
3777
        // explicitly instead of relying on catching a unique constraint error
×
3778
        // because relying on SQL to throw that error would abort the entire
×
3779
        // batch of transactions.
×
3780
        _, err := db.GetChannelBySCID(
×
3781
                ctx, sqlc.GetChannelBySCIDParams{
×
3782
                        Scid:    chanIDB,
×
3783
                        Version: int16(ProtocolV1),
×
3784
                },
×
3785
        )
×
3786
        if err == nil {
×
3787
                return nil, ErrEdgeAlreadyExist
×
3788
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3789
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3790
        }
×
3791

3792
        // Make sure that at least a "shell" entry for each node is present in
3793
        // the nodes table.
3794
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3795
        if err != nil {
×
3796
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3797
        }
×
3798

3799
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3800
        if err != nil {
×
3801
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3802
        }
×
3803

3804
        var capacity sql.NullInt64
×
3805
        if edge.Capacity != 0 {
×
3806
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3807
        }
×
3808

3809
        createParams := sqlc.CreateChannelParams{
×
3810
                Version:     int16(ProtocolV1),
×
3811
                Scid:        chanIDB,
×
3812
                NodeID1:     node1DBID,
×
3813
                NodeID2:     node2DBID,
×
3814
                Outpoint:    edge.ChannelPoint.String(),
×
3815
                Capacity:    capacity,
×
3816
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3817
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3818
        }
×
3819

×
3820
        if edge.AuthProof != nil {
×
3821
                proof := edge.AuthProof
×
3822

×
3823
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3824
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3825
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3826
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3827
        }
×
3828

3829
        // Insert the new channel record.
3830
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3831
        if err != nil {
×
3832
                return nil, err
×
3833
        }
×
3834

3835
        // Insert any channel features.
NEW
3836
        for feature := range edge.Features.Features() {
×
NEW
3837
                err = db.InsertChannelFeature(
×
NEW
3838
                        ctx, sqlc.InsertChannelFeatureParams{
×
NEW
3839
                                ChannelID:  dbChanID,
×
NEW
3840
                                FeatureBit: int32(feature),
×
NEW
3841
                        },
×
NEW
3842
                )
×
3843
                if err != nil {
×
NEW
3844
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
NEW
3845
                                "feature(%v): %w", dbChanID, feature, err)
×
UNCOV
3846
                }
×
3847
        }
3848

3849
        // Finally, insert any extra TLV fields in the channel announcement.
3850
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3851
        if err != nil {
×
3852
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
3853
                        "data: %w", err)
×
3854
        }
×
3855

3856
        for tlvType, value := range extra {
×
3857
                err := db.CreateChannelExtraType(
×
3858
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3859
                                ChannelID: dbChanID,
×
3860
                                Type:      int64(tlvType),
×
3861
                                Value:     value,
×
3862
                        },
×
3863
                )
×
3864
                if err != nil {
×
3865
                        return nil, fmt.Errorf("unable to upsert "+
×
3866
                                "channel(%d) extra signed field(%v): %w",
×
3867
                                edge.ChannelID, tlvType, err)
×
3868
                }
×
3869
        }
3870

3871
        return &dbChanInfo{
×
3872
                channelID: dbChanID,
×
3873
                node1ID:   node1DBID,
×
3874
                node2ID:   node2DBID,
×
3875
        }, nil
×
3876
}
3877

3878
// maybeCreateShellNode checks if a shell node entry exists for the
3879
// given public key. If it does not exist, then a new shell node entry is
3880
// created. The ID of the node is returned. A shell node only has a protocol
3881
// version and public key persisted.
3882
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3883
        pubKey route.Vertex) (int64, error) {
×
3884

×
3885
        dbNode, err := db.GetNodeByPubKey(
×
3886
                ctx, sqlc.GetNodeByPubKeyParams{
×
3887
                        PubKey:  pubKey[:],
×
3888
                        Version: int16(ProtocolV1),
×
3889
                },
×
3890
        )
×
3891
        // The node exists. Return the ID.
×
3892
        if err == nil {
×
3893
                return dbNode.ID, nil
×
3894
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3895
                return 0, err
×
3896
        }
×
3897

3898
        // Otherwise, the node does not exist, so we create a shell entry for
3899
        // it.
3900
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3901
                Version: int16(ProtocolV1),
×
3902
                PubKey:  pubKey[:],
×
3903
        })
×
3904
        if err != nil {
×
3905
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3906
        }
×
3907

3908
        return id, nil
×
3909
}
3910

3911
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3912
// the database. This includes deleting any existing types and then inserting
3913
// the new types.
3914
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3915
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3916

×
3917
        // Delete all existing extra signed fields for the channel policy.
×
3918
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3919
        if err != nil {
×
3920
                return fmt.Errorf("unable to delete "+
×
3921
                        "existing policy extra signed fields for policy %d: %w",
×
3922
                        chanPolicyID, err)
×
3923
        }
×
3924

3925
        // Insert all new extra signed fields for the channel policy.
3926
        for tlvType, value := range extraFields {
×
3927
                err = db.InsertChanPolicyExtraType(
×
3928
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3929
                                ChannelPolicyID: chanPolicyID,
×
3930
                                Type:            int64(tlvType),
×
3931
                                Value:           value,
×
3932
                        },
×
3933
                )
×
3934
                if err != nil {
×
3935
                        return fmt.Errorf("unable to insert "+
×
3936
                                "channel_policy(%d) extra signed field(%v): %w",
×
3937
                                chanPolicyID, tlvType, err)
×
3938
                }
×
3939
        }
3940

3941
        return nil
×
3942
}
3943

3944
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3945
// provided dbChanRow and also fetches any other required information
3946
// to construct the edge info.
3947
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
3948
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1,
3949
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3950

×
3951
        if dbChan.Version != int16(ProtocolV1) {
×
3952
                return nil, fmt.Errorf("unsupported channel version: %d",
×
3953
                        dbChan.Version)
×
3954
        }
×
3955

3956
        fv, extras, err := getChanFeaturesAndExtras(
×
3957
                ctx, db, dbChanID,
×
3958
        )
×
3959
        if err != nil {
×
3960
                return nil, err
×
3961
        }
×
3962

3963
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
3964
        if err != nil {
×
3965
                return nil, err
×
3966
        }
×
3967

3968
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3969
        if err != nil {
×
3970
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3971
                        "fields: %w", err)
×
3972
        }
×
3973
        if recs == nil {
×
3974
                recs = make([]byte, 0)
×
3975
        }
×
3976

3977
        var btcKey1, btcKey2 route.Vertex
×
3978
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
3979
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
3980

×
3981
        channel := &models.ChannelEdgeInfo{
×
3982
                ChainHash:        chain,
×
3983
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
3984
                NodeKey1Bytes:    node1,
×
3985
                NodeKey2Bytes:    node2,
×
3986
                BitcoinKey1Bytes: btcKey1,
×
3987
                BitcoinKey2Bytes: btcKey2,
×
3988
                ChannelPoint:     *op,
×
3989
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
NEW
3990
                Features:         fv,
×
3991
                ExtraOpaqueData:  recs,
×
3992
        }
×
3993

×
3994
        // We always set all the signatures at the same time, so we can
×
3995
        // safely check if one signature is present to determine if we have the
×
3996
        // rest of the signatures for the auth proof.
×
3997
        if len(dbChan.Bitcoin1Signature) > 0 {
×
3998
                channel.AuthProof = &models.ChannelAuthProof{
×
3999
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4000
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4001
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4002
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4003
                }
×
4004
        }
×
4005

4006
        return channel, nil
×
4007
}
4008

4009
// buildNodeVertices is a helper that converts raw node public keys
4010
// into route.Vertex instances.
4011
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4012
        route.Vertex, error) {
×
4013

×
4014
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4015
        if err != nil {
×
4016
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4017
                        "create vertex from node1 pubkey: %w", err)
×
4018
        }
×
4019

4020
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4021
        if err != nil {
×
4022
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4023
                        "create vertex from node2 pubkey: %w", err)
×
4024
        }
×
4025

4026
        return node1Vertex, node2Vertex, nil
×
4027
}
4028

4029
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
4030
// for a channel with the given ID.
4031
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
4032
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
4033

×
4034
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
4035
        if err != nil {
×
4036
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
4037
                        "features and extras: %w", err)
×
4038
        }
×
4039

4040
        var (
×
4041
                fv     = lnwire.EmptyFeatureVector()
×
4042
                extras = make(map[uint64][]byte)
×
4043
        )
×
4044
        for _, row := range rows {
×
4045
                if row.IsFeature {
×
4046
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
4047

×
4048
                        continue
×
4049
                }
4050

4051
                tlvType, ok := row.ExtraKey.(int64)
×
4052
                if !ok {
×
4053
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4054
                                "TLV type: %T", row.ExtraKey)
×
4055
                }
×
4056

4057
                valueBytes, ok := row.Value.([]byte)
×
4058
                if !ok {
×
4059
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4060
                                "Value: %T", row.Value)
×
4061
                }
×
4062

4063
                extras[uint64(tlvType)] = valueBytes
×
4064
        }
4065

4066
        return fv, extras, nil
×
4067
}
4068

4069
// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves
4070
// all the extra info required to build the complete models.ChannelEdgePolicy
4071
// types. It returns two policies, which may be nil if the provided
4072
// sqlc.ChannelPolicy records are nil.
4073
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4074
        dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1,
4075
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4076
        *models.ChannelEdgePolicy, error) {
×
4077

×
4078
        if dbPol1 == nil && dbPol2 == nil {
×
4079
                return nil, nil, nil
×
4080
        }
×
4081

4082
        var (
×
4083
                policy1ID int64
×
4084
                policy2ID int64
×
4085
        )
×
4086
        if dbPol1 != nil {
×
4087
                policy1ID = dbPol1.ID
×
4088
        }
×
4089
        if dbPol2 != nil {
×
4090
                policy2ID = dbPol2.ID
×
4091
        }
×
4092
        rows, err := db.GetChannelPolicyExtraTypes(
×
4093
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
4094
                        ID:   policy1ID,
×
4095
                        ID_2: policy2ID,
×
4096
                },
×
4097
        )
×
4098
        if err != nil {
×
4099
                return nil, nil, err
×
4100
        }
×
4101

4102
        var (
×
4103
                dbPol1Extras = make(map[uint64][]byte)
×
4104
                dbPol2Extras = make(map[uint64][]byte)
×
4105
        )
×
4106
        for _, row := range rows {
×
4107
                switch row.PolicyID {
×
4108
                case policy1ID:
×
4109
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
4110
                case policy2ID:
×
4111
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
4112
                default:
×
4113
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
4114
                                "in row: %v", row.PolicyID, row)
×
4115
                }
4116
        }
4117

4118
        var pol1, pol2 *models.ChannelEdgePolicy
×
4119
        if dbPol1 != nil {
×
4120
                pol1, err = buildChanPolicy(
×
4121
                        *dbPol1, channelID, dbPol1Extras, node2,
×
4122
                )
×
4123
                if err != nil {
×
4124
                        return nil, nil, err
×
4125
                }
×
4126
        }
4127
        if dbPol2 != nil {
×
4128
                pol2, err = buildChanPolicy(
×
4129
                        *dbPol2, channelID, dbPol2Extras, node1,
×
4130
                )
×
4131
                if err != nil {
×
4132
                        return nil, nil, err
×
4133
                }
×
4134
        }
4135

4136
        return pol1, pol2, nil
×
4137
}
4138

4139
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4140
// provided sqlc.ChannelPolicy and other required information.
4141
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
4142
        extras map[uint64][]byte,
4143
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4144

×
4145
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4146
        if err != nil {
×
4147
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4148
                        "fields: %w", err)
×
4149
        }
×
4150

4151
        var inboundFee fn.Option[lnwire.Fee]
×
4152
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4153
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4154

×
4155
                inboundFee = fn.Some(lnwire.Fee{
×
4156
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4157
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4158
                })
×
4159
        }
×
4160

4161
        return &models.ChannelEdgePolicy{
×
4162
                SigBytes:  dbPolicy.Signature,
×
4163
                ChannelID: channelID,
×
4164
                LastUpdate: time.Unix(
×
4165
                        dbPolicy.LastUpdate.Int64, 0,
×
4166
                ),
×
4167
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4168
                        dbPolicy.MessageFlags,
×
4169
                ),
×
4170
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4171
                        dbPolicy.ChannelFlags,
×
4172
                ),
×
4173
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4174
                MinHTLC: lnwire.MilliSatoshi(
×
4175
                        dbPolicy.MinHtlcMsat,
×
4176
                ),
×
4177
                MaxHTLC: lnwire.MilliSatoshi(
×
4178
                        dbPolicy.MaxHtlcMsat.Int64,
×
4179
                ),
×
4180
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4181
                        dbPolicy.BaseFeeMsat,
×
4182
                ),
×
4183
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4184
                ToNode:                    toNode,
×
4185
                InboundFee:                inboundFee,
×
4186
                ExtraOpaqueData:           recs,
×
4187
        }, nil
×
4188
}
4189

4190
// buildNodes builds the models.LightningNode instances for the
4191
// given row which is expected to be a sqlc type that contains node information.
4192
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4193
        dbNode2 sqlc.Node) (*models.LightningNode, *models.LightningNode,
4194
        error) {
×
4195

×
4196
        node1, err := buildNode(ctx, db, &dbNode1)
×
4197
        if err != nil {
×
4198
                return nil, nil, err
×
4199
        }
×
4200

4201
        node2, err := buildNode(ctx, db, &dbNode2)
×
4202
        if err != nil {
×
4203
                return nil, nil, err
×
4204
        }
×
4205

4206
        return node1, node2, nil
×
4207
}
4208

4209
// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give
4210
// row which is expected to be a sqlc type that contains channel policy
4211
// information. It returns two policies, which may be nil if the policy
4212
// information is not present in the row.
4213
//
4214
//nolint:ll,dupl,funlen
4215
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
4216
        error) {
×
4217

×
4218
        var policy1, policy2 *sqlc.ChannelPolicy
×
4219
        switch r := row.(type) {
×
4220
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4221
                if r.Policy1ID.Valid {
×
4222
                        policy1 = &sqlc.ChannelPolicy{
×
4223
                                ID:                      r.Policy1ID.Int64,
×
4224
                                Version:                 r.Policy1Version.Int16,
×
4225
                                ChannelID:               r.Channel.ID,
×
4226
                                NodeID:                  r.Policy1NodeID.Int64,
×
4227
                                Timelock:                r.Policy1Timelock.Int32,
×
4228
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4229
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4230
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4231
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4232
                                LastUpdate:              r.Policy1LastUpdate,
×
4233
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4234
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4235
                                Disabled:                r.Policy1Disabled,
×
4236
                                MessageFlags:            r.Policy1MessageFlags,
×
4237
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4238
                                Signature:               r.Policy1Signature,
×
4239
                        }
×
4240
                }
×
4241
                if r.Policy2ID.Valid {
×
4242
                        policy2 = &sqlc.ChannelPolicy{
×
4243
                                ID:                      r.Policy2ID.Int64,
×
4244
                                Version:                 r.Policy2Version.Int16,
×
4245
                                ChannelID:               r.Channel.ID,
×
4246
                                NodeID:                  r.Policy2NodeID.Int64,
×
4247
                                Timelock:                r.Policy2Timelock.Int32,
×
4248
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4249
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4250
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4251
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4252
                                LastUpdate:              r.Policy2LastUpdate,
×
4253
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4254
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4255
                                Disabled:                r.Policy2Disabled,
×
4256
                                MessageFlags:            r.Policy2MessageFlags,
×
4257
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4258
                                Signature:               r.Policy2Signature,
×
4259
                        }
×
4260
                }
×
4261

4262
                return policy1, policy2, nil
×
4263

4264
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4265
                if r.Policy1ID.Valid {
×
4266
                        policy1 = &sqlc.ChannelPolicy{
×
4267
                                ID:                      r.Policy1ID.Int64,
×
4268
                                Version:                 r.Policy1Version.Int16,
×
4269
                                ChannelID:               r.Channel.ID,
×
4270
                                NodeID:                  r.Policy1NodeID.Int64,
×
4271
                                Timelock:                r.Policy1Timelock.Int32,
×
4272
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4273
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4274
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4275
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4276
                                LastUpdate:              r.Policy1LastUpdate,
×
4277
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4278
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4279
                                Disabled:                r.Policy1Disabled,
×
4280
                                MessageFlags:            r.Policy1MessageFlags,
×
4281
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4282
                                Signature:               r.Policy1Signature,
×
4283
                        }
×
4284
                }
×
4285
                if r.Policy2ID.Valid {
×
4286
                        policy2 = &sqlc.ChannelPolicy{
×
4287
                                ID:                      r.Policy2ID.Int64,
×
4288
                                Version:                 r.Policy2Version.Int16,
×
4289
                                ChannelID:               r.Channel.ID,
×
4290
                                NodeID:                  r.Policy2NodeID.Int64,
×
4291
                                Timelock:                r.Policy2Timelock.Int32,
×
4292
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4293
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4294
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4295
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4296
                                LastUpdate:              r.Policy2LastUpdate,
×
4297
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4298
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4299
                                Disabled:                r.Policy2Disabled,
×
4300
                                MessageFlags:            r.Policy2MessageFlags,
×
4301
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4302
                                Signature:               r.Policy2Signature,
×
4303
                        }
×
4304
                }
×
4305

4306
                return policy1, policy2, nil
×
4307

4308
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4309
                if r.Policy1ID.Valid {
×
4310
                        policy1 = &sqlc.ChannelPolicy{
×
4311
                                ID:                      r.Policy1ID.Int64,
×
4312
                                Version:                 r.Policy1Version.Int16,
×
4313
                                ChannelID:               r.Channel.ID,
×
4314
                                NodeID:                  r.Policy1NodeID.Int64,
×
4315
                                Timelock:                r.Policy1Timelock.Int32,
×
4316
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4317
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4318
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4319
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4320
                                LastUpdate:              r.Policy1LastUpdate,
×
4321
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4322
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4323
                                Disabled:                r.Policy1Disabled,
×
4324
                                MessageFlags:            r.Policy1MessageFlags,
×
4325
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4326
                                Signature:               r.Policy1Signature,
×
4327
                        }
×
4328
                }
×
4329
                if r.Policy2ID.Valid {
×
4330
                        policy2 = &sqlc.ChannelPolicy{
×
4331
                                ID:                      r.Policy2ID.Int64,
×
4332
                                Version:                 r.Policy2Version.Int16,
×
4333
                                ChannelID:               r.Channel.ID,
×
4334
                                NodeID:                  r.Policy2NodeID.Int64,
×
4335
                                Timelock:                r.Policy2Timelock.Int32,
×
4336
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4337
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4338
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4339
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4340
                                LastUpdate:              r.Policy2LastUpdate,
×
4341
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4342
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4343
                                Disabled:                r.Policy2Disabled,
×
4344
                                MessageFlags:            r.Policy2MessageFlags,
×
4345
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4346
                                Signature:               r.Policy2Signature,
×
4347
                        }
×
4348
                }
×
4349

4350
                return policy1, policy2, nil
×
4351

4352
        case sqlc.ListChannelsByNodeIDRow:
×
4353
                if r.Policy1ID.Valid {
×
4354
                        policy1 = &sqlc.ChannelPolicy{
×
4355
                                ID:                      r.Policy1ID.Int64,
×
4356
                                Version:                 r.Policy1Version.Int16,
×
4357
                                ChannelID:               r.Channel.ID,
×
4358
                                NodeID:                  r.Policy1NodeID.Int64,
×
4359
                                Timelock:                r.Policy1Timelock.Int32,
×
4360
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4361
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4362
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4363
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4364
                                LastUpdate:              r.Policy1LastUpdate,
×
4365
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4366
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4367
                                Disabled:                r.Policy1Disabled,
×
4368
                                MessageFlags:            r.Policy1MessageFlags,
×
4369
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4370
                                Signature:               r.Policy1Signature,
×
4371
                        }
×
4372
                }
×
4373
                if r.Policy2ID.Valid {
×
4374
                        policy2 = &sqlc.ChannelPolicy{
×
4375
                                ID:                      r.Policy2ID.Int64,
×
4376
                                Version:                 r.Policy2Version.Int16,
×
4377
                                ChannelID:               r.Channel.ID,
×
4378
                                NodeID:                  r.Policy2NodeID.Int64,
×
4379
                                Timelock:                r.Policy2Timelock.Int32,
×
4380
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4381
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4382
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4383
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4384
                                LastUpdate:              r.Policy2LastUpdate,
×
4385
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4386
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4387
                                Disabled:                r.Policy2Disabled,
×
4388
                                MessageFlags:            r.Policy2MessageFlags,
×
4389
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4390
                                Signature:               r.Policy2Signature,
×
4391
                        }
×
4392
                }
×
4393

4394
                return policy1, policy2, nil
×
4395

4396
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4397
                if r.Policy1ID.Valid {
×
4398
                        policy1 = &sqlc.ChannelPolicy{
×
4399
                                ID:                      r.Policy1ID.Int64,
×
4400
                                Version:                 r.Policy1Version.Int16,
×
4401
                                ChannelID:               r.Channel.ID,
×
4402
                                NodeID:                  r.Policy1NodeID.Int64,
×
4403
                                Timelock:                r.Policy1Timelock.Int32,
×
4404
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4405
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4406
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4407
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4408
                                LastUpdate:              r.Policy1LastUpdate,
×
4409
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4410
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4411
                                Disabled:                r.Policy1Disabled,
×
4412
                                MessageFlags:            r.Policy1MessageFlags,
×
4413
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4414
                                Signature:               r.Policy1Signature,
×
4415
                        }
×
4416
                }
×
4417
                if r.Policy2ID.Valid {
×
4418
                        policy2 = &sqlc.ChannelPolicy{
×
4419
                                ID:                      r.Policy2ID.Int64,
×
4420
                                Version:                 r.Policy2Version.Int16,
×
4421
                                ChannelID:               r.Channel.ID,
×
4422
                                NodeID:                  r.Policy2NodeID.Int64,
×
4423
                                Timelock:                r.Policy2Timelock.Int32,
×
4424
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4425
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4426
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4427
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4428
                                LastUpdate:              r.Policy2LastUpdate,
×
4429
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4430
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4431
                                Disabled:                r.Policy2Disabled,
×
4432
                                MessageFlags:            r.Policy2MessageFlags,
×
4433
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4434
                                Signature:               r.Policy2Signature,
×
4435
                        }
×
4436
                }
×
4437

4438
                return policy1, policy2, nil
×
4439
        default:
×
4440
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4441
                        "extractChannelPolicies: %T", r)
×
4442
        }
4443
}
4444

4445
// channelIDToBytes converts a channel ID (SCID) to a byte array
4446
// representation.
4447
func channelIDToBytes(channelID uint64) []byte {
×
4448
        var chanIDB [8]byte
×
4449
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4450

×
4451
        return chanIDB[:]
×
4452
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc