• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16111275918

07 Jul 2025 07:59AM UTC coverage: 57.777% (-0.03%) from 57.803%
16111275918

Pull #10043

github

web-flow
Merge f70ed0693 into ff32e90d1
Pull Request #10043: multi: add context.Context param to more graphdb.V1Store methods

25 of 38 new or added lines in 8 files covered. (65.79%)

84 existing lines in 10 files now uncovered.

98445 of 170387 relevant lines covered (57.78%)

1.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
72
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
73
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
74

75
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
96
        GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error)
97
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
98
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
99
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
100
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
101
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
102
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
103
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
104
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
105
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
106
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
107
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
108
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
109
        DeleteChannel(ctx context.Context, id int64) error
110

111
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
112
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
113

114
        /*
115
                Channel Policy table queries.
116
        */
117
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
118
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.ChannelPolicy, error)
119
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
120

121
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
122
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
123
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
124

125
        /*
126
                Zombie index queries.
127
        */
128
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
129
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error)
130
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
131
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
132
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
133

134
        /*
135
                Prune log table queries.
136
        */
137
        GetPruneTip(ctx context.Context) (sqlc.PruneLog, error)
138
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
139
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
140

141
        /*
142
                Closed SCID table queries.
143
        */
144
        InsertClosedChannel(ctx context.Context, scid []byte) error
145
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
146
}
147

148
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
149
// database operations.
150
type BatchedSQLQueries interface {
151
        SQLQueries
152
        sqldb.BatchedTx[SQLQueries]
153
}
154

155
// SQLStore is an implementation of the V1Store interface that uses a SQL
156
// database as the backend.
157
type SQLStore struct {
158
        cfg *SQLStoreConfig
159
        db  BatchedSQLQueries
160

161
        // cacheMu guards all caches (rejectCache and chanCache). If
162
        // this mutex will be acquired at the same time as the DB mutex then
163
        // the cacheMu MUST be acquired first to prevent deadlock.
164
        cacheMu     sync.RWMutex
165
        rejectCache *rejectCache
166
        chanCache   *channelCache
167

168
        chanScheduler batch.Scheduler[SQLQueries]
169
        nodeScheduler batch.Scheduler[SQLQueries]
170

171
        srcNodes  map[ProtocolVersion]*srcNodeInfo
172
        srcNodeMu sync.Mutex
173
}
174

175
// A compile-time assertion to ensure that SQLStore implements the V1Store
176
// interface.
177
var _ V1Store = (*SQLStore)(nil)
178

179
// SQLStoreConfig holds the configuration for the SQLStore.
180
type SQLStoreConfig struct {
181
        // ChainHash is the genesis hash for the chain that all the gossip
182
        // messages in this store are aimed at.
183
        ChainHash chainhash.Hash
184
}
185

186
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
187
// storage backend.
188
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
189
        options ...StoreOptionModifier) (*SQLStore, error) {
×
190

×
191
        opts := DefaultOptions()
×
192
        for _, o := range options {
×
193
                o(opts)
×
194
        }
×
195

196
        if opts.NoMigration {
×
197
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
198
                        "supported for SQL stores")
×
199
        }
×
200

201
        s := &SQLStore{
×
202
                cfg:         cfg,
×
203
                db:          db,
×
204
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
205
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
206
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
207
        }
×
208

×
209
        s.chanScheduler = batch.NewTimeScheduler(
×
210
                db, &s.cacheMu, opts.BatchCommitInterval,
×
211
        )
×
212
        s.nodeScheduler = batch.NewTimeScheduler(
×
213
                db, nil, opts.BatchCommitInterval,
×
214
        )
×
215

×
216
        return s, nil
×
217
}
218

219
// AddLightningNode adds a vertex/node to the graph database. If the node is not
220
// in the database from before, this will add a new, unconnected one to the
221
// graph. If it is present from before, this will update that node's
222
// information.
223
//
224
// NOTE: part of the V1Store interface.
225
func (s *SQLStore) AddLightningNode(ctx context.Context,
226
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
227

×
228
        r := &batch.Request[SQLQueries]{
×
229
                Opts: batch.NewSchedulerOptions(opts...),
×
230
                Do: func(queries SQLQueries) error {
×
231
                        _, err := upsertNode(ctx, queries, node)
×
232
                        return err
×
233
                },
×
234
        }
235

236
        return s.nodeScheduler.Execute(ctx, r)
×
237
}
238

239
// FetchLightningNode attempts to look up a target node by its identity public
240
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
241
// returned.
242
//
243
// NOTE: part of the V1Store interface.
244
func (s *SQLStore) FetchLightningNode(ctx context.Context,
245
        pubKey route.Vertex) (*models.LightningNode, error) {
×
246

×
247
        var node *models.LightningNode
×
248
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
249
                var err error
×
250
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
251

×
252
                return err
×
253
        }, sqldb.NoOpReset)
×
254
        if err != nil {
×
255
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
256
        }
×
257

258
        return node, nil
×
259
}
260

261
// HasLightningNode determines if the graph has a vertex identified by the
262
// target node identity public key. If the node exists in the database, a
263
// timestamp of when the data for the node was lasted updated is returned along
264
// with a true boolean. Otherwise, an empty time.Time is returned with a false
265
// boolean.
266
//
267
// NOTE: part of the V1Store interface.
268
func (s *SQLStore) HasLightningNode(ctx context.Context,
269
        pubKey [33]byte) (time.Time, bool, error) {
×
270

×
271
        var (
×
272
                exists     bool
×
273
                lastUpdate time.Time
×
274
        )
×
275
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
276
                dbNode, err := db.GetNodeByPubKey(
×
277
                        ctx, sqlc.GetNodeByPubKeyParams{
×
278
                                Version: int16(ProtocolV1),
×
279
                                PubKey:  pubKey[:],
×
280
                        },
×
281
                )
×
282
                if errors.Is(err, sql.ErrNoRows) {
×
283
                        return nil
×
284
                } else if err != nil {
×
285
                        return fmt.Errorf("unable to fetch node: %w", err)
×
286
                }
×
287

288
                exists = true
×
289

×
290
                if dbNode.LastUpdate.Valid {
×
291
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
292
                }
×
293

294
                return nil
×
295
        }, sqldb.NoOpReset)
296
        if err != nil {
×
297
                return time.Time{}, false,
×
298
                        fmt.Errorf("unable to fetch node: %w", err)
×
299
        }
×
300

301
        return lastUpdate, exists, nil
×
302
}
303

304
// AddrsForNode returns all known addresses for the target node public key
305
// that the graph DB is aware of. The returned boolean indicates if the
306
// given node is unknown to the graph DB or not.
307
//
308
// NOTE: part of the V1Store interface.
309
func (s *SQLStore) AddrsForNode(ctx context.Context,
310
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
311

×
312
        var (
×
313
                addresses []net.Addr
×
314
                known     bool
×
315
        )
×
316
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
317
                var err error
×
318
                known, addresses, err = getNodeAddresses(
×
319
                        ctx, db, nodePub.SerializeCompressed(),
×
320
                )
×
321
                if err != nil {
×
322
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
323
                                err)
×
324
                }
×
325

326
                return nil
×
327
        }, sqldb.NoOpReset)
328
        if err != nil {
×
329
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
330
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
331
        }
×
332

333
        return known, addresses, nil
×
334
}
335

336
// DeleteLightningNode starts a new database transaction to remove a vertex/node
337
// from the database according to the node's public key.
338
//
339
// NOTE: part of the V1Store interface.
340
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
341
        pubKey route.Vertex) error {
×
342

×
343
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
344
                res, err := db.DeleteNodeByPubKey(
×
345
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
346
                                Version: int16(ProtocolV1),
×
347
                                PubKey:  pubKey[:],
×
348
                        },
×
349
                )
×
350
                if err != nil {
×
351
                        return err
×
352
                }
×
353

354
                rows, err := res.RowsAffected()
×
355
                if err != nil {
×
356
                        return err
×
357
                }
×
358

359
                if rows == 0 {
×
360
                        return ErrGraphNodeNotFound
×
361
                } else if rows > 1 {
×
362
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
363
                }
×
364

365
                return err
×
366
        }, sqldb.NoOpReset)
367
        if err != nil {
×
368
                return fmt.Errorf("unable to delete node: %w", err)
×
369
        }
×
370

371
        return nil
×
372
}
373

374
// FetchNodeFeatures returns the features of the given node. If no features are
375
// known for the node, an empty feature vector is returned.
376
//
377
// NOTE: this is part of the graphdb.NodeTraverser interface.
378
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
379
        *lnwire.FeatureVector, error) {
×
380

×
381
        ctx := context.TODO()
×
382

×
383
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
384
}
×
385

386
// DisabledChannelIDs returns the channel ids of disabled channels.
387
// A channel is disabled when two of the associated ChanelEdgePolicies
388
// have their disabled bit on.
389
//
390
// NOTE: part of the V1Store interface.
391
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
392
        var (
×
393
                ctx     = context.TODO()
×
394
                chanIDs []uint64
×
395
        )
×
396
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
397
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
398
                if err != nil {
×
399
                        return fmt.Errorf("unable to fetch disabled "+
×
400
                                "channels: %w", err)
×
401
                }
×
402

403
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
404

×
405
                return nil
×
406
        }, sqldb.NoOpReset)
407
        if err != nil {
×
408
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
409
                        err)
×
410
        }
×
411

412
        return chanIDs, nil
×
413
}
414

415
// LookupAlias attempts to return the alias as advertised by the target node.
416
//
417
// NOTE: part of the V1Store interface.
418
func (s *SQLStore) LookupAlias(ctx context.Context,
419
        pub *btcec.PublicKey) (string, error) {
×
420

×
421
        var alias string
×
422
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
423
                dbNode, err := db.GetNodeByPubKey(
×
424
                        ctx, sqlc.GetNodeByPubKeyParams{
×
425
                                Version: int16(ProtocolV1),
×
426
                                PubKey:  pub.SerializeCompressed(),
×
427
                        },
×
428
                )
×
429
                if errors.Is(err, sql.ErrNoRows) {
×
430
                        return ErrNodeAliasNotFound
×
431
                } else if err != nil {
×
432
                        return fmt.Errorf("unable to fetch node: %w", err)
×
433
                }
×
434

435
                if !dbNode.Alias.Valid {
×
436
                        return ErrNodeAliasNotFound
×
437
                }
×
438

439
                alias = dbNode.Alias.String
×
440

×
441
                return nil
×
442
        }, sqldb.NoOpReset)
443
        if err != nil {
×
444
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
445
        }
×
446

447
        return alias, nil
×
448
}
449

450
// SourceNode returns the source node of the graph. The source node is treated
451
// as the center node within a star-graph. This method may be used to kick off
452
// a path finding algorithm in order to explore the reachability of another
453
// node based off the source node.
454
//
455
// NOTE: part of the V1Store interface.
456
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
457
        error) {
×
458

×
459
        var node *models.LightningNode
×
460
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
461
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
462
                if err != nil {
×
463
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
464
                                err)
×
465
                }
×
466

467
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
468

×
469
                return err
×
470
        }, sqldb.NoOpReset)
471
        if err != nil {
×
472
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
473
        }
×
474

475
        return node, nil
×
476
}
477

478
// SetSourceNode sets the source node within the graph database. The source
479
// node is to be used as the center of a star-graph within path finding
480
// algorithms.
481
//
482
// NOTE: part of the V1Store interface.
483
func (s *SQLStore) SetSourceNode(ctx context.Context,
484
        node *models.LightningNode) error {
×
485

×
486
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
487
                id, err := upsertNode(ctx, db, node)
×
488
                if err != nil {
×
489
                        return fmt.Errorf("unable to upsert source node: %w",
×
490
                                err)
×
491
                }
×
492

493
                // Make sure that if a source node for this version is already
494
                // set, then the ID is the same as the one we are about to set.
495
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
496
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
497
                        return fmt.Errorf("unable to fetch source node: %w",
×
498
                                err)
×
499
                } else if err == nil {
×
500
                        if dbSourceNodeID != id {
×
501
                                return fmt.Errorf("v1 source node already "+
×
502
                                        "set to a different node: %d vs %d",
×
503
                                        dbSourceNodeID, id)
×
504
                        }
×
505

506
                        return nil
×
507
                }
508

509
                return db.AddSourceNode(ctx, id)
×
510
        }, sqldb.NoOpReset)
511
}
512

513
// NodeUpdatesInHorizon returns all the known lightning node which have an
514
// update timestamp within the passed range. This method can be used by two
515
// nodes to quickly determine if they have the same set of up to date node
516
// announcements.
517
//
518
// NOTE: This is part of the V1Store interface.
519
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
520
        endTime time.Time) ([]models.LightningNode, error) {
×
521

×
522
        ctx := context.TODO()
×
523

×
524
        var nodes []models.LightningNode
×
525
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
526
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
527
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
528
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
529
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
530
                        },
×
531
                )
×
532
                if err != nil {
×
533
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
534
                }
×
535

536
                for _, dbNode := range dbNodes {
×
537
                        node, err := buildNode(ctx, db, &dbNode)
×
538
                        if err != nil {
×
539
                                return fmt.Errorf("unable to build node: %w",
×
540
                                        err)
×
541
                        }
×
542

543
                        nodes = append(nodes, *node)
×
544
                }
545

546
                return nil
×
547
        }, sqldb.NoOpReset)
548
        if err != nil {
×
549
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
550
        }
×
551

552
        return nodes, nil
×
553
}
554

555
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
556
// undirected edge from the two target nodes are created. The information stored
557
// denotes the static attributes of the channel, such as the channelID, the keys
558
// involved in creation of the channel, and the set of features that the channel
559
// supports. The chanPoint and chanID are used to uniquely identify the edge
560
// globally within the database.
561
//
562
// NOTE: part of the V1Store interface.
563
func (s *SQLStore) AddChannelEdge(ctx context.Context,
564
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
565

×
566
        var alreadyExists bool
×
567
        r := &batch.Request[SQLQueries]{
×
568
                Opts: batch.NewSchedulerOptions(opts...),
×
569
                Reset: func() {
×
570
                        alreadyExists = false
×
571
                },
×
572
                Do: func(tx SQLQueries) error {
×
573
                        _, err := insertChannel(ctx, tx, edge)
×
574

×
575
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
576
                        // succeed, but propagate the error via local state.
×
577
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
578
                                alreadyExists = true
×
579
                                return nil
×
580
                        }
×
581

582
                        return err
×
583
                },
584
                OnCommit: func(err error) error {
×
585
                        switch {
×
586
                        case err != nil:
×
587
                                return err
×
588
                        case alreadyExists:
×
589
                                return ErrEdgeAlreadyExist
×
590
                        default:
×
591
                                s.rejectCache.remove(edge.ChannelID)
×
592
                                s.chanCache.remove(edge.ChannelID)
×
593
                                return nil
×
594
                        }
595
                },
596
        }
597

598
        return s.chanScheduler.Execute(ctx, r)
×
599
}
600

601
// HighestChanID returns the "highest" known channel ID in the channel graph.
602
// This represents the "newest" channel from the PoV of the chain. This method
603
// can be used by peers to quickly determine if their graphs are in sync.
604
//
605
// NOTE: This is part of the V1Store interface.
606
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
607
        var highestChanID uint64
×
608
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
609
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
610
                if errors.Is(err, sql.ErrNoRows) {
×
611
                        return nil
×
612
                } else if err != nil {
×
613
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
614
                                err)
×
615
                }
×
616

617
                highestChanID = byteOrder.Uint64(chanID)
×
618

×
619
                return nil
×
620
        }, sqldb.NoOpReset)
621
        if err != nil {
×
622
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
623
        }
×
624

625
        return highestChanID, nil
×
626
}
627

628
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
629
// within the database for the referenced channel. The `flags` attribute within
630
// the ChannelEdgePolicy determines which of the directed edges are being
631
// updated. If the flag is 1, then the first node's information is being
632
// updated, otherwise it's the second node's information. The node ordering is
633
// determined by the lexicographical ordering of the identity public keys of the
634
// nodes on either side of the channel.
635
//
636
// NOTE: part of the V1Store interface.
637
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
638
        edge *models.ChannelEdgePolicy,
639
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
640

×
641
        var (
×
642
                isUpdate1    bool
×
643
                edgeNotFound bool
×
644
                from, to     route.Vertex
×
645
        )
×
646

×
647
        r := &batch.Request[SQLQueries]{
×
648
                Opts: batch.NewSchedulerOptions(opts...),
×
649
                Reset: func() {
×
650
                        isUpdate1 = false
×
651
                        edgeNotFound = false
×
652
                },
×
653
                Do: func(tx SQLQueries) error {
×
654
                        var err error
×
655
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
656
                                ctx, tx, edge,
×
657
                        )
×
658
                        if err != nil {
×
659
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
660
                        }
×
661

662
                        // Silence ErrEdgeNotFound so that the batch can
663
                        // succeed, but propagate the error via local state.
664
                        if errors.Is(err, ErrEdgeNotFound) {
×
665
                                edgeNotFound = true
×
666
                                return nil
×
667
                        }
×
668

669
                        return err
×
670
                },
671
                OnCommit: func(err error) error {
×
672
                        switch {
×
673
                        case err != nil:
×
674
                                return err
×
675
                        case edgeNotFound:
×
676
                                return ErrEdgeNotFound
×
677
                        default:
×
678
                                s.updateEdgeCache(edge, isUpdate1)
×
679
                                return nil
×
680
                        }
681
                },
682
        }
683

684
        err := s.chanScheduler.Execute(ctx, r)
×
685

×
686
        return from, to, err
×
687
}
688

689
// updateEdgeCache updates our reject and channel caches with the new
690
// edge policy information.
691
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
692
        isUpdate1 bool) {
×
693

×
694
        // If an entry for this channel is found in reject cache, we'll modify
×
695
        // the entry with the updated timestamp for the direction that was just
×
696
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
697
        // during the next query for this edge.
×
698
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
699
                if isUpdate1 {
×
700
                        entry.upd1Time = e.LastUpdate.Unix()
×
701
                } else {
×
702
                        entry.upd2Time = e.LastUpdate.Unix()
×
703
                }
×
704
                s.rejectCache.insert(e.ChannelID, entry)
×
705
        }
706

707
        // If an entry for this channel is found in channel cache, we'll modify
708
        // the entry with the updated policy for the direction that was just
709
        // written. If the edge doesn't exist, we'll defer loading the info and
710
        // policies and lazily read from disk during the next query.
711
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
712
                if isUpdate1 {
×
713
                        channel.Policy1 = e
×
714
                } else {
×
715
                        channel.Policy2 = e
×
716
                }
×
717
                s.chanCache.insert(e.ChannelID, channel)
×
718
        }
719
}
720

721
// ForEachSourceNodeChannel iterates through all channels of the source node,
722
// executing the passed callback on each. The call-back is provided with the
723
// channel's outpoint, whether we have a policy for the channel and the channel
724
// peer's node information.
725
//
726
// NOTE: part of the V1Store interface.
727
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
728
        cb func(chanPoint wire.OutPoint, havePolicy bool,
NEW
729
                otherNode *models.LightningNode) error) error {
×
730

×
731
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
732
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
733
                if err != nil {
×
734
                        return fmt.Errorf("unable to fetch source node: %w",
×
735
                                err)
×
736
                }
×
737

738
                return forEachNodeChannel(
×
739
                        ctx, db, s.cfg.ChainHash, nodeID,
×
740
                        func(info *models.ChannelEdgeInfo,
×
741
                                outPolicy *models.ChannelEdgePolicy,
×
742
                                _ *models.ChannelEdgePolicy) error {
×
743

×
744
                                // Fetch the other node.
×
745
                                var (
×
746
                                        otherNodePub [33]byte
×
747
                                        node1        = info.NodeKey1Bytes
×
748
                                        node2        = info.NodeKey2Bytes
×
749
                                )
×
750
                                switch {
×
751
                                case bytes.Equal(node1[:], nodePub[:]):
×
752
                                        otherNodePub = node2
×
753
                                case bytes.Equal(node2[:], nodePub[:]):
×
754
                                        otherNodePub = node1
×
755
                                default:
×
756
                                        return fmt.Errorf("node not " +
×
757
                                                "participating in this channel")
×
758
                                }
759

760
                                _, otherNode, err := getNodeByPubKey(
×
761
                                        ctx, db, otherNodePub,
×
762
                                )
×
763
                                if err != nil {
×
764
                                        return fmt.Errorf("unable to fetch "+
×
765
                                                "other node(%x): %w",
×
766
                                                otherNodePub, err)
×
767
                                }
×
768

769
                                return cb(
×
770
                                        info.ChannelPoint, outPolicy != nil,
×
771
                                        otherNode,
×
772
                                )
×
773
                        },
774
                )
775
        }, sqldb.NoOpReset)
776
}
777

778
// ForEachNode iterates through all the stored vertices/nodes in the graph,
779
// executing the passed callback with each node encountered. If the callback
780
// returns an error, then the transaction is aborted and the iteration stops
781
// early. Any operations performed on the NodeTx passed to the call-back are
782
// executed under the same read transaction and so, methods on the NodeTx object
783
// _MUST_ only be called from within the call-back.
784
//
785
// NOTE: part of the V1Store interface.
786
func (s *SQLStore) ForEachNode(ctx context.Context,
NEW
787
        cb func(tx NodeRTx) error) error {
×
788

×
NEW
789
        var lastID int64 = 0
×
790
        handleNode := func(db SQLQueries, dbNode sqlc.Node) error {
×
791
                node, err := buildNode(ctx, db, &dbNode)
×
792
                if err != nil {
×
793
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
794
                                dbNode.ID, err)
×
795
                }
×
796

797
                err = cb(
×
798
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
799
                )
×
800
                if err != nil {
×
801
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
802
                                dbNode.ID, err)
×
803
                }
×
804

805
                return nil
×
806
        }
807

808
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
809
                for {
×
810
                        nodes, err := db.ListNodesPaginated(
×
811
                                ctx, sqlc.ListNodesPaginatedParams{
×
812
                                        Version: int16(ProtocolV1),
×
813
                                        ID:      lastID,
×
814
                                        Limit:   pageSize,
×
815
                                },
×
816
                        )
×
817
                        if err != nil {
×
818
                                return fmt.Errorf("unable to fetch nodes: %w",
×
819
                                        err)
×
820
                        }
×
821

822
                        if len(nodes) == 0 {
×
823
                                break
×
824
                        }
825

826
                        for _, dbNode := range nodes {
×
827
                                err = handleNode(db, dbNode)
×
828
                                if err != nil {
×
829
                                        return err
×
830
                                }
×
831

832
                                lastID = dbNode.ID
×
833
                        }
834
                }
835

836
                return nil
×
837
        }, sqldb.NoOpReset)
838
}
839

840
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
841
// SQLStore and a SQL transaction.
842
type sqlGraphNodeTx struct {
843
        db    SQLQueries
844
        id    int64
845
        node  *models.LightningNode
846
        chain chainhash.Hash
847
}
848

849
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
850
// interface.
851
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
852

853
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
854
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
855

×
856
        return &sqlGraphNodeTx{
×
857
                db:    db,
×
858
                chain: chain,
×
859
                id:    id,
×
860
                node:  node,
×
861
        }
×
862
}
×
863

864
// Node returns the raw information of the node.
865
//
866
// NOTE: This is a part of the NodeRTx interface.
867
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
868
        return s.node
×
869
}
×
870

871
// ForEachChannel can be used to iterate over the node's channels under the same
872
// transaction used to fetch the node.
873
//
874
// NOTE: This is a part of the NodeRTx interface.
875
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
876
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
877

×
878
        ctx := context.TODO()
×
879

×
880
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
881
}
×
882

883
// FetchNode fetches the node with the given pub key under the same transaction
884
// used to fetch the current node. The returned node is also a NodeRTx and any
885
// operations on that NodeRTx will also be done under the same transaction.
886
//
887
// NOTE: This is a part of the NodeRTx interface.
888
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
889
        ctx := context.TODO()
×
890

×
891
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
892
        if err != nil {
×
893
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
894
                        nodePub, err)
×
895
        }
×
896

897
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
898
}
899

900
// ForEachNodeDirectedChannel iterates through all channels of a given node,
901
// executing the passed callback on the directed edge representing the channel
902
// and its incoming policy. If the callback returns an error, then the iteration
903
// is halted with the error propagated back up to the caller.
904
//
905
// Unknown policies are passed into the callback as nil values.
906
//
907
// NOTE: this is part of the graphdb.NodeTraverser interface.
908
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
909
        cb func(channel *DirectedChannel) error) error {
×
910

×
911
        var ctx = context.TODO()
×
912

×
913
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
914
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
915
        }, sqldb.NoOpReset)
×
916
}
917

918
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
919
// graph, executing the passed callback with each node encountered. If the
920
// callback returns an error, then the transaction is aborted and the iteration
921
// stops early.
922
//
923
// NOTE: This is a part of the V1Store interface.
924
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
NEW
925
        cb func(route.Vertex, *lnwire.FeatureVector) error) error {
×
926

×
927
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
928
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
929
                        nodePub route.Vertex) error {
×
930

×
931
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
932
                        if err != nil {
×
933
                                return fmt.Errorf("unable to fetch node "+
×
934
                                        "features: %w", err)
×
935
                        }
×
936

937
                        return cb(nodePub, features)
×
938
                })
939
        }, sqldb.NoOpReset)
940
        if err != nil {
×
941
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
942
        }
×
943

944
        return nil
×
945
}
946

947
// ForEachNodeChannel iterates through all channels of the given node,
948
// executing the passed callback with an edge info structure and the policies
949
// of each end of the channel. The first edge policy is the outgoing edge *to*
950
// the connecting node, while the second is the incoming edge *from* the
951
// connecting node. If the callback returns an error, then the iteration is
952
// halted with the error propagated back up to the caller.
953
//
954
// Unknown policies are passed into the callback as nil values.
955
//
956
// NOTE: part of the V1Store interface.
957
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
958
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
959
                *models.ChannelEdgePolicy) error) error {
×
960

×
961
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
962
                dbNode, err := db.GetNodeByPubKey(
×
963
                        ctx, sqlc.GetNodeByPubKeyParams{
×
964
                                Version: int16(ProtocolV1),
×
965
                                PubKey:  nodePub[:],
×
966
                        },
×
967
                )
×
968
                if errors.Is(err, sql.ErrNoRows) {
×
969
                        return nil
×
970
                } else if err != nil {
×
971
                        return fmt.Errorf("unable to fetch node: %w", err)
×
972
                }
×
973

974
                return forEachNodeChannel(
×
975
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
976
                )
×
977
        }, sqldb.NoOpReset)
978
}
979

980
// ChanUpdatesInHorizon returns all the known channel edges which have at least
981
// one edge that has an update timestamp within the specified horizon.
982
//
983
// NOTE: This is part of the V1Store interface.
984
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
985
        endTime time.Time) ([]ChannelEdge, error) {
×
986

×
987
        s.cacheMu.Lock()
×
988
        defer s.cacheMu.Unlock()
×
989

×
990
        var (
×
991
                ctx = context.TODO()
×
992
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
993
                // an additional map to keep track of the edges already seen to
×
994
                // prevent re-adding it.
×
995
                edgesSeen    = make(map[uint64]struct{})
×
996
                edgesToCache = make(map[uint64]ChannelEdge)
×
997
                edges        []ChannelEdge
×
998
                hits         int
×
999
        )
×
1000
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1001
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1002
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1003
                                Version:   int16(ProtocolV1),
×
1004
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1005
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1006
                        },
×
1007
                )
×
1008
                if err != nil {
×
1009
                        return err
×
1010
                }
×
1011

1012
                for _, row := range rows {
×
1013
                        // If we've already retrieved the info and policies for
×
1014
                        // this edge, then we can skip it as we don't need to do
×
1015
                        // so again.
×
1016
                        chanIDInt := byteOrder.Uint64(row.Channel.Scid)
×
1017
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1018
                                continue
×
1019
                        }
1020

1021
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1022
                                hits++
×
1023
                                edgesSeen[chanIDInt] = struct{}{}
×
1024
                                edges = append(edges, channel)
×
1025

×
1026
                                continue
×
1027
                        }
1028

1029
                        node1, node2, err := buildNodes(
×
1030
                                ctx, db, row.Node, row.Node_2,
×
1031
                        )
×
1032
                        if err != nil {
×
1033
                                return err
×
1034
                        }
×
1035

1036
                        channel, err := getAndBuildEdgeInfo(
×
1037
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1038
                                row.Channel, node1.PubKeyBytes,
×
1039
                                node2.PubKeyBytes,
×
1040
                        )
×
1041
                        if err != nil {
×
1042
                                return fmt.Errorf("unable to build channel "+
×
1043
                                        "info: %w", err)
×
1044
                        }
×
1045

1046
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1047
                        if err != nil {
×
1048
                                return fmt.Errorf("unable to extract channel "+
×
1049
                                        "policies: %w", err)
×
1050
                        }
×
1051

1052
                        p1, p2, err := getAndBuildChanPolicies(
×
1053
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1054
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1055
                        )
×
1056
                        if err != nil {
×
1057
                                return fmt.Errorf("unable to build channel "+
×
1058
                                        "policies: %w", err)
×
1059
                        }
×
1060

1061
                        edgesSeen[chanIDInt] = struct{}{}
×
1062
                        chanEdge := ChannelEdge{
×
1063
                                Info:    channel,
×
1064
                                Policy1: p1,
×
1065
                                Policy2: p2,
×
1066
                                Node1:   node1,
×
1067
                                Node2:   node2,
×
1068
                        }
×
1069
                        edges = append(edges, chanEdge)
×
1070
                        edgesToCache[chanIDInt] = chanEdge
×
1071
                }
1072

1073
                return nil
×
1074
        }, func() {
×
1075
                edgesSeen = make(map[uint64]struct{})
×
1076
                edgesToCache = make(map[uint64]ChannelEdge)
×
1077
                edges = nil
×
1078
        })
×
1079
        if err != nil {
×
1080
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1081
        }
×
1082

1083
        // Insert any edges loaded from disk into the cache.
1084
        for chanid, channel := range edgesToCache {
×
1085
                s.chanCache.insert(chanid, channel)
×
1086
        }
×
1087

1088
        if len(edges) > 0 {
×
1089
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1090
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1091
        } else {
×
1092
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1093
                        "horizon (%s, %s)", startTime, endTime)
×
1094
        }
×
1095

1096
        return edges, nil
×
1097
}
1098

1099
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1100
// data to the call-back.
1101
//
1102
// NOTE: The callback contents MUST not be modified.
1103
//
1104
// NOTE: part of the V1Store interface.
1105
func (s *SQLStore) ForEachNodeCached(ctx context.Context,
1106
        cb func(node route.Vertex,
NEW
1107
                chans map[uint64]*DirectedChannel) error) error {
×
1108

×
1109
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1110
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1111
                        nodePub route.Vertex) error {
×
1112

×
1113
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1114
                        if err != nil {
×
1115
                                return fmt.Errorf("unable to fetch "+
×
1116
                                        "node(id=%d) features: %w", nodeID, err)
×
1117
                        }
×
1118

1119
                        toNodeCallback := func() route.Vertex {
×
1120
                                return nodePub
×
1121
                        }
×
1122

1123
                        rows, err := db.ListChannelsByNodeID(
×
1124
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1125
                                        Version: int16(ProtocolV1),
×
1126
                                        NodeID1: nodeID,
×
1127
                                },
×
1128
                        )
×
1129
                        if err != nil {
×
1130
                                return fmt.Errorf("unable to fetch channels "+
×
1131
                                        "of node(id=%d): %w", nodeID, err)
×
1132
                        }
×
1133

1134
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1135
                        for _, row := range rows {
×
1136
                                node1, node2, err := buildNodeVertices(
×
1137
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1138
                                )
×
1139
                                if err != nil {
×
1140
                                        return err
×
1141
                                }
×
1142

1143
                                e, err := getAndBuildEdgeInfo(
×
1144
                                        ctx, db, s.cfg.ChainHash,
×
1145
                                        row.Channel.ID, row.Channel, node1,
×
1146
                                        node2,
×
1147
                                )
×
1148
                                if err != nil {
×
1149
                                        return fmt.Errorf("unable to build "+
×
1150
                                                "channel info: %w", err)
×
1151
                                }
×
1152

1153
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1154
                                        row,
×
1155
                                )
×
1156
                                if err != nil {
×
1157
                                        return fmt.Errorf("unable to "+
×
1158
                                                "extract channel "+
×
1159
                                                "policies: %w", err)
×
1160
                                }
×
1161

1162
                                p1, p2, err := getAndBuildChanPolicies(
×
1163
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1164
                                        node1, node2,
×
1165
                                )
×
1166
                                if err != nil {
×
1167
                                        return fmt.Errorf("unable to "+
×
1168
                                                "build channel policies: %w",
×
1169
                                                err)
×
1170
                                }
×
1171

1172
                                // Determine the outgoing and incoming policy
1173
                                // for this channel and node combo.
1174
                                outPolicy, inPolicy := p1, p2
×
1175
                                if p1 != nil && p1.ToNode == nodePub {
×
1176
                                        outPolicy, inPolicy = p2, p1
×
1177
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1178
                                        outPolicy, inPolicy = p2, p1
×
1179
                                }
×
1180

1181
                                var cachedInPolicy *models.CachedEdgePolicy
×
1182
                                if inPolicy != nil {
×
1183
                                        cachedInPolicy = models.NewCachedPolicy(
×
1184
                                                p2,
×
1185
                                        )
×
1186
                                        cachedInPolicy.ToNodePubKey =
×
1187
                                                toNodeCallback
×
1188
                                        cachedInPolicy.ToNodeFeatures =
×
1189
                                                features
×
1190
                                }
×
1191

1192
                                var inboundFee lnwire.Fee
×
1193
                                outPolicy.InboundFee.WhenSome(
×
1194
                                        func(fee lnwire.Fee) {
×
1195
                                                inboundFee = fee
×
1196
                                        },
×
1197
                                )
1198

1199
                                directedChannel := &DirectedChannel{
×
1200
                                        ChannelID: e.ChannelID,
×
1201
                                        IsNode1: nodePub ==
×
1202
                                                e.NodeKey1Bytes,
×
1203
                                        OtherNode:    e.NodeKey2Bytes,
×
1204
                                        Capacity:     e.Capacity,
×
1205
                                        OutPolicySet: p1 != nil,
×
1206
                                        InPolicy:     cachedInPolicy,
×
1207
                                        InboundFee:   inboundFee,
×
1208
                                }
×
1209

×
1210
                                if nodePub == e.NodeKey2Bytes {
×
1211
                                        directedChannel.OtherNode =
×
1212
                                                e.NodeKey1Bytes
×
1213
                                }
×
1214

1215
                                channels[e.ChannelID] = directedChannel
×
1216
                        }
1217

1218
                        return cb(nodePub, channels)
×
1219
                })
1220
        }, sqldb.NoOpReset)
1221
}
1222

1223
// ForEachChannelCacheable iterates through all the channel edges stored
1224
// within the graph and invokes the passed callback for each edge. The
1225
// callback takes two edges as since this is a directed graph, both the
1226
// in/out edges are visited. If the callback returns an error, then the
1227
// transaction is aborted and the iteration stops early.
1228
//
1229
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1230
// pointer for that particular channel edge routing policy will be
1231
// passed into the callback.
1232
//
1233
// NOTE: this method is like ForEachChannel but fetches only the data
1234
// required for the graph cache.
1235
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1236
        *models.CachedEdgePolicy,
1237
        *models.CachedEdgePolicy) error) error {
×
1238

×
1239
        ctx := context.TODO()
×
1240

×
1241
        handleChannel := func(db SQLQueries,
×
1242
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1243

×
1244
                node1, node2, err := buildNodeVertices(
×
1245
                        row.Node1Pubkey, row.Node2Pubkey,
×
1246
                )
×
1247
                if err != nil {
×
1248
                        return err
×
1249
                }
×
1250

1251
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
1252

×
1253
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1254
                if err != nil {
×
1255
                        return err
×
1256
                }
×
1257

1258
                var pol1, pol2 *models.CachedEdgePolicy
×
1259
                if dbPol1 != nil {
×
1260
                        policy1, err := buildChanPolicy(
×
1261
                                *dbPol1, edge.ChannelID, nil, node2,
×
1262
                        )
×
1263
                        if err != nil {
×
1264
                                return err
×
1265
                        }
×
1266

1267
                        pol1 = models.NewCachedPolicy(policy1)
×
1268
                }
1269
                if dbPol2 != nil {
×
1270
                        policy2, err := buildChanPolicy(
×
1271
                                *dbPol2, edge.ChannelID, nil, node1,
×
1272
                        )
×
1273
                        if err != nil {
×
1274
                                return err
×
1275
                        }
×
1276

1277
                        pol2 = models.NewCachedPolicy(policy2)
×
1278
                }
1279

1280
                if err := cb(edge, pol1, pol2); err != nil {
×
1281
                        return err
×
1282
                }
×
1283

1284
                return nil
×
1285
        }
1286

1287
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1288
                lastID := int64(-1)
×
1289
                for {
×
1290
                        //nolint:ll
×
1291
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1292
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1293
                                        Version: int16(ProtocolV1),
×
1294
                                        ID:      lastID,
×
1295
                                        Limit:   pageSize,
×
1296
                                },
×
1297
                        )
×
1298
                        if err != nil {
×
1299
                                return err
×
1300
                        }
×
1301

1302
                        if len(rows) == 0 {
×
1303
                                break
×
1304
                        }
1305

1306
                        for _, row := range rows {
×
1307
                                err := handleChannel(db, row)
×
1308
                                if err != nil {
×
1309
                                        return err
×
1310
                                }
×
1311

1312
                                lastID = row.Channel.ID
×
1313
                        }
1314
                }
1315

1316
                return nil
×
1317
        }, sqldb.NoOpReset)
1318
}
1319

1320
// ForEachChannel iterates through all the channel edges stored within the
1321
// graph and invokes the passed callback for each edge. The callback takes two
1322
// edges as since this is a directed graph, both the in/out edges are visited.
1323
// If the callback returns an error, then the transaction is aborted and the
1324
// iteration stops early.
1325
//
1326
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1327
// for that particular channel edge routing policy will be passed into the
1328
// callback.
1329
//
1330
// NOTE: part of the V1Store interface.
1331
func (s *SQLStore) ForEachChannel(ctx context.Context,
1332
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
NEW
1333
                *models.ChannelEdgePolicy) error) error {
×
1334

×
1335
        handleChannel := func(db SQLQueries,
×
1336
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1337

×
1338
                node1, node2, err := buildNodeVertices(
×
1339
                        row.Node1Pubkey, row.Node2Pubkey,
×
1340
                )
×
1341
                if err != nil {
×
1342
                        return fmt.Errorf("unable to build node vertices: %w",
×
1343
                                err)
×
1344
                }
×
1345

1346
                edge, err := getAndBuildEdgeInfo(
×
1347
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1348
                        node1, node2,
×
1349
                )
×
1350
                if err != nil {
×
1351
                        return fmt.Errorf("unable to build channel info: %w",
×
1352
                                err)
×
1353
                }
×
1354

1355
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1356
                if err != nil {
×
1357
                        return fmt.Errorf("unable to extract channel "+
×
1358
                                "policies: %w", err)
×
1359
                }
×
1360

1361
                p1, p2, err := getAndBuildChanPolicies(
×
1362
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1363
                )
×
1364
                if err != nil {
×
1365
                        return fmt.Errorf("unable to build channel "+
×
1366
                                "policies: %w", err)
×
1367
                }
×
1368

1369
                err = cb(edge, p1, p2)
×
1370
                if err != nil {
×
1371
                        return fmt.Errorf("callback failed for channel "+
×
1372
                                "id=%d: %w", edge.ChannelID, err)
×
1373
                }
×
1374

1375
                return nil
×
1376
        }
1377

1378
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1379
                lastID := int64(-1)
×
1380
                for {
×
1381
                        //nolint:ll
×
1382
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1383
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1384
                                        Version: int16(ProtocolV1),
×
1385
                                        ID:      lastID,
×
1386
                                        Limit:   pageSize,
×
1387
                                },
×
1388
                        )
×
1389
                        if err != nil {
×
1390
                                return err
×
1391
                        }
×
1392

1393
                        if len(rows) == 0 {
×
1394
                                break
×
1395
                        }
1396

1397
                        for _, row := range rows {
×
1398
                                err := handleChannel(db, row)
×
1399
                                if err != nil {
×
1400
                                        return err
×
1401
                                }
×
1402

1403
                                lastID = row.Channel.ID
×
1404
                        }
1405
                }
1406

1407
                return nil
×
1408
        }, sqldb.NoOpReset)
1409
}
1410

1411
// FilterChannelRange returns the channel ID's of all known channels which were
1412
// mined in a block height within the passed range. The channel IDs are grouped
1413
// by their common block height. This method can be used to quickly share with a
1414
// peer the set of channels we know of within a particular range to catch them
1415
// up after a period of time offline. If withTimestamps is true then the
1416
// timestamp info of the latest received channel update messages of the channel
1417
// will be included in the response.
1418
//
1419
// NOTE: This is part of the V1Store interface.
1420
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1421
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1422

×
1423
        var (
×
1424
                ctx       = context.TODO()
×
1425
                startSCID = &lnwire.ShortChannelID{
×
1426
                        BlockHeight: startHeight,
×
1427
                }
×
1428
                endSCID = lnwire.ShortChannelID{
×
1429
                        BlockHeight: endHeight,
×
1430
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1431
                        TxPosition:  math.MaxUint16,
×
1432
                }
×
1433
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1434
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1435
        )
×
1436

×
1437
        // 1) get all channels where channelID is between start and end chan ID.
×
1438
        // 2) skip if not public (ie, no channel_proof)
×
1439
        // 3) collect that channel.
×
1440
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1441
        //    and add those timestamps to the collected channel.
×
1442
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1443
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1444
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1445
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1446
                                StartScid: chanIDStart,
×
1447
                                EndScid:   chanIDEnd,
×
1448
                        },
×
1449
                )
×
1450
                if err != nil {
×
1451
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1452
                                err)
×
1453
                }
×
1454

1455
                for _, dbChan := range dbChans {
×
1456
                        cid := lnwire.NewShortChanIDFromInt(
×
1457
                                byteOrder.Uint64(dbChan.Scid),
×
1458
                        )
×
1459
                        chanInfo := NewChannelUpdateInfo(
×
1460
                                cid, time.Time{}, time.Time{},
×
1461
                        )
×
1462

×
1463
                        if !withTimestamps {
×
1464
                                channelsPerBlock[cid.BlockHeight] = append(
×
1465
                                        channelsPerBlock[cid.BlockHeight],
×
1466
                                        chanInfo,
×
1467
                                )
×
1468

×
1469
                                continue
×
1470
                        }
1471

1472
                        //nolint:ll
1473
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1474
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1475
                                        Version:   int16(ProtocolV1),
×
1476
                                        ChannelID: dbChan.ID,
×
1477
                                        NodeID:    dbChan.NodeID1,
×
1478
                                },
×
1479
                        )
×
1480
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1481
                                return fmt.Errorf("unable to fetch node1 "+
×
1482
                                        "policy: %w", err)
×
1483
                        } else if err == nil {
×
1484
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1485
                                        node1Policy.LastUpdate.Int64, 0,
×
1486
                                )
×
1487
                        }
×
1488

1489
                        //nolint:ll
1490
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1491
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1492
                                        Version:   int16(ProtocolV1),
×
1493
                                        ChannelID: dbChan.ID,
×
1494
                                        NodeID:    dbChan.NodeID2,
×
1495
                                },
×
1496
                        )
×
1497
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1498
                                return fmt.Errorf("unable to fetch node2 "+
×
1499
                                        "policy: %w", err)
×
1500
                        } else if err == nil {
×
1501
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1502
                                        node2Policy.LastUpdate.Int64, 0,
×
1503
                                )
×
1504
                        }
×
1505

1506
                        channelsPerBlock[cid.BlockHeight] = append(
×
1507
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1508
                        )
×
1509
                }
1510

1511
                return nil
×
1512
        }, func() {
×
1513
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1514
        })
×
1515
        if err != nil {
×
1516
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1517
        }
×
1518

1519
        if len(channelsPerBlock) == 0 {
×
1520
                return nil, nil
×
1521
        }
×
1522

1523
        // Return the channel ranges in ascending block height order.
1524
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1525
        slices.Sort(blocks)
×
1526

×
1527
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1528
                return BlockChannelRange{
×
1529
                        Height:   block,
×
1530
                        Channels: channelsPerBlock[block],
×
1531
                }
×
1532
        }), nil
×
1533
}
1534

1535
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1536
// zombie. This method is used on an ad-hoc basis, when channels need to be
1537
// marked as zombies outside the normal pruning cycle.
1538
//
1539
// NOTE: part of the V1Store interface.
1540
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1541
        pubKey1, pubKey2 [33]byte) error {
×
1542

×
1543
        ctx := context.TODO()
×
1544

×
1545
        s.cacheMu.Lock()
×
1546
        defer s.cacheMu.Unlock()
×
1547

×
1548
        chanIDB := channelIDToBytes(chanID)
×
1549

×
1550
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1551
                return db.UpsertZombieChannel(
×
1552
                        ctx, sqlc.UpsertZombieChannelParams{
×
1553
                                Version:  int16(ProtocolV1),
×
1554
                                Scid:     chanIDB,
×
1555
                                NodeKey1: pubKey1[:],
×
1556
                                NodeKey2: pubKey2[:],
×
1557
                        },
×
1558
                )
×
1559
        }, sqldb.NoOpReset)
×
1560
        if err != nil {
×
1561
                return fmt.Errorf("unable to upsert zombie channel "+
×
1562
                        "(channel_id=%d): %w", chanID, err)
×
1563
        }
×
1564

1565
        s.rejectCache.remove(chanID)
×
1566
        s.chanCache.remove(chanID)
×
1567

×
1568
        return nil
×
1569
}
1570

1571
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1572
//
1573
// NOTE: part of the V1Store interface.
1574
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1575
        s.cacheMu.Lock()
×
1576
        defer s.cacheMu.Unlock()
×
1577

×
1578
        var (
×
1579
                ctx     = context.TODO()
×
1580
                chanIDB = channelIDToBytes(chanID)
×
1581
        )
×
1582

×
1583
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1584
                res, err := db.DeleteZombieChannel(
×
1585
                        ctx, sqlc.DeleteZombieChannelParams{
×
1586
                                Scid:    chanIDB,
×
1587
                                Version: int16(ProtocolV1),
×
1588
                        },
×
1589
                )
×
1590
                if err != nil {
×
1591
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1592
                                err)
×
1593
                }
×
1594

1595
                rows, err := res.RowsAffected()
×
1596
                if err != nil {
×
1597
                        return err
×
1598
                }
×
1599

1600
                if rows == 0 {
×
1601
                        return ErrZombieEdgeNotFound
×
1602
                } else if rows > 1 {
×
1603
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1604
                                "expected 1", rows)
×
1605
                }
×
1606

1607
                return nil
×
1608
        }, sqldb.NoOpReset)
1609
        if err != nil {
×
1610
                return fmt.Errorf("unable to mark edge live "+
×
1611
                        "(channel_id=%d): %w", chanID, err)
×
1612
        }
×
1613

1614
        s.rejectCache.remove(chanID)
×
1615
        s.chanCache.remove(chanID)
×
1616

×
1617
        return err
×
1618
}
1619

1620
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1621
// zombie, then the two node public keys corresponding to this edge are also
1622
// returned.
1623
//
1624
// NOTE: part of the V1Store interface.
1625
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1626
        error) {
×
1627

×
1628
        var (
×
1629
                ctx              = context.TODO()
×
1630
                isZombie         bool
×
1631
                pubKey1, pubKey2 route.Vertex
×
1632
                chanIDB          = channelIDToBytes(chanID)
×
1633
        )
×
1634

×
1635
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1636
                zombie, err := db.GetZombieChannel(
×
1637
                        ctx, sqlc.GetZombieChannelParams{
×
1638
                                Scid:    chanIDB,
×
1639
                                Version: int16(ProtocolV1),
×
1640
                        },
×
1641
                )
×
1642
                if errors.Is(err, sql.ErrNoRows) {
×
1643
                        return nil
×
1644
                }
×
1645
                if err != nil {
×
1646
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1647
                                err)
×
1648
                }
×
1649

1650
                copy(pubKey1[:], zombie.NodeKey1)
×
1651
                copy(pubKey2[:], zombie.NodeKey2)
×
1652
                isZombie = true
×
1653

×
1654
                return nil
×
1655
        }, sqldb.NoOpReset)
1656
        if err != nil {
×
1657
                return false, route.Vertex{}, route.Vertex{},
×
1658
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1659
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1660
        }
×
1661

1662
        return isZombie, pubKey1, pubKey2, nil
×
1663
}
1664

1665
// NumZombies returns the current number of zombie channels in the graph.
1666
//
1667
// NOTE: part of the V1Store interface.
1668
func (s *SQLStore) NumZombies() (uint64, error) {
×
1669
        var (
×
1670
                ctx        = context.TODO()
×
1671
                numZombies uint64
×
1672
        )
×
1673
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1674
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1675
                if err != nil {
×
1676
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1677
                                err)
×
1678
                }
×
1679

1680
                numZombies = uint64(count)
×
1681

×
1682
                return nil
×
1683
        }, sqldb.NoOpReset)
1684
        if err != nil {
×
1685
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1686
        }
×
1687

1688
        return numZombies, nil
×
1689
}
1690

1691
// DeleteChannelEdges removes edges with the given channel IDs from the
1692
// database and marks them as zombies. This ensures that we're unable to re-add
1693
// it to our database once again. If an edge does not exist within the
1694
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1695
// true, then when we mark these edges as zombies, we'll set up the keys such
1696
// that we require the node that failed to send the fresh update to be the one
1697
// that resurrects the channel from its zombie state. The markZombie bool
1698
// denotes whether to mark the channel as a zombie.
1699
//
1700
// NOTE: part of the V1Store interface.
1701
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1702
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1703

×
1704
        s.cacheMu.Lock()
×
1705
        defer s.cacheMu.Unlock()
×
1706

×
1707
        var (
×
1708
                ctx     = context.TODO()
×
1709
                deleted []*models.ChannelEdgeInfo
×
1710
        )
×
1711
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1712
                for _, chanID := range chanIDs {
×
1713
                        chanIDB := channelIDToBytes(chanID)
×
1714

×
1715
                        row, err := db.GetChannelBySCIDWithPolicies(
×
1716
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1717
                                        Scid:    chanIDB,
×
1718
                                        Version: int16(ProtocolV1),
×
1719
                                },
×
1720
                        )
×
1721
                        if errors.Is(err, sql.ErrNoRows) {
×
1722
                                return ErrEdgeNotFound
×
1723
                        } else if err != nil {
×
1724
                                return fmt.Errorf("unable to fetch channel: %w",
×
1725
                                        err)
×
1726
                        }
×
1727

1728
                        node1, node2, err := buildNodeVertices(
×
1729
                                row.Node.PubKey, row.Node_2.PubKey,
×
1730
                        )
×
1731
                        if err != nil {
×
1732
                                return err
×
1733
                        }
×
1734

1735
                        info, err := getAndBuildEdgeInfo(
×
1736
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1737
                                row.Channel, node1, node2,
×
1738
                        )
×
1739
                        if err != nil {
×
1740
                                return err
×
1741
                        }
×
1742

1743
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
1744
                        if err != nil {
×
1745
                                return fmt.Errorf("unable to delete "+
×
1746
                                        "channel: %w", err)
×
1747
                        }
×
1748

1749
                        deleted = append(deleted, info)
×
1750

×
1751
                        if !markZombie {
×
1752
                                continue
×
1753
                        }
1754

1755
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1756
                                info.NodeKey2Bytes
×
1757
                        if strictZombiePruning {
×
1758
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1759
                                if row.Policy1LastUpdate.Valid {
×
1760
                                        e1Time := time.Unix(
×
1761
                                                row.Policy1LastUpdate.Int64, 0,
×
1762
                                        )
×
1763
                                        e1UpdateTime = &e1Time
×
1764
                                }
×
1765
                                if row.Policy2LastUpdate.Valid {
×
1766
                                        e2Time := time.Unix(
×
1767
                                                row.Policy2LastUpdate.Int64, 0,
×
1768
                                        )
×
1769
                                        e2UpdateTime = &e2Time
×
1770
                                }
×
1771

1772
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1773
                                        info, e1UpdateTime, e2UpdateTime,
×
1774
                                )
×
1775
                        }
1776

1777
                        err = db.UpsertZombieChannel(
×
1778
                                ctx, sqlc.UpsertZombieChannelParams{
×
1779
                                        Version:  int16(ProtocolV1),
×
1780
                                        Scid:     chanIDB,
×
1781
                                        NodeKey1: nodeKey1[:],
×
1782
                                        NodeKey2: nodeKey2[:],
×
1783
                                },
×
1784
                        )
×
1785
                        if err != nil {
×
1786
                                return fmt.Errorf("unable to mark channel as "+
×
1787
                                        "zombie: %w", err)
×
1788
                        }
×
1789
                }
1790

1791
                return nil
×
1792
        }, func() {
×
1793
                deleted = nil
×
1794
        })
×
1795
        if err != nil {
×
1796
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1797
                        err)
×
1798
        }
×
1799

1800
        for _, chanID := range chanIDs {
×
1801
                s.rejectCache.remove(chanID)
×
1802
                s.chanCache.remove(chanID)
×
1803
        }
×
1804

1805
        return deleted, nil
×
1806
}
1807

1808
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1809
// channel identified by the channel ID. If the channel can't be found, then
1810
// ErrEdgeNotFound is returned. A struct which houses the general information
1811
// for the channel itself is returned as well as two structs that contain the
1812
// routing policies for the channel in either direction.
1813
//
1814
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1815
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1816
// the ChannelEdgeInfo will only include the public keys of each node.
1817
//
1818
// NOTE: part of the V1Store interface.
1819
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1820
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1821
        *models.ChannelEdgePolicy, error) {
×
1822

×
1823
        var (
×
1824
                ctx              = context.TODO()
×
1825
                edge             *models.ChannelEdgeInfo
×
1826
                policy1, policy2 *models.ChannelEdgePolicy
×
1827
                chanIDB          = channelIDToBytes(chanID)
×
1828
        )
×
1829
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1830
                row, err := db.GetChannelBySCIDWithPolicies(
×
1831
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1832
                                Scid:    chanIDB,
×
1833
                                Version: int16(ProtocolV1),
×
1834
                        },
×
1835
                )
×
1836
                if errors.Is(err, sql.ErrNoRows) {
×
1837
                        // First check if this edge is perhaps in the zombie
×
1838
                        // index.
×
1839
                        zombie, err := db.GetZombieChannel(
×
1840
                                ctx, sqlc.GetZombieChannelParams{
×
1841
                                        Scid:    chanIDB,
×
1842
                                        Version: int16(ProtocolV1),
×
1843
                                },
×
1844
                        )
×
1845
                        if errors.Is(err, sql.ErrNoRows) {
×
1846
                                return ErrEdgeNotFound
×
1847
                        } else if err != nil {
×
1848
                                return fmt.Errorf("unable to check if "+
×
1849
                                        "channel is zombie: %w", err)
×
1850
                        }
×
1851

1852
                        // At this point, we know the channel is a zombie, so
1853
                        // we'll return an error indicating this, and we will
1854
                        // populate the edge info with the public keys of each
1855
                        // party as this is the only information we have about
1856
                        // it.
1857
                        edge = &models.ChannelEdgeInfo{}
×
1858
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1859
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1860

×
1861
                        return ErrZombieEdge
×
1862
                } else if err != nil {
×
1863
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1864
                }
×
1865

1866
                node1, node2, err := buildNodeVertices(
×
1867
                        row.Node.PubKey, row.Node_2.PubKey,
×
1868
                )
×
1869
                if err != nil {
×
1870
                        return err
×
1871
                }
×
1872

1873
                edge, err = getAndBuildEdgeInfo(
×
1874
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1875
                        node1, node2,
×
1876
                )
×
1877
                if err != nil {
×
1878
                        return fmt.Errorf("unable to build channel info: %w",
×
1879
                                err)
×
1880
                }
×
1881

1882
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1883
                if err != nil {
×
1884
                        return fmt.Errorf("unable to extract channel "+
×
1885
                                "policies: %w", err)
×
1886
                }
×
1887

1888
                policy1, policy2, err = getAndBuildChanPolicies(
×
1889
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1890
                )
×
1891
                if err != nil {
×
1892
                        return fmt.Errorf("unable to build channel "+
×
1893
                                "policies: %w", err)
×
1894
                }
×
1895

1896
                return nil
×
1897
        }, sqldb.NoOpReset)
1898
        if err != nil {
×
1899
                // If we are returning the ErrZombieEdge, then we also need to
×
1900
                // return the edge info as the method comment indicates that
×
1901
                // this will be populated when the edge is a zombie.
×
1902
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1903
                        err)
×
1904
        }
×
1905

1906
        return edge, policy1, policy2, nil
×
1907
}
1908

1909
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1910
// the channel identified by the funding outpoint. If the channel can't be
1911
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1912
// information for the channel itself is returned as well as two structs that
1913
// contain the routing policies for the channel in either direction.
1914
//
1915
// NOTE: part of the V1Store interface.
1916
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1917
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1918
        *models.ChannelEdgePolicy, error) {
×
1919

×
1920
        var (
×
1921
                ctx              = context.TODO()
×
1922
                edge             *models.ChannelEdgeInfo
×
1923
                policy1, policy2 *models.ChannelEdgePolicy
×
1924
        )
×
1925
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1926
                row, err := db.GetChannelByOutpointWithPolicies(
×
1927
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1928
                                Outpoint: op.String(),
×
1929
                                Version:  int16(ProtocolV1),
×
1930
                        },
×
1931
                )
×
1932
                if errors.Is(err, sql.ErrNoRows) {
×
1933
                        return ErrEdgeNotFound
×
1934
                } else if err != nil {
×
1935
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1936
                }
×
1937

1938
                node1, node2, err := buildNodeVertices(
×
1939
                        row.Node1Pubkey, row.Node2Pubkey,
×
1940
                )
×
1941
                if err != nil {
×
1942
                        return err
×
1943
                }
×
1944

1945
                edge, err = getAndBuildEdgeInfo(
×
1946
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1947
                        node1, node2,
×
1948
                )
×
1949
                if err != nil {
×
1950
                        return fmt.Errorf("unable to build channel info: %w",
×
1951
                                err)
×
1952
                }
×
1953

1954
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1955
                if err != nil {
×
1956
                        return fmt.Errorf("unable to extract channel "+
×
1957
                                "policies: %w", err)
×
1958
                }
×
1959

1960
                policy1, policy2, err = getAndBuildChanPolicies(
×
1961
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1962
                )
×
1963
                if err != nil {
×
1964
                        return fmt.Errorf("unable to build channel "+
×
1965
                                "policies: %w", err)
×
1966
                }
×
1967

1968
                return nil
×
1969
        }, sqldb.NoOpReset)
1970
        if err != nil {
×
1971
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1972
                        err)
×
1973
        }
×
1974

1975
        return edge, policy1, policy2, nil
×
1976
}
1977

1978
// HasChannelEdge returns true if the database knows of a channel edge with the
1979
// passed channel ID, and false otherwise. If an edge with that ID is found
1980
// within the graph, then two time stamps representing the last time the edge
1981
// was updated for both directed edges are returned along with the boolean. If
1982
// it is not found, then the zombie index is checked and its result is returned
1983
// as the second boolean.
1984
//
1985
// NOTE: part of the V1Store interface.
1986
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1987
        bool, error) {
×
1988

×
1989
        ctx := context.TODO()
×
1990

×
1991
        var (
×
1992
                exists          bool
×
1993
                isZombie        bool
×
1994
                node1LastUpdate time.Time
×
1995
                node2LastUpdate time.Time
×
1996
        )
×
1997

×
1998
        // We'll query the cache with the shared lock held to allow multiple
×
1999
        // readers to access values in the cache concurrently if they exist.
×
2000
        s.cacheMu.RLock()
×
2001
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2002
                s.cacheMu.RUnlock()
×
2003
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2004
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2005
                exists, isZombie = entry.flags.unpack()
×
2006

×
2007
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2008
        }
×
2009
        s.cacheMu.RUnlock()
×
2010

×
2011
        s.cacheMu.Lock()
×
2012
        defer s.cacheMu.Unlock()
×
2013

×
2014
        // The item was not found with the shared lock, so we'll acquire the
×
2015
        // exclusive lock and check the cache again in case another method added
×
2016
        // the entry to the cache while no lock was held.
×
2017
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2018
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2019
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2020
                exists, isZombie = entry.flags.unpack()
×
2021

×
2022
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2023
        }
×
2024

2025
        chanIDB := channelIDToBytes(chanID)
×
2026
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2027
                channel, err := db.GetChannelBySCID(
×
2028
                        ctx, sqlc.GetChannelBySCIDParams{
×
2029
                                Scid:    chanIDB,
×
2030
                                Version: int16(ProtocolV1),
×
2031
                        },
×
2032
                )
×
2033
                if errors.Is(err, sql.ErrNoRows) {
×
2034
                        // Check if it is a zombie channel.
×
2035
                        isZombie, err = db.IsZombieChannel(
×
2036
                                ctx, sqlc.IsZombieChannelParams{
×
2037
                                        Scid:    chanIDB,
×
2038
                                        Version: int16(ProtocolV1),
×
2039
                                },
×
2040
                        )
×
2041
                        if err != nil {
×
2042
                                return fmt.Errorf("could not check if channel "+
×
2043
                                        "is zombie: %w", err)
×
2044
                        }
×
2045

2046
                        return nil
×
2047
                } else if err != nil {
×
2048
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2049
                }
×
2050

2051
                exists = true
×
2052

×
2053
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2054
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2055
                                Version:   int16(ProtocolV1),
×
2056
                                ChannelID: channel.ID,
×
2057
                                NodeID:    channel.NodeID1,
×
2058
                        },
×
2059
                )
×
2060
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2061
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2062
                                err)
×
2063
                } else if err == nil {
×
2064
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2065
                }
×
2066

2067
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2068
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2069
                                Version:   int16(ProtocolV1),
×
2070
                                ChannelID: channel.ID,
×
2071
                                NodeID:    channel.NodeID2,
×
2072
                        },
×
2073
                )
×
2074
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2075
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2076
                                err)
×
2077
                } else if err == nil {
×
2078
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2079
                }
×
2080

2081
                return nil
×
2082
        }, sqldb.NoOpReset)
2083
        if err != nil {
×
2084
                return time.Time{}, time.Time{}, false, false,
×
2085
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2086
        }
×
2087

2088
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2089
                upd1Time: node1LastUpdate.Unix(),
×
2090
                upd2Time: node2LastUpdate.Unix(),
×
2091
                flags:    packRejectFlags(exists, isZombie),
×
2092
        })
×
2093

×
2094
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2095
}
2096

2097
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2098
// passed channel point (outpoint). If the passed channel doesn't exist within
2099
// the database, then ErrEdgeNotFound is returned.
2100
//
2101
// NOTE: part of the V1Store interface.
2102
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2103
        var (
×
2104
                ctx       = context.TODO()
×
2105
                channelID uint64
×
2106
        )
×
2107
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2108
                chanID, err := db.GetSCIDByOutpoint(
×
2109
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2110
                                Outpoint: chanPoint.String(),
×
2111
                                Version:  int16(ProtocolV1),
×
2112
                        },
×
2113
                )
×
2114
                if errors.Is(err, sql.ErrNoRows) {
×
2115
                        return ErrEdgeNotFound
×
2116
                } else if err != nil {
×
2117
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2118
                                err)
×
2119
                }
×
2120

2121
                channelID = byteOrder.Uint64(chanID)
×
2122

×
2123
                return nil
×
2124
        }, sqldb.NoOpReset)
2125
        if err != nil {
×
2126
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2127
        }
×
2128

2129
        return channelID, nil
×
2130
}
2131

2132
// IsPublicNode is a helper method that determines whether the node with the
2133
// given public key is seen as a public node in the graph from the graph's
2134
// source node's point of view.
2135
//
2136
// NOTE: part of the V1Store interface.
2137
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2138
        ctx := context.TODO()
×
2139

×
2140
        var isPublic bool
×
2141
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2142
                var err error
×
2143
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2144

×
2145
                return err
×
2146
        }, sqldb.NoOpReset)
×
2147
        if err != nil {
×
2148
                return false, fmt.Errorf("unable to check if node is "+
×
2149
                        "public: %w", err)
×
2150
        }
×
2151

2152
        return isPublic, nil
×
2153
}
2154

2155
// FetchChanInfos returns the set of channel edges that correspond to the passed
2156
// channel ID's. If an edge is the query is unknown to the database, it will
2157
// skipped and the result will contain only those edges that exist at the time
2158
// of the query. This can be used to respond to peer queries that are seeking to
2159
// fill in gaps in their view of the channel graph.
2160
//
2161
// NOTE: part of the V1Store interface.
2162
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2163
        var (
×
2164
                ctx   = context.TODO()
×
2165
                edges []ChannelEdge
×
2166
        )
×
2167
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2168
                for _, chanID := range chanIDs {
×
2169
                        chanIDB := channelIDToBytes(chanID)
×
2170

×
2171
                        // TODO(elle): potentially optimize this by using
×
2172
                        //  sqlc.slice() once that works for both SQLite and
×
2173
                        //  Postgres.
×
2174
                        row, err := db.GetChannelBySCIDWithPolicies(
×
2175
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
2176
                                        Scid:    chanIDB,
×
2177
                                        Version: int16(ProtocolV1),
×
2178
                                },
×
2179
                        )
×
2180
                        if errors.Is(err, sql.ErrNoRows) {
×
2181
                                continue
×
2182
                        } else if err != nil {
×
2183
                                return fmt.Errorf("unable to fetch channel: %w",
×
2184
                                        err)
×
2185
                        }
×
2186

2187
                        node1, node2, err := buildNodes(
×
2188
                                ctx, db, row.Node, row.Node_2,
×
2189
                        )
×
2190
                        if err != nil {
×
2191
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2192
                                        err)
×
2193
                        }
×
2194

2195
                        edge, err := getAndBuildEdgeInfo(
×
2196
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2197
                                row.Channel, node1.PubKeyBytes,
×
2198
                                node2.PubKeyBytes,
×
2199
                        )
×
2200
                        if err != nil {
×
2201
                                return fmt.Errorf("unable to build "+
×
2202
                                        "channel info: %w", err)
×
2203
                        }
×
2204

2205
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2206
                        if err != nil {
×
2207
                                return fmt.Errorf("unable to extract channel "+
×
2208
                                        "policies: %w", err)
×
2209
                        }
×
2210

2211
                        p1, p2, err := getAndBuildChanPolicies(
×
2212
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2213
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2214
                        )
×
2215
                        if err != nil {
×
2216
                                return fmt.Errorf("unable to build channel "+
×
2217
                                        "policies: %w", err)
×
2218
                        }
×
2219

2220
                        edges = append(edges, ChannelEdge{
×
2221
                                Info:    edge,
×
2222
                                Policy1: p1,
×
2223
                                Policy2: p2,
×
2224
                                Node1:   node1,
×
2225
                                Node2:   node2,
×
2226
                        })
×
2227
                }
2228

2229
                return nil
×
2230
        }, func() {
×
2231
                edges = nil
×
2232
        })
×
2233
        if err != nil {
×
2234
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2235
        }
×
2236

2237
        return edges, nil
×
2238
}
2239

2240
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2241
// ID's that we don't know and are not known zombies of the passed set. In other
2242
// words, we perform a set difference of our set of chan ID's and the ones
2243
// passed in. This method can be used by callers to determine the set of
2244
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2245
// known zombies is also returned.
2246
//
2247
// NOTE: part of the V1Store interface.
2248
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2249
        []ChannelUpdateInfo, error) {
×
2250

×
2251
        var (
×
2252
                ctx          = context.TODO()
×
2253
                newChanIDs   []uint64
×
2254
                knownZombies []ChannelUpdateInfo
×
2255
        )
×
2256
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2257
                for _, chanInfo := range chansInfo {
×
2258
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2259
                        chanIDB := channelIDToBytes(channelID)
×
2260

×
2261
                        // TODO(elle): potentially optimize this by using
×
2262
                        //  sqlc.slice() once that works for both SQLite and
×
2263
                        //  Postgres.
×
2264
                        _, err := db.GetChannelBySCID(
×
2265
                                ctx, sqlc.GetChannelBySCIDParams{
×
2266
                                        Version: int16(ProtocolV1),
×
2267
                                        Scid:    chanIDB,
×
2268
                                },
×
2269
                        )
×
2270
                        if err == nil {
×
2271
                                continue
×
2272
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
2273
                                return fmt.Errorf("unable to fetch channel: %w",
×
2274
                                        err)
×
2275
                        }
×
2276

2277
                        isZombie, err := db.IsZombieChannel(
×
2278
                                ctx, sqlc.IsZombieChannelParams{
×
2279
                                        Scid:    chanIDB,
×
2280
                                        Version: int16(ProtocolV1),
×
2281
                                },
×
2282
                        )
×
2283
                        if err != nil {
×
2284
                                return fmt.Errorf("unable to fetch zombie "+
×
2285
                                        "channel: %w", err)
×
2286
                        }
×
2287

2288
                        if isZombie {
×
2289
                                knownZombies = append(knownZombies, chanInfo)
×
2290

×
2291
                                continue
×
2292
                        }
2293

2294
                        newChanIDs = append(newChanIDs, channelID)
×
2295
                }
2296

2297
                return nil
×
2298
        }, func() {
×
2299
                newChanIDs = nil
×
2300
                knownZombies = nil
×
2301
        })
×
2302
        if err != nil {
×
2303
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2304
        }
×
2305

2306
        return newChanIDs, knownZombies, nil
×
2307
}
2308

2309
// PruneGraphNodes is a garbage collection method which attempts to prune out
2310
// any nodes from the channel graph that are currently unconnected. This ensure
2311
// that we only maintain a graph of reachable nodes. In the event that a pruned
2312
// node gains more channels, it will be re-added back to the graph.
2313
//
2314
// NOTE: this prunes nodes across protocol versions. It will never prune the
2315
// source nodes.
2316
//
2317
// NOTE: part of the V1Store interface.
2318
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2319
        var ctx = context.TODO()
×
2320

×
2321
        var prunedNodes []route.Vertex
×
2322
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2323
                var err error
×
2324
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2325

×
2326
                return err
×
2327
        }, func() {
×
2328
                prunedNodes = nil
×
2329
        })
×
2330
        if err != nil {
×
2331
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2332
        }
×
2333

2334
        return prunedNodes, nil
×
2335
}
2336

2337
// PruneGraph prunes newly closed channels from the channel graph in response
2338
// to a new block being solved on the network. Any transactions which spend the
2339
// funding output of any known channels within he graph will be deleted.
2340
// Additionally, the "prune tip", or the last block which has been used to
2341
// prune the graph is stored so callers can ensure the graph is fully in sync
2342
// with the current UTXO state. A slice of channels that have been closed by
2343
// the target block along with any pruned nodes are returned if the function
2344
// succeeds without error.
2345
//
2346
// NOTE: part of the V1Store interface.
2347
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2348
        blockHash *chainhash.Hash, blockHeight uint32) (
2349
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2350

×
2351
        ctx := context.TODO()
×
2352

×
2353
        s.cacheMu.Lock()
×
2354
        defer s.cacheMu.Unlock()
×
2355

×
2356
        var (
×
2357
                closedChans []*models.ChannelEdgeInfo
×
2358
                prunedNodes []route.Vertex
×
2359
        )
×
2360
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2361
                for _, outpoint := range spentOutputs {
×
2362
                        // TODO(elle): potentially optimize this by using
×
2363
                        //  sqlc.slice() once that works for both SQLite and
×
2364
                        //  Postgres.
×
2365
                        //
×
2366
                        // NOTE: this fetches channels for all protocol
×
2367
                        // versions.
×
2368
                        row, err := db.GetChannelByOutpoint(
×
2369
                                ctx, outpoint.String(),
×
2370
                        )
×
2371
                        if errors.Is(err, sql.ErrNoRows) {
×
2372
                                continue
×
2373
                        } else if err != nil {
×
2374
                                return fmt.Errorf("unable to fetch channel: %w",
×
2375
                                        err)
×
2376
                        }
×
2377

2378
                        node1, node2, err := buildNodeVertices(
×
2379
                                row.Node1Pubkey, row.Node2Pubkey,
×
2380
                        )
×
2381
                        if err != nil {
×
2382
                                return err
×
2383
                        }
×
2384

2385
                        info, err := getAndBuildEdgeInfo(
×
2386
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2387
                                row.Channel, node1, node2,
×
2388
                        )
×
2389
                        if err != nil {
×
2390
                                return err
×
2391
                        }
×
2392

2393
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2394
                        if err != nil {
×
2395
                                return fmt.Errorf("unable to delete "+
×
2396
                                        "channel: %w", err)
×
2397
                        }
×
2398

2399
                        closedChans = append(closedChans, info)
×
2400
                }
2401

2402
                err := db.UpsertPruneLogEntry(
×
2403
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2404
                                BlockHash:   blockHash[:],
×
2405
                                BlockHeight: int64(blockHeight),
×
2406
                        },
×
2407
                )
×
2408
                if err != nil {
×
2409
                        return fmt.Errorf("unable to insert prune log "+
×
2410
                                "entry: %w", err)
×
2411
                }
×
2412

2413
                // Now that we've pruned some channels, we'll also prune any
2414
                // nodes that no longer have any channels.
2415
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2416
                if err != nil {
×
2417
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2418
                                err)
×
2419
                }
×
2420

2421
                return nil
×
2422
        }, func() {
×
2423
                prunedNodes = nil
×
2424
                closedChans = nil
×
2425
        })
×
2426
        if err != nil {
×
2427
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2428
        }
×
2429

2430
        for _, channel := range closedChans {
×
2431
                s.rejectCache.remove(channel.ChannelID)
×
2432
                s.chanCache.remove(channel.ChannelID)
×
2433
        }
×
2434

2435
        return closedChans, prunedNodes, nil
×
2436
}
2437

2438
// ChannelView returns the verifiable edge information for each active channel
2439
// within the known channel graph. The set of UTXOs (along with their scripts)
2440
// returned are the ones that need to be watched on chain to detect channel
2441
// closes on the resident blockchain.
2442
//
2443
// NOTE: part of the V1Store interface.
2444
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2445
        var (
×
2446
                ctx        = context.TODO()
×
2447
                edgePoints []EdgePoint
×
2448
        )
×
2449

×
2450
        handleChannel := func(db SQLQueries,
×
2451
                channel sqlc.ListChannelsPaginatedRow) error {
×
2452

×
2453
                pkScript, err := genMultiSigP2WSH(
×
2454
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
2455
                )
×
2456
                if err != nil {
×
2457
                        return err
×
2458
                }
×
2459

2460
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2461
                if err != nil {
×
2462
                        return err
×
2463
                }
×
2464

2465
                edgePoints = append(edgePoints, EdgePoint{
×
2466
                        FundingPkScript: pkScript,
×
2467
                        OutPoint:        *op,
×
2468
                })
×
2469

×
2470
                return nil
×
2471
        }
2472

2473
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2474
                lastID := int64(-1)
×
2475
                for {
×
2476
                        rows, err := db.ListChannelsPaginated(
×
2477
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2478
                                        Version: int16(ProtocolV1),
×
2479
                                        ID:      lastID,
×
2480
                                        Limit:   pageSize,
×
2481
                                },
×
2482
                        )
×
2483
                        if err != nil {
×
2484
                                return err
×
2485
                        }
×
2486

2487
                        if len(rows) == 0 {
×
2488
                                break
×
2489
                        }
2490

2491
                        for _, row := range rows {
×
2492
                                err := handleChannel(db, row)
×
2493
                                if err != nil {
×
2494
                                        return err
×
2495
                                }
×
2496

2497
                                lastID = row.ID
×
2498
                        }
2499
                }
2500

2501
                return nil
×
2502
        }, func() {
×
2503
                edgePoints = nil
×
2504
        })
×
2505
        if err != nil {
×
2506
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2507
        }
×
2508

2509
        return edgePoints, nil
×
2510
}
2511

2512
// PruneTip returns the block height and hash of the latest block that has been
2513
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2514
// to tell if the graph is currently in sync with the current best known UTXO
2515
// state.
2516
//
2517
// NOTE: part of the V1Store interface.
2518
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2519
        var (
×
2520
                ctx       = context.TODO()
×
2521
                tipHash   chainhash.Hash
×
2522
                tipHeight uint32
×
2523
        )
×
2524
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2525
                pruneTip, err := db.GetPruneTip(ctx)
×
2526
                if errors.Is(err, sql.ErrNoRows) {
×
2527
                        return ErrGraphNeverPruned
×
2528
                } else if err != nil {
×
2529
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2530
                }
×
2531

2532
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2533
                tipHeight = uint32(pruneTip.BlockHeight)
×
2534

×
2535
                return nil
×
2536
        }, sqldb.NoOpReset)
2537
        if err != nil {
×
2538
                return nil, 0, err
×
2539
        }
×
2540

2541
        return &tipHash, tipHeight, nil
×
2542
}
2543

2544
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2545
//
2546
// NOTE: this prunes nodes across protocol versions. It will never prune the
2547
// source nodes.
2548
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2549
        db SQLQueries) ([]route.Vertex, error) {
×
2550

×
2551
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2552
        if err != nil {
×
2553
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2554
                        "nodes: %w", err)
×
2555
        }
×
2556

2557
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2558
        for i, nodeKey := range nodeKeys {
×
2559
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2560
                if err != nil {
×
2561
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2562
                                "from bytes: %w", err)
×
2563
                }
×
2564

2565
                prunedNodes[i] = pub
×
2566
        }
2567

2568
        return prunedNodes, nil
×
2569
}
2570

2571
// DisconnectBlockAtHeight is used to indicate that the block specified
2572
// by the passed height has been disconnected from the main chain. This
2573
// will "rewind" the graph back to the height below, deleting channels
2574
// that are no longer confirmed from the graph. The prune log will be
2575
// set to the last prune height valid for the remaining chain.
2576
// Channels that were removed from the graph resulting from the
2577
// disconnected block are returned.
2578
//
2579
// NOTE: part of the V1Store interface.
2580
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2581
        []*models.ChannelEdgeInfo, error) {
×
2582

×
2583
        ctx := context.TODO()
×
2584

×
2585
        var (
×
2586
                // Every channel having a ShortChannelID starting at 'height'
×
2587
                // will no longer be confirmed.
×
2588
                startShortChanID = lnwire.ShortChannelID{
×
2589
                        BlockHeight: height,
×
2590
                }
×
2591

×
2592
                // Delete everything after this height from the db up until the
×
2593
                // SCID alias range.
×
2594
                endShortChanID = aliasmgr.StartingAlias
×
2595

×
2596
                removedChans []*models.ChannelEdgeInfo
×
2597

×
2598
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2599
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2600
        )
×
2601

×
2602
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2603
                rows, err := db.GetChannelsBySCIDRange(
×
2604
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2605
                                StartScid: chanIDStart,
×
2606
                                EndScid:   chanIDEnd,
×
2607
                        },
×
2608
                )
×
2609
                if err != nil {
×
2610
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2611
                }
×
2612

2613
                for _, row := range rows {
×
2614
                        node1, node2, err := buildNodeVertices(
×
2615
                                row.Node1PubKey, row.Node2PubKey,
×
2616
                        )
×
2617
                        if err != nil {
×
2618
                                return err
×
2619
                        }
×
2620

2621
                        channel, err := getAndBuildEdgeInfo(
×
2622
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2623
                                row.Channel, node1, node2,
×
2624
                        )
×
2625
                        if err != nil {
×
2626
                                return err
×
2627
                        }
×
2628

2629
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2630
                        if err != nil {
×
2631
                                return fmt.Errorf("unable to delete "+
×
2632
                                        "channel: %w", err)
×
2633
                        }
×
2634

2635
                        removedChans = append(removedChans, channel)
×
2636
                }
2637

2638
                return db.DeletePruneLogEntriesInRange(
×
2639
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2640
                                StartHeight: int64(height),
×
2641
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2642
                        },
×
2643
                )
×
2644
        }, func() {
×
2645
                removedChans = nil
×
2646
        })
×
2647
        if err != nil {
×
2648
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2649
                        "height: %w", err)
×
2650
        }
×
2651

2652
        for _, channel := range removedChans {
×
2653
                s.rejectCache.remove(channel.ChannelID)
×
2654
                s.chanCache.remove(channel.ChannelID)
×
2655
        }
×
2656

2657
        return removedChans, nil
×
2658
}
2659

2660
// AddEdgeProof sets the proof of an existing edge in the graph database.
2661
//
2662
// NOTE: part of the V1Store interface.
2663
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2664
        proof *models.ChannelAuthProof) error {
×
2665

×
2666
        var (
×
2667
                ctx       = context.TODO()
×
2668
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2669
        )
×
2670

×
2671
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2672
                res, err := db.AddV1ChannelProof(
×
2673
                        ctx, sqlc.AddV1ChannelProofParams{
×
2674
                                Scid:              scidBytes,
×
2675
                                Node1Signature:    proof.NodeSig1Bytes,
×
2676
                                Node2Signature:    proof.NodeSig2Bytes,
×
2677
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2678
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2679
                        },
×
2680
                )
×
2681
                if err != nil {
×
2682
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2683
                }
×
2684

2685
                n, err := res.RowsAffected()
×
2686
                if err != nil {
×
2687
                        return err
×
2688
                }
×
2689

2690
                if n == 0 {
×
2691
                        return fmt.Errorf("no rows affected when adding edge "+
×
2692
                                "proof for SCID %v", scid)
×
2693
                } else if n > 1 {
×
2694
                        return fmt.Errorf("multiple rows affected when adding "+
×
2695
                                "edge proof for SCID %v: %d rows affected",
×
2696
                                scid, n)
×
2697
                }
×
2698

2699
                return nil
×
2700
        }, sqldb.NoOpReset)
2701
        if err != nil {
×
2702
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2703
        }
×
2704

2705
        return nil
×
2706
}
2707

2708
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2709
// that we can ignore channel announcements that we know to be closed without
2710
// having to validate them and fetch a block.
2711
//
2712
// NOTE: part of the V1Store interface.
2713
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2714
        var (
×
2715
                ctx     = context.TODO()
×
2716
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2717
        )
×
2718

×
2719
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2720
                return db.InsertClosedChannel(ctx, chanIDB)
×
2721
        }, sqldb.NoOpReset)
×
2722
}
2723

2724
// IsClosedScid checks whether a channel identified by the passed in scid is
2725
// closed. This helps avoid having to perform expensive validation checks.
2726
//
2727
// NOTE: part of the V1Store interface.
2728
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2729
        var (
×
2730
                ctx      = context.TODO()
×
2731
                isClosed bool
×
2732
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2733
        )
×
2734
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2735
                var err error
×
2736
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2737
                if err != nil {
×
2738
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2739
                                err)
×
2740
                }
×
2741

2742
                return nil
×
2743
        }, sqldb.NoOpReset)
2744
        if err != nil {
×
2745
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2746
                        err)
×
2747
        }
×
2748

2749
        return isClosed, nil
×
2750
}
2751

2752
// GraphSession will provide the call-back with access to a NodeTraverser
2753
// instance which can be used to perform queries against the channel graph.
2754
//
2755
// NOTE: part of the V1Store interface.
2756
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error) error {
×
2757
        var ctx = context.TODO()
×
2758

×
2759
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2760
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2761
        }, sqldb.NoOpReset)
×
2762
}
2763

2764
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2765
// read only transaction for a consistent view of the graph.
2766
type sqlNodeTraverser struct {
2767
        db    SQLQueries
2768
        chain chainhash.Hash
2769
}
2770

2771
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2772
// NodeTraverser interface.
2773
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2774

2775
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2776
func newSQLNodeTraverser(db SQLQueries,
2777
        chain chainhash.Hash) *sqlNodeTraverser {
×
2778

×
2779
        return &sqlNodeTraverser{
×
2780
                db:    db,
×
2781
                chain: chain,
×
2782
        }
×
2783
}
×
2784

2785
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2786
// node.
2787
//
2788
// NOTE: Part of the NodeTraverser interface.
2789
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2790
        cb func(channel *DirectedChannel) error) error {
×
2791

×
2792
        ctx := context.TODO()
×
2793

×
2794
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2795
}
×
2796

2797
// FetchNodeFeatures returns the features of the given node. If the node is
2798
// unknown, assume no additional features are supported.
2799
//
2800
// NOTE: Part of the NodeTraverser interface.
2801
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2802
        *lnwire.FeatureVector, error) {
×
2803

×
2804
        ctx := context.TODO()
×
2805

×
2806
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2807
}
×
2808

2809
// forEachNodeDirectedChannel iterates through all channels of a given
2810
// node, executing the passed callback on the directed edge representing the
2811
// channel and its incoming policy. If the node is not found, no error is
2812
// returned.
2813
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2814
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2815

×
2816
        toNodeCallback := func() route.Vertex {
×
2817
                return nodePub
×
2818
        }
×
2819

2820
        dbID, err := db.GetNodeIDByPubKey(
×
2821
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2822
                        Version: int16(ProtocolV1),
×
2823
                        PubKey:  nodePub[:],
×
2824
                },
×
2825
        )
×
2826
        if errors.Is(err, sql.ErrNoRows) {
×
2827
                return nil
×
2828
        } else if err != nil {
×
2829
                return fmt.Errorf("unable to fetch node: %w", err)
×
2830
        }
×
2831

2832
        rows, err := db.ListChannelsByNodeID(
×
2833
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2834
                        Version: int16(ProtocolV1),
×
2835
                        NodeID1: dbID,
×
2836
                },
×
2837
        )
×
2838
        if err != nil {
×
2839
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2840
        }
×
2841

2842
        // Exit early if there are no channels for this node so we don't
2843
        // do the unnecessary feature fetching.
2844
        if len(rows) == 0 {
×
2845
                return nil
×
2846
        }
×
2847

2848
        features, err := getNodeFeatures(ctx, db, dbID)
×
2849
        if err != nil {
×
2850
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2851
        }
×
2852

2853
        for _, row := range rows {
×
2854
                node1, node2, err := buildNodeVertices(
×
2855
                        row.Node1Pubkey, row.Node2Pubkey,
×
2856
                )
×
2857
                if err != nil {
×
2858
                        return fmt.Errorf("unable to build node vertices: %w",
×
2859
                                err)
×
2860
                }
×
2861

2862
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
2863

×
2864
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2865
                if err != nil {
×
2866
                        return err
×
2867
                }
×
2868

2869
                var p1, p2 *models.CachedEdgePolicy
×
2870
                if dbPol1 != nil {
×
2871
                        policy1, err := buildChanPolicy(
×
2872
                                *dbPol1, edge.ChannelID, nil, node2,
×
2873
                        )
×
2874
                        if err != nil {
×
2875
                                return err
×
2876
                        }
×
2877

2878
                        p1 = models.NewCachedPolicy(policy1)
×
2879
                }
2880
                if dbPol2 != nil {
×
2881
                        policy2, err := buildChanPolicy(
×
2882
                                *dbPol2, edge.ChannelID, nil, node1,
×
2883
                        )
×
2884
                        if err != nil {
×
2885
                                return err
×
2886
                        }
×
2887

2888
                        p2 = models.NewCachedPolicy(policy2)
×
2889
                }
2890

2891
                // Determine the outgoing and incoming policy for this
2892
                // channel and node combo.
2893
                outPolicy, inPolicy := p1, p2
×
2894
                if p1 != nil && node2 == nodePub {
×
2895
                        outPolicy, inPolicy = p2, p1
×
2896
                } else if p2 != nil && node1 != nodePub {
×
2897
                        outPolicy, inPolicy = p2, p1
×
2898
                }
×
2899

2900
                var cachedInPolicy *models.CachedEdgePolicy
×
2901
                if inPolicy != nil {
×
2902
                        cachedInPolicy = inPolicy
×
2903
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2904
                        cachedInPolicy.ToNodeFeatures = features
×
2905
                }
×
2906

2907
                directedChannel := &DirectedChannel{
×
2908
                        ChannelID:    edge.ChannelID,
×
2909
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2910
                        OtherNode:    edge.NodeKey2Bytes,
×
2911
                        Capacity:     edge.Capacity,
×
2912
                        OutPolicySet: outPolicy != nil,
×
2913
                        InPolicy:     cachedInPolicy,
×
2914
                }
×
2915
                if outPolicy != nil {
×
2916
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2917
                                directedChannel.InboundFee = fee
×
2918
                        })
×
2919
                }
2920

2921
                if nodePub == edge.NodeKey2Bytes {
×
2922
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2923
                }
×
2924

2925
                if err := cb(directedChannel); err != nil {
×
2926
                        return err
×
2927
                }
×
2928
        }
2929

2930
        return nil
×
2931
}
2932

2933
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2934
// and executes the provided callback for each node.
2935
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
2936
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
2937

×
2938
        lastID := int64(-1)
×
2939

×
2940
        for {
×
2941
                nodes, err := db.ListNodeIDsAndPubKeys(
×
2942
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
2943
                                Version: int16(ProtocolV1),
×
2944
                                ID:      lastID,
×
2945
                                Limit:   pageSize,
×
2946
                        },
×
2947
                )
×
2948
                if err != nil {
×
2949
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
2950
                }
×
2951

2952
                if len(nodes) == 0 {
×
2953
                        break
×
2954
                }
2955

2956
                for _, node := range nodes {
×
2957
                        var pub route.Vertex
×
2958
                        copy(pub[:], node.PubKey)
×
2959

×
2960
                        if err := cb(node.ID, pub); err != nil {
×
2961
                                return fmt.Errorf("forEachNodeCacheable "+
×
2962
                                        "callback failed for node(id=%d): %w",
×
2963
                                        node.ID, err)
×
2964
                        }
×
2965

2966
                        lastID = node.ID
×
2967
                }
2968
        }
2969

2970
        return nil
×
2971
}
2972

2973
// forEachNodeChannel iterates through all channels of a node, executing
2974
// the passed callback on each. The call-back is provided with the channel's
2975
// edge information, the outgoing policy and the incoming policy for the
2976
// channel and node combo.
2977
func forEachNodeChannel(ctx context.Context, db SQLQueries,
2978
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
2979
                *models.ChannelEdgePolicy,
2980
                *models.ChannelEdgePolicy) error) error {
×
2981

×
2982
        // Get all the V1 channels for this node.Add commentMore actions
×
2983
        rows, err := db.ListChannelsByNodeID(
×
2984
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2985
                        Version: int16(ProtocolV1),
×
2986
                        NodeID1: id,
×
2987
                },
×
2988
        )
×
2989
        if err != nil {
×
2990
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2991
        }
×
2992

2993
        // Call the call-back for each channel and its known policies.
2994
        for _, row := range rows {
×
2995
                node1, node2, err := buildNodeVertices(
×
2996
                        row.Node1Pubkey, row.Node2Pubkey,
×
2997
                )
×
2998
                if err != nil {
×
2999
                        return fmt.Errorf("unable to build node vertices: %w",
×
3000
                                err)
×
3001
                }
×
3002

3003
                edge, err := getAndBuildEdgeInfo(
×
3004
                        ctx, db, chain, row.Channel.ID, row.Channel, node1,
×
3005
                        node2,
×
3006
                )
×
3007
                if err != nil {
×
3008
                        return fmt.Errorf("unable to build channel info: %w",
×
3009
                                err)
×
3010
                }
×
3011

3012
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3013
                if err != nil {
×
3014
                        return fmt.Errorf("unable to extract channel "+
×
3015
                                "policies: %w", err)
×
3016
                }
×
3017

3018
                p1, p2, err := getAndBuildChanPolicies(
×
3019
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3020
                )
×
3021
                if err != nil {
×
3022
                        return fmt.Errorf("unable to build channel "+
×
3023
                                "policies: %w", err)
×
3024
                }
×
3025

3026
                // Determine the outgoing and incoming policy for this
3027
                // channel and node combo.
3028
                p1ToNode := row.Channel.NodeID2
×
3029
                p2ToNode := row.Channel.NodeID1
×
3030
                outPolicy, inPolicy := p1, p2
×
3031
                if (p1 != nil && p1ToNode == id) ||
×
3032
                        (p2 != nil && p2ToNode != id) {
×
3033

×
3034
                        outPolicy, inPolicy = p2, p1
×
3035
                }
×
3036

3037
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3038
                        return err
×
3039
                }
×
3040
        }
3041

3042
        return nil
×
3043
}
3044

3045
// updateChanEdgePolicy upserts the channel policy info we have stored for
3046
// a channel we already know of.
3047
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3048
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3049
        error) {
×
3050

×
3051
        var (
×
3052
                node1Pub, node2Pub route.Vertex
×
3053
                isNode1            bool
×
3054
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3055
        )
×
3056

×
3057
        // Check that this edge policy refers to a channel that we already
×
3058
        // know of. We do this explicitly so that we can return the appropriate
×
3059
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3060
        // abort the transaction which would abort the entire batch.
×
3061
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3062
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3063
                        Scid:    chanIDB,
×
3064
                        Version: int16(ProtocolV1),
×
3065
                },
×
3066
        )
×
3067
        if errors.Is(err, sql.ErrNoRows) {
×
3068
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3069
        } else if err != nil {
×
3070
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3071
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3072
        }
×
3073

3074
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3075
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3076

×
3077
        // Figure out which node this edge is from.
×
3078
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3079
        nodeID := dbChan.NodeID1
×
3080
        if !isNode1 {
×
3081
                nodeID = dbChan.NodeID2
×
3082
        }
×
3083

3084
        var (
×
3085
                inboundBase sql.NullInt64
×
3086
                inboundRate sql.NullInt64
×
3087
        )
×
3088
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3089
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3090
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3091
        })
×
3092

3093
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3094
                Version:     int16(ProtocolV1),
×
3095
                ChannelID:   dbChan.ID,
×
3096
                NodeID:      nodeID,
×
3097
                Timelock:    int32(edge.TimeLockDelta),
×
3098
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3099
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3100
                MinHtlcMsat: int64(edge.MinHTLC),
×
3101
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3102
                Disabled: sql.NullBool{
×
3103
                        Valid: true,
×
3104
                        Bool:  edge.IsDisabled(),
×
3105
                },
×
3106
                MaxHtlcMsat: sql.NullInt64{
×
3107
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3108
                        Int64: int64(edge.MaxHTLC),
×
3109
                },
×
3110
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3111
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3112
                InboundBaseFeeMsat:      inboundBase,
×
3113
                InboundFeeRateMilliMsat: inboundRate,
×
3114
                Signature:               edge.SigBytes,
×
3115
        })
×
3116
        if err != nil {
×
3117
                return node1Pub, node2Pub, isNode1,
×
3118
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3119
        }
×
3120

3121
        // Convert the flat extra opaque data into a map of TLV types to
3122
        // values.
3123
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3124
        if err != nil {
×
3125
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3126
                        "marshal extra opaque data: %w", err)
×
3127
        }
×
3128

3129
        // Update the channel policy's extra signed fields.
3130
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3131
        if err != nil {
×
3132
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3133
                        "policy extra TLVs: %w", err)
×
3134
        }
×
3135

3136
        return node1Pub, node2Pub, isNode1, nil
×
3137
}
3138

3139
// getNodeByPubKey attempts to look up a target node by its public key.
3140
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3141
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3142

×
3143
        dbNode, err := db.GetNodeByPubKey(
×
3144
                ctx, sqlc.GetNodeByPubKeyParams{
×
3145
                        Version: int16(ProtocolV1),
×
3146
                        PubKey:  pubKey[:],
×
3147
                },
×
3148
        )
×
3149
        if errors.Is(err, sql.ErrNoRows) {
×
3150
                return 0, nil, ErrGraphNodeNotFound
×
3151
        } else if err != nil {
×
3152
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3153
        }
×
3154

3155
        node, err := buildNode(ctx, db, &dbNode)
×
3156
        if err != nil {
×
3157
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3158
        }
×
3159

3160
        return dbNode.ID, node, nil
×
3161
}
3162

3163
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3164
// provided database channel row and the public keys of the two nodes
3165
// involved in the channel.
3166
func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub,
3167
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3168

×
3169
        return &models.CachedEdgeInfo{
×
3170
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
3171
                NodeKey1Bytes: node1Pub,
×
3172
                NodeKey2Bytes: node2Pub,
×
3173
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
3174
        }
×
3175
}
×
3176

3177
// buildNode constructs a LightningNode instance from the given database node
3178
// record. The node's features, addresses and extra signed fields are also
3179
// fetched from the database and set on the node.
3180
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
3181
        *models.LightningNode, error) {
×
3182

×
3183
        if dbNode.Version != int16(ProtocolV1) {
×
3184
                return nil, fmt.Errorf("unsupported node version: %d",
×
3185
                        dbNode.Version)
×
3186
        }
×
3187

3188
        var pub [33]byte
×
3189
        copy(pub[:], dbNode.PubKey)
×
3190

×
3191
        node := &models.LightningNode{
×
3192
                PubKeyBytes: pub,
×
3193
                Features:    lnwire.EmptyFeatureVector(),
×
3194
                LastUpdate:  time.Unix(0, 0),
×
3195
        }
×
3196

×
3197
        if len(dbNode.Signature) == 0 {
×
3198
                return node, nil
×
3199
        }
×
3200

3201
        node.HaveNodeAnnouncement = true
×
3202
        node.AuthSigBytes = dbNode.Signature
×
3203
        node.Alias = dbNode.Alias.String
×
3204
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3205

×
3206
        var err error
×
3207
        if dbNode.Color.Valid {
×
3208
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3209
                if err != nil {
×
3210
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3211
                                err)
×
3212
                }
×
3213
        }
3214

3215
        // Fetch the node's features.
3216
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
3217
        if err != nil {
×
3218
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3219
                        "features: %w", dbNode.ID, err)
×
3220
        }
×
3221

3222
        // Fetch the node's addresses.
3223
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
3224
        if err != nil {
×
3225
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3226
                        "addresses: %w", dbNode.ID, err)
×
3227
        }
×
3228

3229
        // Fetch the node's extra signed fields.
3230
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
3231
        if err != nil {
×
3232
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3233
                        "extra signed fields: %w", dbNode.ID, err)
×
3234
        }
×
3235

3236
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
3237
        if err != nil {
×
3238
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3239
                        "fields: %w", err)
×
3240
        }
×
3241

3242
        if len(recs) != 0 {
×
3243
                node.ExtraOpaqueData = recs
×
3244
        }
×
3245

3246
        return node, nil
×
3247
}
3248

3249
// getNodeFeatures fetches the feature bits and constructs the feature vector
3250
// for a node with the given DB ID.
3251
func getNodeFeatures(ctx context.Context, db SQLQueries,
3252
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3253

×
3254
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3255
        if err != nil {
×
3256
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3257
                        nodeID, err)
×
3258
        }
×
3259

3260
        features := lnwire.EmptyFeatureVector()
×
3261
        for _, feature := range rows {
×
3262
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3263
        }
×
3264

3265
        return features, nil
×
3266
}
3267

3268
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3269
// given DB ID.
3270
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3271
        nodeID int64) (map[uint64][]byte, error) {
×
3272

×
3273
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3274
        if err != nil {
×
3275
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
3276
                        "signed fields: %w", nodeID, err)
×
3277
        }
×
3278

3279
        extraFields := make(map[uint64][]byte)
×
3280
        for _, field := range fields {
×
3281
                extraFields[uint64(field.Type)] = field.Value
×
3282
        }
×
3283

3284
        return extraFields, nil
×
3285
}
3286

3287
// upsertNode upserts the node record into the database. If the node already
3288
// exists, then the node's information is updated. If the node doesn't exist,
3289
// then a new node is created. The node's features, addresses and extra TLV
3290
// types are also updated. The node's DB ID is returned.
3291
func upsertNode(ctx context.Context, db SQLQueries,
3292
        node *models.LightningNode) (int64, error) {
×
3293

×
3294
        params := sqlc.UpsertNodeParams{
×
3295
                Version: int16(ProtocolV1),
×
3296
                PubKey:  node.PubKeyBytes[:],
×
3297
        }
×
3298

×
3299
        if node.HaveNodeAnnouncement {
×
3300
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3301
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3302
                params.Alias = sqldb.SQLStr(node.Alias)
×
3303
                params.Signature = node.AuthSigBytes
×
3304
        }
×
3305

3306
        nodeID, err := db.UpsertNode(ctx, params)
×
3307
        if err != nil {
×
3308
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3309
                        err)
×
3310
        }
×
3311

3312
        // We can exit here if we don't have the announcement yet.
3313
        if !node.HaveNodeAnnouncement {
×
3314
                return nodeID, nil
×
3315
        }
×
3316

3317
        // Update the node's features.
3318
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3319
        if err != nil {
×
3320
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3321
        }
×
3322

3323
        // Update the node's addresses.
3324
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3325
        if err != nil {
×
3326
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3327
        }
×
3328

3329
        // Convert the flat extra opaque data into a map of TLV types to
3330
        // values.
3331
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3332
        if err != nil {
×
3333
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3334
                        err)
×
3335
        }
×
3336

3337
        // Update the node's extra signed fields.
3338
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3339
        if err != nil {
×
3340
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3341
        }
×
3342

3343
        return nodeID, nil
×
3344
}
3345

3346
// upsertNodeFeatures updates the node's features node_features table. This
3347
// includes deleting any feature bits no longer present and inserting any new
3348
// feature bits. If the feature bit does not yet exist in the features table,
3349
// then an entry is created in that table first.
3350
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3351
        features *lnwire.FeatureVector) error {
×
3352

×
3353
        // Get any existing features for the node.
×
3354
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3355
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3356
                return err
×
3357
        }
×
3358

3359
        // Copy the nodes latest set of feature bits.
3360
        newFeatures := make(map[int32]struct{})
×
3361
        if features != nil {
×
3362
                for feature := range features.Features() {
×
3363
                        newFeatures[int32(feature)] = struct{}{}
×
3364
                }
×
3365
        }
3366

3367
        // For any current feature that already exists in the DB, remove it from
3368
        // the in-memory map. For any existing feature that does not exist in
3369
        // the in-memory map, delete it from the database.
3370
        for _, feature := range existingFeatures {
×
3371
                // The feature is still present, so there are no updates to be
×
3372
                // made.
×
3373
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3374
                        delete(newFeatures, feature.FeatureBit)
×
3375
                        continue
×
3376
                }
3377

3378
                // The feature is no longer present, so we remove it from the
3379
                // database.
3380
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3381
                        NodeID:     nodeID,
×
3382
                        FeatureBit: feature.FeatureBit,
×
3383
                })
×
3384
                if err != nil {
×
3385
                        return fmt.Errorf("unable to delete node(%d) "+
×
3386
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3387
                                err)
×
3388
                }
×
3389
        }
3390

3391
        // Any remaining entries in newFeatures are new features that need to be
3392
        // added to the database for the first time.
3393
        for feature := range newFeatures {
×
3394
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3395
                        NodeID:     nodeID,
×
3396
                        FeatureBit: feature,
×
3397
                })
×
3398
                if err != nil {
×
3399
                        return fmt.Errorf("unable to insert node(%d) "+
×
3400
                                "feature(%v): %w", nodeID, feature, err)
×
3401
                }
×
3402
        }
3403

3404
        return nil
×
3405
}
3406

3407
// fetchNodeFeatures fetches the features for a node with the given public key.
3408
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3409
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3410

×
3411
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3412
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3413
                        PubKey:  nodePub[:],
×
3414
                        Version: int16(ProtocolV1),
×
3415
                },
×
3416
        )
×
3417
        if err != nil {
×
3418
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3419
                        nodePub, err)
×
3420
        }
×
3421

3422
        features := lnwire.EmptyFeatureVector()
×
3423
        for _, bit := range rows {
×
3424
                features.Set(lnwire.FeatureBit(bit))
×
3425
        }
×
3426

3427
        return features, nil
×
3428
}
3429

3430
// dbAddressType is an enum type that represents the different address types
3431
// that we store in the node_addresses table. The address type determines how
3432
// the address is to be serialised/deserialize.
3433
type dbAddressType uint8
3434

3435
const (
3436
        addressTypeIPv4   dbAddressType = 1
3437
        addressTypeIPv6   dbAddressType = 2
3438
        addressTypeTorV2  dbAddressType = 3
3439
        addressTypeTorV3  dbAddressType = 4
3440
        addressTypeOpaque dbAddressType = math.MaxInt8
3441
)
3442

3443
// upsertNodeAddresses updates the node's addresses in the database. This
3444
// includes deleting any existing addresses and inserting the new set of
3445
// addresses. The deletion is necessary since the ordering of the addresses may
3446
// change, and we need to ensure that the database reflects the latest set of
3447
// addresses so that at the time of reconstructing the node announcement, the
3448
// order is preserved and the signature over the message remains valid.
3449
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3450
        addresses []net.Addr) error {
×
3451

×
3452
        // Delete any existing addresses for the node. This is required since
×
3453
        // even if the new set of addresses is the same, the ordering may have
×
3454
        // changed for a given address type.
×
3455
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3456
        if err != nil {
×
3457
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3458
                        nodeID, err)
×
3459
        }
×
3460

3461
        // Copy the nodes latest set of addresses.
3462
        newAddresses := map[dbAddressType][]string{
×
3463
                addressTypeIPv4:   {},
×
3464
                addressTypeIPv6:   {},
×
3465
                addressTypeTorV2:  {},
×
3466
                addressTypeTorV3:  {},
×
3467
                addressTypeOpaque: {},
×
3468
        }
×
3469
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3470
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3471
        }
×
3472

3473
        for _, address := range addresses {
×
3474
                switch addr := address.(type) {
×
3475
                case *net.TCPAddr:
×
3476
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3477
                                addAddr(addressTypeIPv4, addr)
×
3478
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3479
                                addAddr(addressTypeIPv6, addr)
×
3480
                        } else {
×
3481
                                return fmt.Errorf("unhandled IP address: %v",
×
3482
                                        addr)
×
3483
                        }
×
3484

3485
                case *tor.OnionAddr:
×
3486
                        switch len(addr.OnionService) {
×
3487
                        case tor.V2Len:
×
3488
                                addAddr(addressTypeTorV2, addr)
×
3489
                        case tor.V3Len:
×
3490
                                addAddr(addressTypeTorV3, addr)
×
3491
                        default:
×
3492
                                return fmt.Errorf("invalid length for a tor " +
×
3493
                                        "address")
×
3494
                        }
3495

3496
                case *lnwire.OpaqueAddrs:
×
3497
                        addAddr(addressTypeOpaque, addr)
×
3498

3499
                default:
×
3500
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3501
                }
3502
        }
3503

3504
        // Any remaining entries in newAddresses are new addresses that need to
3505
        // be added to the database for the first time.
3506
        for addrType, addrList := range newAddresses {
×
3507
                for position, addr := range addrList {
×
3508
                        err := db.InsertNodeAddress(
×
3509
                                ctx, sqlc.InsertNodeAddressParams{
×
3510
                                        NodeID:   nodeID,
×
3511
                                        Type:     int16(addrType),
×
3512
                                        Address:  addr,
×
3513
                                        Position: int32(position),
×
3514
                                },
×
3515
                        )
×
3516
                        if err != nil {
×
3517
                                return fmt.Errorf("unable to insert "+
×
3518
                                        "node(%d) address(%v): %w", nodeID,
×
3519
                                        addr, err)
×
3520
                        }
×
3521
                }
3522
        }
3523

3524
        return nil
×
3525
}
3526

3527
// getNodeAddresses fetches the addresses for a node with the given public key.
3528
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3529
        []net.Addr, error) {
×
3530

×
3531
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
3532
        // are returned in the same order as they were inserted.
×
3533
        rows, err := db.GetNodeAddressesByPubKey(
×
3534
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
3535
                        Version: int16(ProtocolV1),
×
3536
                        PubKey:  nodePub,
×
3537
                },
×
3538
        )
×
3539
        if err != nil {
×
3540
                return false, nil, err
×
3541
        }
×
3542

3543
        // GetNodeAddressesByPubKey uses a left join so there should always be
3544
        // at least one row returned if the node exists even if it has no
3545
        // addresses.
3546
        if len(rows) == 0 {
×
3547
                return false, nil, nil
×
3548
        }
×
3549

3550
        addresses := make([]net.Addr, 0, len(rows))
×
3551
        for _, addr := range rows {
×
3552
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3553
                        continue
×
3554
                }
3555

3556
                address := addr.Address.String
×
3557

×
3558
                switch dbAddressType(addr.Type.Int16) {
×
3559
                case addressTypeIPv4:
×
3560
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3561
                        if err != nil {
×
3562
                                return false, nil, nil
×
3563
                        }
×
3564
                        tcp.IP = tcp.IP.To4()
×
3565

×
3566
                        addresses = append(addresses, tcp)
×
3567

3568
                case addressTypeIPv6:
×
3569
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3570
                        if err != nil {
×
3571
                                return false, nil, nil
×
3572
                        }
×
3573
                        addresses = append(addresses, tcp)
×
3574

3575
                case addressTypeTorV3, addressTypeTorV2:
×
3576
                        service, portStr, err := net.SplitHostPort(address)
×
3577
                        if err != nil {
×
3578
                                return false, nil, fmt.Errorf("unable to "+
×
3579
                                        "split tor v3 address: %v",
×
3580
                                        addr.Address)
×
3581
                        }
×
3582

3583
                        port, err := strconv.Atoi(portStr)
×
3584
                        if err != nil {
×
3585
                                return false, nil, err
×
3586
                        }
×
3587

3588
                        addresses = append(addresses, &tor.OnionAddr{
×
3589
                                OnionService: service,
×
3590
                                Port:         port,
×
3591
                        })
×
3592

3593
                case addressTypeOpaque:
×
3594
                        opaque, err := hex.DecodeString(address)
×
3595
                        if err != nil {
×
3596
                                return false, nil, fmt.Errorf("unable to "+
×
3597
                                        "decode opaque address: %v", addr)
×
3598
                        }
×
3599

3600
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3601
                                Payload: opaque,
×
3602
                        })
×
3603

3604
                default:
×
3605
                        return false, nil, fmt.Errorf("unknown address "+
×
3606
                                "type: %v", addr.Type)
×
3607
                }
3608
        }
3609

3610
        // If we have no addresses, then we'll return nil instead of an
3611
        // empty slice.
3612
        if len(addresses) == 0 {
×
3613
                addresses = nil
×
3614
        }
×
3615

3616
        return true, addresses, nil
×
3617
}
3618

3619
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3620
// database. This includes updating any existing types, inserting any new types,
3621
// and deleting any types that are no longer present.
3622
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3623
        nodeID int64, extraFields map[uint64][]byte) error {
×
3624

×
3625
        // Get any existing extra signed fields for the node.
×
3626
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3627
        if err != nil {
×
3628
                return err
×
3629
        }
×
3630

3631
        // Make a lookup map of the existing field types so that we can use it
3632
        // to keep track of any fields we should delete.
3633
        m := make(map[uint64]bool)
×
3634
        for _, field := range existingFields {
×
3635
                m[uint64(field.Type)] = true
×
3636
        }
×
3637

3638
        // For all the new fields, we'll upsert them and remove them from the
3639
        // map of existing fields.
3640
        for tlvType, value := range extraFields {
×
3641
                err = db.UpsertNodeExtraType(
×
3642
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3643
                                NodeID: nodeID,
×
3644
                                Type:   int64(tlvType),
×
3645
                                Value:  value,
×
3646
                        },
×
3647
                )
×
3648
                if err != nil {
×
3649
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3650
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3651
                }
×
3652

3653
                // Remove the field from the map of existing fields if it was
3654
                // present.
3655
                delete(m, tlvType)
×
3656
        }
3657

3658
        // For all the fields that are left in the map of existing fields, we'll
3659
        // delete them as they are no longer present in the new set of fields.
3660
        for tlvType := range m {
×
3661
                err = db.DeleteExtraNodeType(
×
3662
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3663
                                NodeID: nodeID,
×
3664
                                Type:   int64(tlvType),
×
3665
                        },
×
3666
                )
×
3667
                if err != nil {
×
3668
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3669
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3670
                }
×
3671
        }
3672

3673
        return nil
×
3674
}
3675

3676
// srcNodeInfo holds the information about the source node of the graph.
3677
type srcNodeInfo struct {
3678
        // id is the DB level ID of the source node entry in the "nodes" table.
3679
        id int64
3680

3681
        // pub is the public key of the source node.
3682
        pub route.Vertex
3683
}
3684

3685
// sourceNode returns the DB node ID and pub key of the source node for the
3686
// specified protocol version.
3687
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3688
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3689

×
3690
        s.srcNodeMu.Lock()
×
3691
        defer s.srcNodeMu.Unlock()
×
3692

×
3693
        // If we already have the source node ID and pub key cached, then
×
3694
        // return them.
×
3695
        if info, ok := s.srcNodes[version]; ok {
×
3696
                return info.id, info.pub, nil
×
3697
        }
×
3698

3699
        var pubKey route.Vertex
×
3700

×
3701
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3702
        if err != nil {
×
3703
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3704
                        err)
×
3705
        }
×
3706

3707
        if len(nodes) == 0 {
×
3708
                return 0, pubKey, ErrSourceNodeNotSet
×
3709
        } else if len(nodes) > 1 {
×
3710
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3711
                        "protocol %s found", version)
×
3712
        }
×
3713

3714
        copy(pubKey[:], nodes[0].PubKey)
×
3715

×
3716
        s.srcNodes[version] = &srcNodeInfo{
×
3717
                id:  nodes[0].NodeID,
×
3718
                pub: pubKey,
×
3719
        }
×
3720

×
3721
        return nodes[0].NodeID, pubKey, nil
×
3722
}
3723

3724
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3725
// This then produces a map from TLV type to value. If the input is not a
3726
// valid TLV stream, then an error is returned.
3727
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3728
        r := bytes.NewReader(data)
×
3729

×
3730
        tlvStream, err := tlv.NewStream()
×
3731
        if err != nil {
×
3732
                return nil, err
×
3733
        }
×
3734

3735
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3736
        // pass it into the P2P decoding variant.
3737
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3738
        if err != nil {
×
3739
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3740
        }
×
3741
        if len(parsedTypes) == 0 {
×
3742
                return nil, nil
×
3743
        }
×
3744

3745
        records := make(map[uint64][]byte)
×
3746
        for k, v := range parsedTypes {
×
3747
                records[uint64(k)] = v
×
3748
        }
×
3749

3750
        return records, nil
×
3751
}
3752

3753
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3754
// channel.
3755
type dbChanInfo struct {
3756
        channelID int64
3757
        node1ID   int64
3758
        node2ID   int64
3759
}
3760

3761
// insertChannel inserts a new channel record into the database.
3762
func insertChannel(ctx context.Context, db SQLQueries,
3763
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3764

×
3765
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3766

×
3767
        // Make sure that the channel doesn't already exist. We do this
×
3768
        // explicitly instead of relying on catching a unique constraint error
×
3769
        // because relying on SQL to throw that error would abort the entire
×
3770
        // batch of transactions.
×
3771
        _, err := db.GetChannelBySCID(
×
3772
                ctx, sqlc.GetChannelBySCIDParams{
×
3773
                        Scid:    chanIDB,
×
3774
                        Version: int16(ProtocolV1),
×
3775
                },
×
3776
        )
×
3777
        if err == nil {
×
3778
                return nil, ErrEdgeAlreadyExist
×
3779
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3780
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3781
        }
×
3782

3783
        // Make sure that at least a "shell" entry for each node is present in
3784
        // the nodes table.
3785
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3786
        if err != nil {
×
3787
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3788
        }
×
3789

3790
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3791
        if err != nil {
×
3792
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3793
        }
×
3794

3795
        var capacity sql.NullInt64
×
3796
        if edge.Capacity != 0 {
×
3797
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3798
        }
×
3799

3800
        createParams := sqlc.CreateChannelParams{
×
3801
                Version:     int16(ProtocolV1),
×
3802
                Scid:        chanIDB,
×
3803
                NodeID1:     node1DBID,
×
3804
                NodeID2:     node2DBID,
×
3805
                Outpoint:    edge.ChannelPoint.String(),
×
3806
                Capacity:    capacity,
×
3807
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3808
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3809
        }
×
3810

×
3811
        if edge.AuthProof != nil {
×
3812
                proof := edge.AuthProof
×
3813

×
3814
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3815
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3816
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3817
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3818
        }
×
3819

3820
        // Insert the new channel record.
3821
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3822
        if err != nil {
×
3823
                return nil, err
×
3824
        }
×
3825

3826
        // Insert any channel features.
3827
        for feature := range edge.Features.Features() {
×
3828
                err = db.InsertChannelFeature(
×
3829
                        ctx, sqlc.InsertChannelFeatureParams{
×
3830
                                ChannelID:  dbChanID,
×
3831
                                FeatureBit: int32(feature),
×
3832
                        },
×
3833
                )
×
3834
                if err != nil {
×
3835
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
3836
                                "feature(%v): %w", dbChanID, feature, err)
×
3837
                }
×
3838
        }
3839

3840
        // Finally, insert any extra TLV fields in the channel announcement.
3841
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3842
        if err != nil {
×
3843
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
3844
                        "data: %w", err)
×
3845
        }
×
3846

3847
        for tlvType, value := range extra {
×
3848
                err := db.CreateChannelExtraType(
×
3849
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3850
                                ChannelID: dbChanID,
×
3851
                                Type:      int64(tlvType),
×
3852
                                Value:     value,
×
3853
                        },
×
3854
                )
×
3855
                if err != nil {
×
3856
                        return nil, fmt.Errorf("unable to upsert "+
×
3857
                                "channel(%d) extra signed field(%v): %w",
×
3858
                                edge.ChannelID, tlvType, err)
×
3859
                }
×
3860
        }
3861

3862
        return &dbChanInfo{
×
3863
                channelID: dbChanID,
×
3864
                node1ID:   node1DBID,
×
3865
                node2ID:   node2DBID,
×
3866
        }, nil
×
3867
}
3868

3869
// maybeCreateShellNode checks if a shell node entry exists for the
3870
// given public key. If it does not exist, then a new shell node entry is
3871
// created. The ID of the node is returned. A shell node only has a protocol
3872
// version and public key persisted.
3873
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3874
        pubKey route.Vertex) (int64, error) {
×
3875

×
3876
        dbNode, err := db.GetNodeByPubKey(
×
3877
                ctx, sqlc.GetNodeByPubKeyParams{
×
3878
                        PubKey:  pubKey[:],
×
3879
                        Version: int16(ProtocolV1),
×
3880
                },
×
3881
        )
×
3882
        // The node exists. Return the ID.
×
3883
        if err == nil {
×
3884
                return dbNode.ID, nil
×
3885
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3886
                return 0, err
×
3887
        }
×
3888

3889
        // Otherwise, the node does not exist, so we create a shell entry for
3890
        // it.
3891
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3892
                Version: int16(ProtocolV1),
×
3893
                PubKey:  pubKey[:],
×
3894
        })
×
3895
        if err != nil {
×
3896
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3897
        }
×
3898

3899
        return id, nil
×
3900
}
3901

3902
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3903
// the database. This includes deleting any existing types and then inserting
3904
// the new types.
3905
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3906
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3907

×
3908
        // Delete all existing extra signed fields for the channel policy.
×
3909
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3910
        if err != nil {
×
3911
                return fmt.Errorf("unable to delete "+
×
3912
                        "existing policy extra signed fields for policy %d: %w",
×
3913
                        chanPolicyID, err)
×
3914
        }
×
3915

3916
        // Insert all new extra signed fields for the channel policy.
3917
        for tlvType, value := range extraFields {
×
3918
                err = db.InsertChanPolicyExtraType(
×
3919
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3920
                                ChannelPolicyID: chanPolicyID,
×
3921
                                Type:            int64(tlvType),
×
3922
                                Value:           value,
×
3923
                        },
×
3924
                )
×
3925
                if err != nil {
×
3926
                        return fmt.Errorf("unable to insert "+
×
3927
                                "channel_policy(%d) extra signed field(%v): %w",
×
3928
                                chanPolicyID, tlvType, err)
×
3929
                }
×
3930
        }
3931

3932
        return nil
×
3933
}
3934

3935
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3936
// provided dbChanRow and also fetches any other required information
3937
// to construct the edge info.
3938
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
3939
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1,
3940
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3941

×
3942
        if dbChan.Version != int16(ProtocolV1) {
×
3943
                return nil, fmt.Errorf("unsupported channel version: %d",
×
3944
                        dbChan.Version)
×
3945
        }
×
3946

3947
        fv, extras, err := getChanFeaturesAndExtras(
×
3948
                ctx, db, dbChanID,
×
3949
        )
×
3950
        if err != nil {
×
3951
                return nil, err
×
3952
        }
×
3953

3954
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
3955
        if err != nil {
×
3956
                return nil, err
×
3957
        }
×
3958

3959
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3960
        if err != nil {
×
3961
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3962
                        "fields: %w", err)
×
3963
        }
×
3964
        if recs == nil {
×
3965
                recs = make([]byte, 0)
×
3966
        }
×
3967

3968
        var btcKey1, btcKey2 route.Vertex
×
3969
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
3970
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
3971

×
3972
        channel := &models.ChannelEdgeInfo{
×
3973
                ChainHash:        chain,
×
3974
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
3975
                NodeKey1Bytes:    node1,
×
3976
                NodeKey2Bytes:    node2,
×
3977
                BitcoinKey1Bytes: btcKey1,
×
3978
                BitcoinKey2Bytes: btcKey2,
×
3979
                ChannelPoint:     *op,
×
3980
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
3981
                Features:         fv,
×
3982
                ExtraOpaqueData:  recs,
×
3983
        }
×
3984

×
3985
        // We always set all the signatures at the same time, so we can
×
3986
        // safely check if one signature is present to determine if we have the
×
3987
        // rest of the signatures for the auth proof.
×
3988
        if len(dbChan.Bitcoin1Signature) > 0 {
×
3989
                channel.AuthProof = &models.ChannelAuthProof{
×
3990
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
3991
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
3992
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
3993
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
3994
                }
×
3995
        }
×
3996

3997
        return channel, nil
×
3998
}
3999

4000
// buildNodeVertices is a helper that converts raw node public keys
4001
// into route.Vertex instances.
4002
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4003
        route.Vertex, error) {
×
4004

×
4005
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4006
        if err != nil {
×
4007
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4008
                        "create vertex from node1 pubkey: %w", err)
×
4009
        }
×
4010

4011
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4012
        if err != nil {
×
4013
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4014
                        "create vertex from node2 pubkey: %w", err)
×
4015
        }
×
4016

4017
        return node1Vertex, node2Vertex, nil
×
4018
}
4019

4020
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
4021
// for a channel with the given ID.
4022
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
4023
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
4024

×
4025
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
4026
        if err != nil {
×
4027
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
4028
                        "features and extras: %w", err)
×
4029
        }
×
4030

4031
        var (
×
4032
                fv     = lnwire.EmptyFeatureVector()
×
4033
                extras = make(map[uint64][]byte)
×
4034
        )
×
4035
        for _, row := range rows {
×
4036
                if row.IsFeature {
×
4037
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
4038

×
4039
                        continue
×
4040
                }
4041

4042
                tlvType, ok := row.ExtraKey.(int64)
×
4043
                if !ok {
×
4044
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4045
                                "TLV type: %T", row.ExtraKey)
×
4046
                }
×
4047

4048
                valueBytes, ok := row.Value.([]byte)
×
4049
                if !ok {
×
4050
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4051
                                "Value: %T", row.Value)
×
4052
                }
×
4053

4054
                extras[uint64(tlvType)] = valueBytes
×
4055
        }
4056

4057
        return fv, extras, nil
×
4058
}
4059

4060
// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves
4061
// all the extra info required to build the complete models.ChannelEdgePolicy
4062
// types. It returns two policies, which may be nil if the provided
4063
// sqlc.ChannelPolicy records are nil.
4064
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4065
        dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1,
4066
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4067
        *models.ChannelEdgePolicy, error) {
×
4068

×
4069
        if dbPol1 == nil && dbPol2 == nil {
×
4070
                return nil, nil, nil
×
4071
        }
×
4072

4073
        var (
×
4074
                policy1ID int64
×
4075
                policy2ID int64
×
4076
        )
×
4077
        if dbPol1 != nil {
×
4078
                policy1ID = dbPol1.ID
×
4079
        }
×
4080
        if dbPol2 != nil {
×
4081
                policy2ID = dbPol2.ID
×
4082
        }
×
4083
        rows, err := db.GetChannelPolicyExtraTypes(
×
4084
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
4085
                        ID:   policy1ID,
×
4086
                        ID_2: policy2ID,
×
4087
                },
×
4088
        )
×
4089
        if err != nil {
×
4090
                return nil, nil, err
×
4091
        }
×
4092

4093
        var (
×
4094
                dbPol1Extras = make(map[uint64][]byte)
×
4095
                dbPol2Extras = make(map[uint64][]byte)
×
4096
        )
×
4097
        for _, row := range rows {
×
4098
                switch row.PolicyID {
×
4099
                case policy1ID:
×
4100
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
4101
                case policy2ID:
×
4102
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
4103
                default:
×
4104
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
4105
                                "in row: %v", row.PolicyID, row)
×
4106
                }
4107
        }
4108

4109
        var pol1, pol2 *models.ChannelEdgePolicy
×
4110
        if dbPol1 != nil {
×
4111
                pol1, err = buildChanPolicy(
×
4112
                        *dbPol1, channelID, dbPol1Extras, node2,
×
4113
                )
×
4114
                if err != nil {
×
4115
                        return nil, nil, err
×
4116
                }
×
4117
        }
4118
        if dbPol2 != nil {
×
4119
                pol2, err = buildChanPolicy(
×
4120
                        *dbPol2, channelID, dbPol2Extras, node1,
×
4121
                )
×
4122
                if err != nil {
×
4123
                        return nil, nil, err
×
4124
                }
×
4125
        }
4126

4127
        return pol1, pol2, nil
×
4128
}
4129

4130
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4131
// provided sqlc.ChannelPolicy and other required information.
4132
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
4133
        extras map[uint64][]byte,
4134
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4135

×
4136
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4137
        if err != nil {
×
4138
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4139
                        "fields: %w", err)
×
4140
        }
×
4141

4142
        var inboundFee fn.Option[lnwire.Fee]
×
4143
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4144
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4145

×
4146
                inboundFee = fn.Some(lnwire.Fee{
×
4147
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4148
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4149
                })
×
4150
        }
×
4151

4152
        return &models.ChannelEdgePolicy{
×
4153
                SigBytes:  dbPolicy.Signature,
×
4154
                ChannelID: channelID,
×
4155
                LastUpdate: time.Unix(
×
4156
                        dbPolicy.LastUpdate.Int64, 0,
×
4157
                ),
×
4158
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4159
                        dbPolicy.MessageFlags,
×
4160
                ),
×
4161
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4162
                        dbPolicy.ChannelFlags,
×
4163
                ),
×
4164
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4165
                MinHTLC: lnwire.MilliSatoshi(
×
4166
                        dbPolicy.MinHtlcMsat,
×
4167
                ),
×
4168
                MaxHTLC: lnwire.MilliSatoshi(
×
4169
                        dbPolicy.MaxHtlcMsat.Int64,
×
4170
                ),
×
4171
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4172
                        dbPolicy.BaseFeeMsat,
×
4173
                ),
×
4174
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4175
                ToNode:                    toNode,
×
4176
                InboundFee:                inboundFee,
×
4177
                ExtraOpaqueData:           recs,
×
4178
        }, nil
×
4179
}
4180

4181
// buildNodes builds the models.LightningNode instances for the
4182
// given row which is expected to be a sqlc type that contains node information.
4183
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4184
        dbNode2 sqlc.Node) (*models.LightningNode, *models.LightningNode,
4185
        error) {
×
4186

×
4187
        node1, err := buildNode(ctx, db, &dbNode1)
×
4188
        if err != nil {
×
4189
                return nil, nil, err
×
4190
        }
×
4191

4192
        node2, err := buildNode(ctx, db, &dbNode2)
×
4193
        if err != nil {
×
4194
                return nil, nil, err
×
4195
        }
×
4196

4197
        return node1, node2, nil
×
4198
}
4199

4200
// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give
4201
// row which is expected to be a sqlc type that contains channel policy
4202
// information. It returns two policies, which may be nil if the policy
4203
// information is not present in the row.
4204
//
4205
//nolint:ll,dupl,funlen
4206
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
4207
        error) {
×
4208

×
4209
        var policy1, policy2 *sqlc.ChannelPolicy
×
4210
        switch r := row.(type) {
×
4211
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4212
                if r.Policy1ID.Valid {
×
4213
                        policy1 = &sqlc.ChannelPolicy{
×
4214
                                ID:                      r.Policy1ID.Int64,
×
4215
                                Version:                 r.Policy1Version.Int16,
×
4216
                                ChannelID:               r.Channel.ID,
×
4217
                                NodeID:                  r.Policy1NodeID.Int64,
×
4218
                                Timelock:                r.Policy1Timelock.Int32,
×
4219
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4220
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4221
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4222
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4223
                                LastUpdate:              r.Policy1LastUpdate,
×
4224
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4225
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4226
                                Disabled:                r.Policy1Disabled,
×
4227
                                MessageFlags:            r.Policy1MessageFlags,
×
4228
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4229
                                Signature:               r.Policy1Signature,
×
4230
                        }
×
4231
                }
×
4232
                if r.Policy2ID.Valid {
×
4233
                        policy2 = &sqlc.ChannelPolicy{
×
4234
                                ID:                      r.Policy2ID.Int64,
×
4235
                                Version:                 r.Policy2Version.Int16,
×
4236
                                ChannelID:               r.Channel.ID,
×
4237
                                NodeID:                  r.Policy2NodeID.Int64,
×
4238
                                Timelock:                r.Policy2Timelock.Int32,
×
4239
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4240
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4241
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4242
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4243
                                LastUpdate:              r.Policy2LastUpdate,
×
4244
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4245
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4246
                                Disabled:                r.Policy2Disabled,
×
4247
                                MessageFlags:            r.Policy2MessageFlags,
×
4248
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4249
                                Signature:               r.Policy2Signature,
×
4250
                        }
×
4251
                }
×
4252

4253
                return policy1, policy2, nil
×
4254

4255
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4256
                if r.Policy1ID.Valid {
×
4257
                        policy1 = &sqlc.ChannelPolicy{
×
4258
                                ID:                      r.Policy1ID.Int64,
×
4259
                                Version:                 r.Policy1Version.Int16,
×
4260
                                ChannelID:               r.Channel.ID,
×
4261
                                NodeID:                  r.Policy1NodeID.Int64,
×
4262
                                Timelock:                r.Policy1Timelock.Int32,
×
4263
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4264
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4265
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4266
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4267
                                LastUpdate:              r.Policy1LastUpdate,
×
4268
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4269
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4270
                                Disabled:                r.Policy1Disabled,
×
4271
                                MessageFlags:            r.Policy1MessageFlags,
×
4272
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4273
                                Signature:               r.Policy1Signature,
×
4274
                        }
×
4275
                }
×
4276
                if r.Policy2ID.Valid {
×
4277
                        policy2 = &sqlc.ChannelPolicy{
×
4278
                                ID:                      r.Policy2ID.Int64,
×
4279
                                Version:                 r.Policy2Version.Int16,
×
4280
                                ChannelID:               r.Channel.ID,
×
4281
                                NodeID:                  r.Policy2NodeID.Int64,
×
4282
                                Timelock:                r.Policy2Timelock.Int32,
×
4283
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4284
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4285
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4286
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4287
                                LastUpdate:              r.Policy2LastUpdate,
×
4288
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4289
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4290
                                Disabled:                r.Policy2Disabled,
×
4291
                                MessageFlags:            r.Policy2MessageFlags,
×
4292
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4293
                                Signature:               r.Policy2Signature,
×
4294
                        }
×
4295
                }
×
4296

4297
                return policy1, policy2, nil
×
4298

4299
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4300
                if r.Policy1ID.Valid {
×
4301
                        policy1 = &sqlc.ChannelPolicy{
×
4302
                                ID:                      r.Policy1ID.Int64,
×
4303
                                Version:                 r.Policy1Version.Int16,
×
4304
                                ChannelID:               r.Channel.ID,
×
4305
                                NodeID:                  r.Policy1NodeID.Int64,
×
4306
                                Timelock:                r.Policy1Timelock.Int32,
×
4307
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4308
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4309
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4310
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4311
                                LastUpdate:              r.Policy1LastUpdate,
×
4312
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4313
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4314
                                Disabled:                r.Policy1Disabled,
×
4315
                                MessageFlags:            r.Policy1MessageFlags,
×
4316
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4317
                                Signature:               r.Policy1Signature,
×
4318
                        }
×
4319
                }
×
4320
                if r.Policy2ID.Valid {
×
4321
                        policy2 = &sqlc.ChannelPolicy{
×
4322
                                ID:                      r.Policy2ID.Int64,
×
4323
                                Version:                 r.Policy2Version.Int16,
×
4324
                                ChannelID:               r.Channel.ID,
×
4325
                                NodeID:                  r.Policy2NodeID.Int64,
×
4326
                                Timelock:                r.Policy2Timelock.Int32,
×
4327
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4328
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4329
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4330
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4331
                                LastUpdate:              r.Policy2LastUpdate,
×
4332
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4333
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4334
                                Disabled:                r.Policy2Disabled,
×
4335
                                MessageFlags:            r.Policy2MessageFlags,
×
4336
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4337
                                Signature:               r.Policy2Signature,
×
4338
                        }
×
4339
                }
×
4340

4341
                return policy1, policy2, nil
×
4342

4343
        case sqlc.ListChannelsByNodeIDRow:
×
4344
                if r.Policy1ID.Valid {
×
4345
                        policy1 = &sqlc.ChannelPolicy{
×
4346
                                ID:                      r.Policy1ID.Int64,
×
4347
                                Version:                 r.Policy1Version.Int16,
×
4348
                                ChannelID:               r.Channel.ID,
×
4349
                                NodeID:                  r.Policy1NodeID.Int64,
×
4350
                                Timelock:                r.Policy1Timelock.Int32,
×
4351
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4352
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4353
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4354
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4355
                                LastUpdate:              r.Policy1LastUpdate,
×
4356
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4357
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4358
                                Disabled:                r.Policy1Disabled,
×
4359
                                MessageFlags:            r.Policy1MessageFlags,
×
4360
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4361
                                Signature:               r.Policy1Signature,
×
4362
                        }
×
4363
                }
×
4364
                if r.Policy2ID.Valid {
×
4365
                        policy2 = &sqlc.ChannelPolicy{
×
4366
                                ID:                      r.Policy2ID.Int64,
×
4367
                                Version:                 r.Policy2Version.Int16,
×
4368
                                ChannelID:               r.Channel.ID,
×
4369
                                NodeID:                  r.Policy2NodeID.Int64,
×
4370
                                Timelock:                r.Policy2Timelock.Int32,
×
4371
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4372
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4373
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4374
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4375
                                LastUpdate:              r.Policy2LastUpdate,
×
4376
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4377
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4378
                                Disabled:                r.Policy2Disabled,
×
4379
                                MessageFlags:            r.Policy2MessageFlags,
×
4380
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4381
                                Signature:               r.Policy2Signature,
×
4382
                        }
×
4383
                }
×
4384

4385
                return policy1, policy2, nil
×
4386

4387
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4388
                if r.Policy1ID.Valid {
×
4389
                        policy1 = &sqlc.ChannelPolicy{
×
4390
                                ID:                      r.Policy1ID.Int64,
×
4391
                                Version:                 r.Policy1Version.Int16,
×
4392
                                ChannelID:               r.Channel.ID,
×
4393
                                NodeID:                  r.Policy1NodeID.Int64,
×
4394
                                Timelock:                r.Policy1Timelock.Int32,
×
4395
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4396
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4397
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4398
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4399
                                LastUpdate:              r.Policy1LastUpdate,
×
4400
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4401
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4402
                                Disabled:                r.Policy1Disabled,
×
4403
                                MessageFlags:            r.Policy1MessageFlags,
×
4404
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4405
                                Signature:               r.Policy1Signature,
×
4406
                        }
×
4407
                }
×
4408
                if r.Policy2ID.Valid {
×
4409
                        policy2 = &sqlc.ChannelPolicy{
×
4410
                                ID:                      r.Policy2ID.Int64,
×
4411
                                Version:                 r.Policy2Version.Int16,
×
4412
                                ChannelID:               r.Channel.ID,
×
4413
                                NodeID:                  r.Policy2NodeID.Int64,
×
4414
                                Timelock:                r.Policy2Timelock.Int32,
×
4415
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4416
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4417
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4418
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4419
                                LastUpdate:              r.Policy2LastUpdate,
×
4420
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4421
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4422
                                Disabled:                r.Policy2Disabled,
×
4423
                                MessageFlags:            r.Policy2MessageFlags,
×
4424
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4425
                                Signature:               r.Policy2Signature,
×
4426
                        }
×
4427
                }
×
4428

4429
                return policy1, policy2, nil
×
4430
        default:
×
4431
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4432
                        "extractChannelPolicies: %T", r)
×
4433
        }
4434
}
4435

4436
// channelIDToBytes converts a channel ID (SCID) to a byte array
4437
// representation.
4438
func channelIDToBytes(channelID uint64) []byte {
×
4439
        var chanIDB [8]byte
×
4440
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4441

×
4442
        return chanIDB[:]
×
4443
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc