• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16023352958

02 Jul 2025 11:02AM UTC coverage: 57.603% (-0.2%) from 57.803%
16023352958

Pull #10025

github

web-flow
Merge d7fd9e180 into 1d2e5472b
Pull Request #10025: [draft] graph/db: kvdb -> SQL migration

15 of 608 new or added lines in 8 files covered. (2.47%)

71 existing lines in 13 files now uncovered.

98475 of 170954 relevant lines covered (57.6%)

1.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
72
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
73
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
74

75
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
96
        GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error)
97
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
98
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
99
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
100
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
101
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
102
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
103
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
104
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
105
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
106
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
107
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
108
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
109
        DeleteChannel(ctx context.Context, id int64) error
110

111
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
112
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
113

114
        /*
115
                Channel Policy table queries.
116
        */
117
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
118
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.ChannelPolicy, error)
119
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
120

121
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
122
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
123
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
124

125
        /*
126
                Zombie index queries.
127
        */
128
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
129
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error)
130
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
131
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
132
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
133

134
        /*
135
                Prune log table queries.
136
        */
137
        GetPruneTip(ctx context.Context) (sqlc.PruneLog, error)
138
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
139
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
140
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
141

142
        /*
143
                Closed SCID table queries.
144
        */
145
        InsertClosedChannel(ctx context.Context, scid []byte) error
146
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
147
}
148

149
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
150
// database operations.
151
type BatchedSQLQueries interface {
152
        SQLQueries
153
        sqldb.BatchedTx[SQLQueries]
154
}
155

156
// SQLStore is an implementation of the V1Store interface that uses a SQL
157
// database as the backend.
158
type SQLStore struct {
159
        cfg *SQLStoreConfig
160
        db  BatchedSQLQueries
161

162
        // cacheMu guards all caches (rejectCache and chanCache). If
163
        // this mutex will be acquired at the same time as the DB mutex then
164
        // the cacheMu MUST be acquired first to prevent deadlock.
165
        cacheMu     sync.RWMutex
166
        rejectCache *rejectCache
167
        chanCache   *channelCache
168

169
        chanScheduler batch.Scheduler[SQLQueries]
170
        nodeScheduler batch.Scheduler[SQLQueries]
171

172
        srcNodes  map[ProtocolVersion]*srcNodeInfo
173
        srcNodeMu sync.Mutex
174
}
175

176
// A compile-time assertion to ensure that SQLStore implements the V1Store
177
// interface.
178
var _ V1Store = (*SQLStore)(nil)
179

180
// SQLStoreConfig holds the configuration for the SQLStore.
181
type SQLStoreConfig struct {
182
        // ChainHash is the genesis hash for the chain that all the gossip
183
        // messages in this store are aimed at.
184
        ChainHash chainhash.Hash
185
}
186

187
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
188
// storage backend.
189
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
190
        options ...StoreOptionModifier) (*SQLStore, error) {
×
191

×
192
        opts := DefaultOptions()
×
193
        for _, o := range options {
×
194
                o(opts)
×
195
        }
×
196

197
        if opts.NoMigration {
×
198
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
199
                        "supported for SQL stores")
×
200
        }
×
201

202
        s := &SQLStore{
×
203
                cfg:         cfg,
×
204
                db:          db,
×
205
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
206
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
207
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
208
        }
×
209

×
210
        s.chanScheduler = batch.NewTimeScheduler(
×
211
                db, &s.cacheMu, opts.BatchCommitInterval,
×
212
        )
×
213
        s.nodeScheduler = batch.NewTimeScheduler(
×
214
                db, nil, opts.BatchCommitInterval,
×
215
        )
×
216

×
217
        return s, nil
×
218
}
219

220
// AddLightningNode adds a vertex/node to the graph database. If the node is not
221
// in the database from before, this will add a new, unconnected one to the
222
// graph. If it is present from before, this will update that node's
223
// information.
224
//
225
// NOTE: part of the V1Store interface.
226
func (s *SQLStore) AddLightningNode(ctx context.Context,
227
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
228

×
229
        r := &batch.Request[SQLQueries]{
×
230
                Opts: batch.NewSchedulerOptions(opts...),
×
231
                Do: func(queries SQLQueries) error {
×
232
                        _, err := upsertNode(ctx, queries, node)
×
233
                        return err
×
234
                },
×
235
        }
236

237
        return s.nodeScheduler.Execute(ctx, r)
×
238
}
239

240
// FetchLightningNode attempts to look up a target node by its identity public
241
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
242
// returned.
243
//
244
// NOTE: part of the V1Store interface.
245
func (s *SQLStore) FetchLightningNode(ctx context.Context,
246
        pubKey route.Vertex) (*models.LightningNode, error) {
×
247

×
248
        var node *models.LightningNode
×
249
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
250
                var err error
×
251
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
252

×
253
                return err
×
254
        }, sqldb.NoOpReset)
×
255
        if err != nil {
×
256
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
257
        }
×
258

259
        return node, nil
×
260
}
261

262
// HasLightningNode determines if the graph has a vertex identified by the
263
// target node identity public key. If the node exists in the database, a
264
// timestamp of when the data for the node was lasted updated is returned along
265
// with a true boolean. Otherwise, an empty time.Time is returned with a false
266
// boolean.
267
//
268
// NOTE: part of the V1Store interface.
269
func (s *SQLStore) HasLightningNode(ctx context.Context,
270
        pubKey [33]byte) (time.Time, bool, error) {
×
271

×
272
        var (
×
273
                exists     bool
×
274
                lastUpdate time.Time
×
275
        )
×
276
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
277
                dbNode, err := db.GetNodeByPubKey(
×
278
                        ctx, sqlc.GetNodeByPubKeyParams{
×
279
                                Version: int16(ProtocolV1),
×
280
                                PubKey:  pubKey[:],
×
281
                        },
×
282
                )
×
283
                if errors.Is(err, sql.ErrNoRows) {
×
284
                        return nil
×
285
                } else if err != nil {
×
286
                        return fmt.Errorf("unable to fetch node: %w", err)
×
287
                }
×
288

289
                exists = true
×
290

×
291
                if dbNode.LastUpdate.Valid {
×
292
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
293
                }
×
294

295
                return nil
×
296
        }, sqldb.NoOpReset)
297
        if err != nil {
×
298
                return time.Time{}, false,
×
299
                        fmt.Errorf("unable to fetch node: %w", err)
×
300
        }
×
301

302
        return lastUpdate, exists, nil
×
303
}
304

305
// AddrsForNode returns all known addresses for the target node public key
306
// that the graph DB is aware of. The returned boolean indicates if the
307
// given node is unknown to the graph DB or not.
308
//
309
// NOTE: part of the V1Store interface.
310
func (s *SQLStore) AddrsForNode(ctx context.Context,
311
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
312

×
313
        var (
×
314
                addresses []net.Addr
×
315
                known     bool
×
316
        )
×
317
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
318
                var err error
×
319
                known, addresses, err = getNodeAddresses(
×
320
                        ctx, db, nodePub.SerializeCompressed(),
×
321
                )
×
322
                if err != nil {
×
323
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
324
                                err)
×
325
                }
×
326

327
                return nil
×
328
        }, sqldb.NoOpReset)
329
        if err != nil {
×
330
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
331
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
332
        }
×
333

334
        return known, addresses, nil
×
335
}
336

337
// DeleteLightningNode starts a new database transaction to remove a vertex/node
338
// from the database according to the node's public key.
339
//
340
// NOTE: part of the V1Store interface.
341
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
342
        pubKey route.Vertex) error {
×
343

×
344
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
345
                res, err := db.DeleteNodeByPubKey(
×
346
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
347
                                Version: int16(ProtocolV1),
×
348
                                PubKey:  pubKey[:],
×
349
                        },
×
350
                )
×
351
                if err != nil {
×
352
                        return err
×
353
                }
×
354

355
                rows, err := res.RowsAffected()
×
356
                if err != nil {
×
357
                        return err
×
358
                }
×
359

360
                if rows == 0 {
×
361
                        return ErrGraphNodeNotFound
×
362
                } else if rows > 1 {
×
363
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
364
                }
×
365

366
                return err
×
367
        }, sqldb.NoOpReset)
368
        if err != nil {
×
369
                return fmt.Errorf("unable to delete node: %w", err)
×
370
        }
×
371

372
        return nil
×
373
}
374

375
// FetchNodeFeatures returns the features of the given node. If no features are
376
// known for the node, an empty feature vector is returned.
377
//
378
// NOTE: this is part of the graphdb.NodeTraverser interface.
379
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
380
        *lnwire.FeatureVector, error) {
×
381

×
382
        ctx := context.TODO()
×
383

×
384
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
385
}
×
386

387
// DisabledChannelIDs returns the channel ids of disabled channels.
388
// A channel is disabled when two of the associated ChanelEdgePolicies
389
// have their disabled bit on.
390
//
391
// NOTE: part of the V1Store interface.
392
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
393
        var (
×
394
                ctx     = context.TODO()
×
395
                chanIDs []uint64
×
396
        )
×
397
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
398
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
399
                if err != nil {
×
400
                        return fmt.Errorf("unable to fetch disabled "+
×
401
                                "channels: %w", err)
×
402
                }
×
403

404
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
405

×
406
                return nil
×
407
        }, sqldb.NoOpReset)
408
        if err != nil {
×
409
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
410
                        err)
×
411
        }
×
412

413
        return chanIDs, nil
×
414
}
415

416
// LookupAlias attempts to return the alias as advertised by the target node.
417
//
418
// NOTE: part of the V1Store interface.
419
func (s *SQLStore) LookupAlias(ctx context.Context,
420
        pub *btcec.PublicKey) (string, error) {
×
421

×
422
        var alias string
×
423
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
424
                dbNode, err := db.GetNodeByPubKey(
×
425
                        ctx, sqlc.GetNodeByPubKeyParams{
×
426
                                Version: int16(ProtocolV1),
×
427
                                PubKey:  pub.SerializeCompressed(),
×
428
                        },
×
429
                )
×
430
                if errors.Is(err, sql.ErrNoRows) {
×
431
                        return ErrNodeAliasNotFound
×
432
                } else if err != nil {
×
433
                        return fmt.Errorf("unable to fetch node: %w", err)
×
434
                }
×
435

436
                if !dbNode.Alias.Valid {
×
437
                        return ErrNodeAliasNotFound
×
438
                }
×
439

440
                alias = dbNode.Alias.String
×
441

×
442
                return nil
×
443
        }, sqldb.NoOpReset)
444
        if err != nil {
×
445
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
446
        }
×
447

448
        return alias, nil
×
449
}
450

451
// SourceNode returns the source node of the graph. The source node is treated
452
// as the center node within a star-graph. This method may be used to kick off
453
// a path finding algorithm in order to explore the reachability of another
454
// node based off the source node.
455
//
456
// NOTE: part of the V1Store interface.
457
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
458
        error) {
×
459

×
460
        var node *models.LightningNode
×
461
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
462
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
463
                if err != nil {
×
464
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
465
                                err)
×
466
                }
×
467

468
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
469

×
470
                return err
×
471
        }, sqldb.NoOpReset)
472
        if err != nil {
×
473
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
474
        }
×
475

476
        return node, nil
×
477
}
478

479
// SetSourceNode sets the source node within the graph database. The source
480
// node is to be used as the center of a star-graph within path finding
481
// algorithms.
482
//
483
// NOTE: part of the V1Store interface.
484
func (s *SQLStore) SetSourceNode(ctx context.Context,
485
        node *models.LightningNode) error {
×
486

×
487
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
488
                id, err := upsertNode(ctx, db, node)
×
489
                if err != nil {
×
490
                        return fmt.Errorf("unable to upsert source node: %w",
×
491
                                err)
×
492
                }
×
493

494
                // Make sure that if a source node for this version is already
495
                // set, then the ID is the same as the one we are about to set.
496
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
497
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
498
                        return fmt.Errorf("unable to fetch source node: %w",
×
499
                                err)
×
500
                } else if err == nil {
×
501
                        if dbSourceNodeID != id {
×
502
                                return fmt.Errorf("v1 source node already "+
×
503
                                        "set to a different node: %d vs %d",
×
504
                                        dbSourceNodeID, id)
×
505
                        }
×
506

507
                        return nil
×
508
                }
509

510
                return db.AddSourceNode(ctx, id)
×
511
        }, sqldb.NoOpReset)
512
}
513

514
// NodeUpdatesInHorizon returns all the known lightning node which have an
515
// update timestamp within the passed range. This method can be used by two
516
// nodes to quickly determine if they have the same set of up to date node
517
// announcements.
518
//
519
// NOTE: This is part of the V1Store interface.
520
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
521
        endTime time.Time) ([]models.LightningNode, error) {
×
522

×
523
        ctx := context.TODO()
×
524

×
525
        var nodes []models.LightningNode
×
526
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
527
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
528
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
529
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
530
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
531
                        },
×
532
                )
×
533
                if err != nil {
×
534
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
535
                }
×
536

537
                for _, dbNode := range dbNodes {
×
538
                        node, err := buildNode(ctx, db, &dbNode)
×
539
                        if err != nil {
×
540
                                return fmt.Errorf("unable to build node: %w",
×
541
                                        err)
×
542
                        }
×
543

544
                        nodes = append(nodes, *node)
×
545
                }
546

547
                return nil
×
548
        }, sqldb.NoOpReset)
549
        if err != nil {
×
550
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
551
        }
×
552

553
        return nodes, nil
×
554
}
555

556
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
557
// undirected edge from the two target nodes are created. The information stored
558
// denotes the static attributes of the channel, such as the channelID, the keys
559
// involved in creation of the channel, and the set of features that the channel
560
// supports. The chanPoint and chanID are used to uniquely identify the edge
561
// globally within the database.
562
//
563
// NOTE: part of the V1Store interface.
564
func (s *SQLStore) AddChannelEdge(ctx context.Context,
565
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
566

×
567
        var alreadyExists bool
×
568
        r := &batch.Request[SQLQueries]{
×
569
                Opts: batch.NewSchedulerOptions(opts...),
×
570
                Reset: func() {
×
571
                        alreadyExists = false
×
572
                },
×
573
                Do: func(tx SQLQueries) error {
×
574
                        _, err := insertChannel(ctx, tx, edge)
×
575

×
576
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
577
                        // succeed, but propagate the error via local state.
×
578
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
579
                                alreadyExists = true
×
580
                                return nil
×
581
                        }
×
582

583
                        return err
×
584
                },
585
                OnCommit: func(err error) error {
×
586
                        switch {
×
587
                        case err != nil:
×
588
                                return err
×
589
                        case alreadyExists:
×
590
                                return ErrEdgeAlreadyExist
×
591
                        default:
×
592
                                s.rejectCache.remove(edge.ChannelID)
×
593
                                s.chanCache.remove(edge.ChannelID)
×
594
                                return nil
×
595
                        }
596
                },
597
        }
598

599
        return s.chanScheduler.Execute(ctx, r)
×
600
}
601

602
// HighestChanID returns the "highest" known channel ID in the channel graph.
603
// This represents the "newest" channel from the PoV of the chain. This method
604
// can be used by peers to quickly determine if their graphs are in sync.
605
//
606
// NOTE: This is part of the V1Store interface.
607
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
608
        var highestChanID uint64
×
609
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
610
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
611
                if errors.Is(err, sql.ErrNoRows) {
×
612
                        return nil
×
613
                } else if err != nil {
×
614
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
615
                                err)
×
616
                }
×
617

618
                highestChanID = byteOrder.Uint64(chanID)
×
619

×
620
                return nil
×
621
        }, sqldb.NoOpReset)
622
        if err != nil {
×
623
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
624
        }
×
625

626
        return highestChanID, nil
×
627
}
628

629
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
630
// within the database for the referenced channel. The `flags` attribute within
631
// the ChannelEdgePolicy determines which of the directed edges are being
632
// updated. If the flag is 1, then the first node's information is being
633
// updated, otherwise it's the second node's information. The node ordering is
634
// determined by the lexicographical ordering of the identity public keys of the
635
// nodes on either side of the channel.
636
//
637
// NOTE: part of the V1Store interface.
638
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
639
        edge *models.ChannelEdgePolicy,
640
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
641

×
642
        var (
×
643
                isUpdate1    bool
×
644
                edgeNotFound bool
×
645
                from, to     route.Vertex
×
646
        )
×
647

×
648
        r := &batch.Request[SQLQueries]{
×
649
                Opts: batch.NewSchedulerOptions(opts...),
×
650
                Reset: func() {
×
651
                        isUpdate1 = false
×
652
                        edgeNotFound = false
×
653
                },
×
654
                Do: func(tx SQLQueries) error {
×
655
                        var err error
×
656
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
657
                                ctx, tx, edge,
×
658
                        )
×
659
                        if err != nil {
×
660
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
661
                        }
×
662

663
                        // Silence ErrEdgeNotFound so that the batch can
664
                        // succeed, but propagate the error via local state.
665
                        if errors.Is(err, ErrEdgeNotFound) {
×
666
                                edgeNotFound = true
×
667
                                return nil
×
668
                        }
×
669

670
                        return err
×
671
                },
672
                OnCommit: func(err error) error {
×
673
                        switch {
×
674
                        case err != nil:
×
675
                                return err
×
676
                        case edgeNotFound:
×
677
                                return ErrEdgeNotFound
×
678
                        default:
×
679
                                s.updateEdgeCache(edge, isUpdate1)
×
680
                                return nil
×
681
                        }
682
                },
683
        }
684

685
        err := s.chanScheduler.Execute(ctx, r)
×
686

×
687
        return from, to, err
×
688
}
689

690
// updateEdgeCache updates our reject and channel caches with the new
691
// edge policy information.
692
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
693
        isUpdate1 bool) {
×
694

×
695
        // If an entry for this channel is found in reject cache, we'll modify
×
696
        // the entry with the updated timestamp for the direction that was just
×
697
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
698
        // during the next query for this edge.
×
699
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
700
                if isUpdate1 {
×
701
                        entry.upd1Time = e.LastUpdate.Unix()
×
702
                } else {
×
703
                        entry.upd2Time = e.LastUpdate.Unix()
×
704
                }
×
705
                s.rejectCache.insert(e.ChannelID, entry)
×
706
        }
707

708
        // If an entry for this channel is found in channel cache, we'll modify
709
        // the entry with the updated policy for the direction that was just
710
        // written. If the edge doesn't exist, we'll defer loading the info and
711
        // policies and lazily read from disk during the next query.
712
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
713
                if isUpdate1 {
×
714
                        channel.Policy1 = e
×
715
                } else {
×
716
                        channel.Policy2 = e
×
717
                }
×
718
                s.chanCache.insert(e.ChannelID, channel)
×
719
        }
720
}
721

722
// ForEachSourceNodeChannel iterates through all channels of the source node,
723
// executing the passed callback on each. The call-back is provided with the
724
// channel's outpoint, whether we have a policy for the channel and the channel
725
// peer's node information.
726
//
727
// NOTE: part of the V1Store interface.
728
func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint,
729
        havePolicy bool, otherNode *models.LightningNode) error) error {
×
730

×
731
        var ctx = context.TODO()
×
732

×
733
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
734
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
735
                if err != nil {
×
736
                        return fmt.Errorf("unable to fetch source node: %w",
×
737
                                err)
×
738
                }
×
739

740
                return forEachNodeChannel(
×
741
                        ctx, db, s.cfg.ChainHash, nodeID,
×
742
                        func(info *models.ChannelEdgeInfo,
×
743
                                outPolicy *models.ChannelEdgePolicy,
×
744
                                _ *models.ChannelEdgePolicy) error {
×
745

×
746
                                // Fetch the other node.
×
747
                                var (
×
748
                                        otherNodePub [33]byte
×
749
                                        node1        = info.NodeKey1Bytes
×
750
                                        node2        = info.NodeKey2Bytes
×
751
                                )
×
752
                                switch {
×
753
                                case bytes.Equal(node1[:], nodePub[:]):
×
754
                                        otherNodePub = node2
×
755
                                case bytes.Equal(node2[:], nodePub[:]):
×
756
                                        otherNodePub = node1
×
757
                                default:
×
758
                                        return fmt.Errorf("node not " +
×
759
                                                "participating in this channel")
×
760
                                }
761

762
                                _, otherNode, err := getNodeByPubKey(
×
763
                                        ctx, db, otherNodePub,
×
764
                                )
×
765
                                if err != nil {
×
766
                                        return fmt.Errorf("unable to fetch "+
×
767
                                                "other node(%x): %w",
×
768
                                                otherNodePub, err)
×
769
                                }
×
770

771
                                return cb(
×
772
                                        info.ChannelPoint, outPolicy != nil,
×
773
                                        otherNode,
×
774
                                )
×
775
                        },
776
                )
777
        }, sqldb.NoOpReset)
778
}
779

780
// ForEachNode iterates through all the stored vertices/nodes in the graph,
781
// executing the passed callback with each node encountered. If the callback
782
// returns an error, then the transaction is aborted and the iteration stops
783
// early. Any operations performed on the NodeTx passed to the call-back are
784
// executed under the same read transaction and so, methods on the NodeTx object
785
// _MUST_ only be called from within the call-back.
786
//
787
// NOTE: part of the V1Store interface.
788
func (s *SQLStore) ForEachNode(cb func(tx NodeRTx) error) error {
×
789
        var (
×
790
                ctx          = context.TODO()
×
791
                lastID int64 = 0
×
792
        )
×
793

×
794
        handleNode := func(db SQLQueries, dbNode sqlc.Node) error {
×
795
                node, err := buildNode(ctx, db, &dbNode)
×
796
                if err != nil {
×
797
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
798
                                dbNode.ID, err)
×
799
                }
×
800

801
                err = cb(
×
802
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
803
                )
×
804
                if err != nil {
×
805
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
806
                                dbNode.ID, err)
×
807
                }
×
808

809
                return nil
×
810
        }
811

812
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
813
                for {
×
814
                        nodes, err := db.ListNodesPaginated(
×
815
                                ctx, sqlc.ListNodesPaginatedParams{
×
816
                                        Version: int16(ProtocolV1),
×
817
                                        ID:      lastID,
×
818
                                        Limit:   pageSize,
×
819
                                },
×
820
                        )
×
821
                        if err != nil {
×
822
                                return fmt.Errorf("unable to fetch nodes: %w",
×
823
                                        err)
×
824
                        }
×
825

826
                        if len(nodes) == 0 {
×
827
                                break
×
828
                        }
829

830
                        for _, dbNode := range nodes {
×
831
                                err = handleNode(db, dbNode)
×
832
                                if err != nil {
×
833
                                        return err
×
834
                                }
×
835

836
                                lastID = dbNode.ID
×
837
                        }
838
                }
839

840
                return nil
×
841
        }, sqldb.NoOpReset)
842
}
843

844
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
845
// SQLStore and a SQL transaction.
846
type sqlGraphNodeTx struct {
847
        db    SQLQueries
848
        id    int64
849
        node  *models.LightningNode
850
        chain chainhash.Hash
851
}
852

853
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
854
// interface.
855
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
856

857
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
858
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
859

×
860
        return &sqlGraphNodeTx{
×
861
                db:    db,
×
862
                chain: chain,
×
863
                id:    id,
×
864
                node:  node,
×
865
        }
×
866
}
×
867

868
// Node returns the raw information of the node.
869
//
870
// NOTE: This is a part of the NodeRTx interface.
871
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
872
        return s.node
×
873
}
×
874

875
// ForEachChannel can be used to iterate over the node's channels under the same
876
// transaction used to fetch the node.
877
//
878
// NOTE: This is a part of the NodeRTx interface.
879
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
880
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
881

×
882
        ctx := context.TODO()
×
883

×
884
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
885
}
×
886

887
// FetchNode fetches the node with the given pub key under the same transaction
888
// used to fetch the current node. The returned node is also a NodeRTx and any
889
// operations on that NodeRTx will also be done under the same transaction.
890
//
891
// NOTE: This is a part of the NodeRTx interface.
892
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
893
        ctx := context.TODO()
×
894

×
895
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
896
        if err != nil {
×
897
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
898
                        nodePub, err)
×
899
        }
×
900

901
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
902
}
903

904
// ForEachNodeDirectedChannel iterates through all channels of a given node,
905
// executing the passed callback on the directed edge representing the channel
906
// and its incoming policy. If the callback returns an error, then the iteration
907
// is halted with the error propagated back up to the caller.
908
//
909
// Unknown policies are passed into the callback as nil values.
910
//
911
// NOTE: this is part of the graphdb.NodeTraverser interface.
912
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
913
        cb func(channel *DirectedChannel) error) error {
×
914

×
915
        var ctx = context.TODO()
×
916

×
917
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
918
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
919
        }, sqldb.NoOpReset)
×
920
}
921

922
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
923
// graph, executing the passed callback with each node encountered. If the
924
// callback returns an error, then the transaction is aborted and the iteration
925
// stops early.
926
//
927
// NOTE: This is a part of the V1Store interface.
928
func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex,
929
        *lnwire.FeatureVector) error) error {
×
930

×
931
        ctx := context.TODO()
×
932

×
933
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
934
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
935
                        nodePub route.Vertex) error {
×
936

×
937
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
938
                        if err != nil {
×
939
                                return fmt.Errorf("unable to fetch node "+
×
940
                                        "features: %w", err)
×
941
                        }
×
942

943
                        return cb(nodePub, features)
×
944
                })
945
        }, sqldb.NoOpReset)
946
        if err != nil {
×
947
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
948
        }
×
949

950
        return nil
×
951
}
952

953
// ForEachNodeChannel iterates through all channels of the given node,
954
// executing the passed callback with an edge info structure and the policies
955
// of each end of the channel. The first edge policy is the outgoing edge *to*
956
// the connecting node, while the second is the incoming edge *from* the
957
// connecting node. If the callback returns an error, then the iteration is
958
// halted with the error propagated back up to the caller.
959
//
960
// Unknown policies are passed into the callback as nil values.
961
//
962
// NOTE: part of the V1Store interface.
963
func (s *SQLStore) ForEachNodeChannel(nodePub route.Vertex,
964
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
965
                *models.ChannelEdgePolicy) error) error {
×
966

×
967
        var ctx = context.TODO()
×
968

×
969
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
970
                dbNode, err := db.GetNodeByPubKey(
×
971
                        ctx, sqlc.GetNodeByPubKeyParams{
×
972
                                Version: int16(ProtocolV1),
×
973
                                PubKey:  nodePub[:],
×
974
                        },
×
975
                )
×
976
                if errors.Is(err, sql.ErrNoRows) {
×
977
                        return nil
×
978
                } else if err != nil {
×
979
                        return fmt.Errorf("unable to fetch node: %w", err)
×
980
                }
×
981

982
                return forEachNodeChannel(
×
983
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
984
                )
×
985
        }, sqldb.NoOpReset)
986
}
987

988
// ChanUpdatesInHorizon returns all the known channel edges which have at least
989
// one edge that has an update timestamp within the specified horizon.
990
//
991
// NOTE: This is part of the V1Store interface.
992
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
993
        endTime time.Time) ([]ChannelEdge, error) {
×
994

×
995
        s.cacheMu.Lock()
×
996
        defer s.cacheMu.Unlock()
×
997

×
998
        var (
×
999
                ctx = context.TODO()
×
1000
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
1001
                // an additional map to keep track of the edges already seen to
×
1002
                // prevent re-adding it.
×
1003
                edgesSeen    = make(map[uint64]struct{})
×
1004
                edgesToCache = make(map[uint64]ChannelEdge)
×
1005
                edges        []ChannelEdge
×
1006
                hits         int
×
1007
        )
×
1008
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1009
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1010
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1011
                                Version:   int16(ProtocolV1),
×
1012
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1013
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1014
                        },
×
1015
                )
×
1016
                if err != nil {
×
1017
                        return err
×
1018
                }
×
1019

1020
                for _, row := range rows {
×
1021
                        // If we've already retrieved the info and policies for
×
1022
                        // this edge, then we can skip it as we don't need to do
×
1023
                        // so again.
×
1024
                        chanIDInt := byteOrder.Uint64(row.Channel.Scid)
×
1025
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1026
                                continue
×
1027
                        }
1028

1029
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1030
                                hits++
×
1031
                                edgesSeen[chanIDInt] = struct{}{}
×
1032
                                edges = append(edges, channel)
×
1033

×
1034
                                continue
×
1035
                        }
1036

1037
                        node1, node2, err := buildNodes(
×
1038
                                ctx, db, row.Node, row.Node_2,
×
1039
                        )
×
1040
                        if err != nil {
×
1041
                                return err
×
1042
                        }
×
1043

1044
                        channel, err := getAndBuildEdgeInfo(
×
1045
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1046
                                row.Channel, node1.PubKeyBytes,
×
1047
                                node2.PubKeyBytes,
×
1048
                        )
×
1049
                        if err != nil {
×
1050
                                return fmt.Errorf("unable to build channel "+
×
1051
                                        "info: %w", err)
×
1052
                        }
×
1053

1054
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1055
                        if err != nil {
×
1056
                                return fmt.Errorf("unable to extract channel "+
×
1057
                                        "policies: %w", err)
×
1058
                        }
×
1059

1060
                        p1, p2, err := getAndBuildChanPolicies(
×
1061
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1062
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1063
                        )
×
1064
                        if err != nil {
×
1065
                                return fmt.Errorf("unable to build channel "+
×
1066
                                        "policies: %w", err)
×
1067
                        }
×
1068

1069
                        edgesSeen[chanIDInt] = struct{}{}
×
1070
                        chanEdge := ChannelEdge{
×
1071
                                Info:    channel,
×
1072
                                Policy1: p1,
×
1073
                                Policy2: p2,
×
1074
                                Node1:   node1,
×
1075
                                Node2:   node2,
×
1076
                        }
×
1077
                        edges = append(edges, chanEdge)
×
1078
                        edgesToCache[chanIDInt] = chanEdge
×
1079
                }
1080

1081
                return nil
×
1082
        }, func() {
×
1083
                edgesSeen = make(map[uint64]struct{})
×
1084
                edgesToCache = make(map[uint64]ChannelEdge)
×
1085
                edges = nil
×
1086
        })
×
1087
        if err != nil {
×
1088
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1089
        }
×
1090

1091
        // Insert any edges loaded from disk into the cache.
1092
        for chanid, channel := range edgesToCache {
×
1093
                s.chanCache.insert(chanid, channel)
×
1094
        }
×
1095

1096
        if len(edges) > 0 {
×
1097
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1098
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1099
        } else {
×
1100
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1101
                        "horizon (%s, %s)", startTime, endTime)
×
1102
        }
×
1103

1104
        return edges, nil
×
1105
}
1106

1107
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1108
// data to the call-back.
1109
//
1110
// NOTE: The callback contents MUST not be modified.
1111
//
1112
// NOTE: part of the V1Store interface.
1113
func (s *SQLStore) ForEachNodeCached(cb func(node route.Vertex,
1114
        chans map[uint64]*DirectedChannel) error) error {
×
1115

×
1116
        var ctx = context.TODO()
×
1117

×
1118
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1119
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1120
                        nodePub route.Vertex) error {
×
1121

×
1122
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1123
                        if err != nil {
×
1124
                                return fmt.Errorf("unable to fetch "+
×
1125
                                        "node(id=%d) features: %w", nodeID, err)
×
1126
                        }
×
1127

1128
                        toNodeCallback := func() route.Vertex {
×
1129
                                return nodePub
×
1130
                        }
×
1131

1132
                        rows, err := db.ListChannelsByNodeID(
×
1133
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1134
                                        Version: int16(ProtocolV1),
×
1135
                                        NodeID1: nodeID,
×
1136
                                },
×
1137
                        )
×
1138
                        if err != nil {
×
1139
                                return fmt.Errorf("unable to fetch channels "+
×
1140
                                        "of node(id=%d): %w", nodeID, err)
×
1141
                        }
×
1142

1143
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1144
                        for _, row := range rows {
×
1145
                                node1, node2, err := buildNodeVertices(
×
1146
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1147
                                )
×
1148
                                if err != nil {
×
1149
                                        return err
×
1150
                                }
×
1151

1152
                                e, err := getAndBuildEdgeInfo(
×
1153
                                        ctx, db, s.cfg.ChainHash,
×
1154
                                        row.Channel.ID, row.Channel, node1,
×
1155
                                        node2,
×
1156
                                )
×
1157
                                if err != nil {
×
1158
                                        return fmt.Errorf("unable to build "+
×
1159
                                                "channel info: %w", err)
×
1160
                                }
×
1161

1162
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1163
                                        row,
×
1164
                                )
×
1165
                                if err != nil {
×
1166
                                        return fmt.Errorf("unable to "+
×
1167
                                                "extract channel "+
×
1168
                                                "policies: %w", err)
×
1169
                                }
×
1170

1171
                                p1, p2, err := getAndBuildChanPolicies(
×
1172
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1173
                                        node1, node2,
×
1174
                                )
×
1175
                                if err != nil {
×
1176
                                        return fmt.Errorf("unable to "+
×
1177
                                                "build channel policies: %w",
×
1178
                                                err)
×
1179
                                }
×
1180

1181
                                // Determine the outgoing and incoming policy
1182
                                // for this channel and node combo.
1183
                                outPolicy, inPolicy := p1, p2
×
1184
                                if p1 != nil && p1.ToNode == nodePub {
×
1185
                                        outPolicy, inPolicy = p2, p1
×
1186
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1187
                                        outPolicy, inPolicy = p2, p1
×
1188
                                }
×
1189

1190
                                var cachedInPolicy *models.CachedEdgePolicy
×
1191
                                if inPolicy != nil {
×
1192
                                        cachedInPolicy = models.NewCachedPolicy(
×
1193
                                                p2,
×
1194
                                        )
×
1195
                                        cachedInPolicy.ToNodePubKey =
×
1196
                                                toNodeCallback
×
1197
                                        cachedInPolicy.ToNodeFeatures =
×
1198
                                                features
×
1199
                                }
×
1200

1201
                                var inboundFee lnwire.Fee
×
1202
                                outPolicy.InboundFee.WhenSome(
×
1203
                                        func(fee lnwire.Fee) {
×
1204
                                                inboundFee = fee
×
1205
                                        },
×
1206
                                )
1207

1208
                                directedChannel := &DirectedChannel{
×
1209
                                        ChannelID: e.ChannelID,
×
1210
                                        IsNode1: nodePub ==
×
1211
                                                e.NodeKey1Bytes,
×
1212
                                        OtherNode:    e.NodeKey2Bytes,
×
1213
                                        Capacity:     e.Capacity,
×
1214
                                        OutPolicySet: p1 != nil,
×
1215
                                        InPolicy:     cachedInPolicy,
×
1216
                                        InboundFee:   inboundFee,
×
1217
                                }
×
1218

×
1219
                                if nodePub == e.NodeKey2Bytes {
×
1220
                                        directedChannel.OtherNode =
×
1221
                                                e.NodeKey1Bytes
×
1222
                                }
×
1223

1224
                                channels[e.ChannelID] = directedChannel
×
1225
                        }
1226

1227
                        return cb(nodePub, channels)
×
1228
                })
1229
        }, sqldb.NoOpReset)
1230
}
1231

1232
// ForEachChannelCacheable iterates through all the channel edges stored
1233
// within the graph and invokes the passed callback for each edge. The
1234
// callback takes two edges as since this is a directed graph, both the
1235
// in/out edges are visited. If the callback returns an error, then the
1236
// transaction is aborted and the iteration stops early.
1237
//
1238
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1239
// pointer for that particular channel edge routing policy will be
1240
// passed into the callback.
1241
//
1242
// NOTE: this method is like ForEachChannel but fetches only the data
1243
// required for the graph cache.
1244
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1245
        *models.CachedEdgePolicy,
1246
        *models.CachedEdgePolicy) error) error {
×
1247

×
1248
        ctx := context.TODO()
×
1249

×
1250
        handleChannel := func(db SQLQueries,
×
1251
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1252

×
1253
                node1, node2, err := buildNodeVertices(
×
1254
                        row.Node1Pubkey, row.Node2Pubkey,
×
1255
                )
×
1256
                if err != nil {
×
1257
                        return err
×
1258
                }
×
1259

1260
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
1261

×
1262
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1263
                if err != nil {
×
1264
                        return err
×
1265
                }
×
1266

1267
                var pol1, pol2 *models.CachedEdgePolicy
×
1268
                if dbPol1 != nil {
×
1269
                        policy1, err := buildChanPolicy(
×
1270
                                *dbPol1, edge.ChannelID, nil, node2,
×
1271
                        )
×
1272
                        if err != nil {
×
1273
                                return err
×
1274
                        }
×
1275

1276
                        pol1 = models.NewCachedPolicy(policy1)
×
1277
                }
1278
                if dbPol2 != nil {
×
1279
                        policy2, err := buildChanPolicy(
×
1280
                                *dbPol2, edge.ChannelID, nil, node1,
×
1281
                        )
×
1282
                        if err != nil {
×
1283
                                return err
×
1284
                        }
×
1285

1286
                        pol2 = models.NewCachedPolicy(policy2)
×
1287
                }
1288

1289
                if err := cb(edge, pol1, pol2); err != nil {
×
1290
                        return err
×
1291
                }
×
1292

1293
                return nil
×
1294
        }
1295

1296
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1297
                lastID := int64(-1)
×
1298
                for {
×
1299
                        //nolint:ll
×
1300
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1301
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1302
                                        Version: int16(ProtocolV1),
×
1303
                                        ID:      lastID,
×
1304
                                        Limit:   pageSize,
×
1305
                                },
×
1306
                        )
×
1307
                        if err != nil {
×
1308
                                return err
×
1309
                        }
×
1310

1311
                        if len(rows) == 0 {
×
1312
                                break
×
1313
                        }
1314

1315
                        for _, row := range rows {
×
1316
                                err := handleChannel(db, row)
×
1317
                                if err != nil {
×
1318
                                        return err
×
1319
                                }
×
1320

1321
                                lastID = row.Channel.ID
×
1322
                        }
1323
                }
1324

1325
                return nil
×
1326
        }, sqldb.NoOpReset)
1327
}
1328

1329
// ForEachChannel iterates through all the channel edges stored within the
1330
// graph and invokes the passed callback for each edge. The callback takes two
1331
// edges as since this is a directed graph, both the in/out edges are visited.
1332
// If the callback returns an error, then the transaction is aborted and the
1333
// iteration stops early.
1334
//
1335
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1336
// for that particular channel edge routing policy will be passed into the
1337
// callback.
1338
//
1339
// NOTE: part of the V1Store interface.
1340
func (s *SQLStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
1341
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
1342

×
1343
        ctx := context.TODO()
×
1344

×
1345
        handleChannel := func(db SQLQueries,
×
1346
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1347

×
1348
                node1, node2, err := buildNodeVertices(
×
1349
                        row.Node1Pubkey, row.Node2Pubkey,
×
1350
                )
×
1351
                if err != nil {
×
1352
                        return fmt.Errorf("unable to build node vertices: %w",
×
1353
                                err)
×
1354
                }
×
1355

1356
                edge, err := getAndBuildEdgeInfo(
×
1357
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1358
                        node1, node2,
×
1359
                )
×
1360
                if err != nil {
×
1361
                        return fmt.Errorf("unable to build channel info: %w",
×
1362
                                err)
×
1363
                }
×
1364

1365
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1366
                if err != nil {
×
1367
                        return fmt.Errorf("unable to extract channel "+
×
1368
                                "policies: %w", err)
×
1369
                }
×
1370

1371
                p1, p2, err := getAndBuildChanPolicies(
×
1372
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1373
                )
×
1374
                if err != nil {
×
1375
                        return fmt.Errorf("unable to build channel "+
×
1376
                                "policies: %w", err)
×
1377
                }
×
1378

1379
                err = cb(edge, p1, p2)
×
1380
                if err != nil {
×
1381
                        return fmt.Errorf("callback failed for channel "+
×
1382
                                "id=%d: %w", edge.ChannelID, err)
×
1383
                }
×
1384

1385
                return nil
×
1386
        }
1387

1388
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1389
                lastID := int64(-1)
×
1390
                for {
×
1391
                        //nolint:ll
×
1392
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1393
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1394
                                        Version: int16(ProtocolV1),
×
1395
                                        ID:      lastID,
×
1396
                                        Limit:   pageSize,
×
1397
                                },
×
1398
                        )
×
1399
                        if err != nil {
×
1400
                                return err
×
1401
                        }
×
1402

1403
                        if len(rows) == 0 {
×
1404
                                break
×
1405
                        }
1406

1407
                        for _, row := range rows {
×
1408
                                err := handleChannel(db, row)
×
1409
                                if err != nil {
×
1410
                                        return err
×
1411
                                }
×
1412

1413
                                lastID = row.Channel.ID
×
1414
                        }
1415
                }
1416

1417
                return nil
×
1418
        }, sqldb.NoOpReset)
1419
}
1420

1421
// FilterChannelRange returns the channel ID's of all known channels which were
1422
// mined in a block height within the passed range. The channel IDs are grouped
1423
// by their common block height. This method can be used to quickly share with a
1424
// peer the set of channels we know of within a particular range to catch them
1425
// up after a period of time offline. If withTimestamps is true then the
1426
// timestamp info of the latest received channel update messages of the channel
1427
// will be included in the response.
1428
//
1429
// NOTE: This is part of the V1Store interface.
1430
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1431
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1432

×
1433
        var (
×
1434
                ctx       = context.TODO()
×
1435
                startSCID = &lnwire.ShortChannelID{
×
1436
                        BlockHeight: startHeight,
×
1437
                }
×
1438
                endSCID = lnwire.ShortChannelID{
×
1439
                        BlockHeight: endHeight,
×
1440
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1441
                        TxPosition:  math.MaxUint16,
×
1442
                }
×
1443
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1444
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1445
        )
×
1446

×
1447
        // 1) get all channels where channelID is between start and end chan ID.
×
1448
        // 2) skip if not public (ie, no channel_proof)
×
1449
        // 3) collect that channel.
×
1450
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1451
        //    and add those timestamps to the collected channel.
×
1452
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1453
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1454
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1455
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1456
                                StartScid: chanIDStart,
×
1457
                                EndScid:   chanIDEnd,
×
1458
                        },
×
1459
                )
×
1460
                if err != nil {
×
1461
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1462
                                err)
×
1463
                }
×
1464

1465
                for _, dbChan := range dbChans {
×
1466
                        cid := lnwire.NewShortChanIDFromInt(
×
1467
                                byteOrder.Uint64(dbChan.Scid),
×
1468
                        )
×
1469
                        chanInfo := NewChannelUpdateInfo(
×
1470
                                cid, time.Time{}, time.Time{},
×
1471
                        )
×
1472

×
1473
                        if !withTimestamps {
×
1474
                                channelsPerBlock[cid.BlockHeight] = append(
×
1475
                                        channelsPerBlock[cid.BlockHeight],
×
1476
                                        chanInfo,
×
1477
                                )
×
1478

×
1479
                                continue
×
1480
                        }
1481

1482
                        //nolint:ll
1483
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1484
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1485
                                        Version:   int16(ProtocolV1),
×
1486
                                        ChannelID: dbChan.ID,
×
1487
                                        NodeID:    dbChan.NodeID1,
×
1488
                                },
×
1489
                        )
×
1490
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1491
                                return fmt.Errorf("unable to fetch node1 "+
×
1492
                                        "policy: %w", err)
×
1493
                        } else if err == nil {
×
1494
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1495
                                        node1Policy.LastUpdate.Int64, 0,
×
1496
                                )
×
1497
                        }
×
1498

1499
                        //nolint:ll
1500
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1501
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1502
                                        Version:   int16(ProtocolV1),
×
1503
                                        ChannelID: dbChan.ID,
×
1504
                                        NodeID:    dbChan.NodeID2,
×
1505
                                },
×
1506
                        )
×
1507
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1508
                                return fmt.Errorf("unable to fetch node2 "+
×
1509
                                        "policy: %w", err)
×
1510
                        } else if err == nil {
×
1511
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1512
                                        node2Policy.LastUpdate.Int64, 0,
×
1513
                                )
×
1514
                        }
×
1515

1516
                        channelsPerBlock[cid.BlockHeight] = append(
×
1517
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1518
                        )
×
1519
                }
1520

1521
                return nil
×
1522
        }, func() {
×
1523
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1524
        })
×
1525
        if err != nil {
×
1526
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1527
        }
×
1528

1529
        if len(channelsPerBlock) == 0 {
×
1530
                return nil, nil
×
1531
        }
×
1532

1533
        // Return the channel ranges in ascending block height order.
1534
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1535
        slices.Sort(blocks)
×
1536

×
1537
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1538
                return BlockChannelRange{
×
1539
                        Height:   block,
×
1540
                        Channels: channelsPerBlock[block],
×
1541
                }
×
1542
        }), nil
×
1543
}
1544

1545
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1546
// zombie. This method is used on an ad-hoc basis, when channels need to be
1547
// marked as zombies outside the normal pruning cycle.
1548
//
1549
// NOTE: part of the V1Store interface.
1550
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1551
        pubKey1, pubKey2 [33]byte) error {
×
1552

×
1553
        ctx := context.TODO()
×
1554

×
1555
        s.cacheMu.Lock()
×
1556
        defer s.cacheMu.Unlock()
×
1557

×
1558
        chanIDB := channelIDToBytes(chanID)
×
1559

×
1560
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1561
                return db.UpsertZombieChannel(
×
1562
                        ctx, sqlc.UpsertZombieChannelParams{
×
1563
                                Version:  int16(ProtocolV1),
×
1564
                                Scid:     chanIDB,
×
1565
                                NodeKey1: pubKey1[:],
×
1566
                                NodeKey2: pubKey2[:],
×
1567
                        },
×
1568
                )
×
1569
        }, sqldb.NoOpReset)
×
1570
        if err != nil {
×
1571
                return fmt.Errorf("unable to upsert zombie channel "+
×
1572
                        "(channel_id=%d): %w", chanID, err)
×
1573
        }
×
1574

1575
        s.rejectCache.remove(chanID)
×
1576
        s.chanCache.remove(chanID)
×
1577

×
1578
        return nil
×
1579
}
1580

1581
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1582
//
1583
// NOTE: part of the V1Store interface.
1584
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1585
        s.cacheMu.Lock()
×
1586
        defer s.cacheMu.Unlock()
×
1587

×
1588
        var (
×
1589
                ctx     = context.TODO()
×
1590
                chanIDB = channelIDToBytes(chanID)
×
1591
        )
×
1592

×
1593
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1594
                res, err := db.DeleteZombieChannel(
×
1595
                        ctx, sqlc.DeleteZombieChannelParams{
×
1596
                                Scid:    chanIDB,
×
1597
                                Version: int16(ProtocolV1),
×
1598
                        },
×
1599
                )
×
1600
                if err != nil {
×
1601
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1602
                                err)
×
1603
                }
×
1604

1605
                rows, err := res.RowsAffected()
×
1606
                if err != nil {
×
1607
                        return err
×
1608
                }
×
1609

1610
                if rows == 0 {
×
1611
                        return ErrZombieEdgeNotFound
×
1612
                } else if rows > 1 {
×
1613
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1614
                                "expected 1", rows)
×
1615
                }
×
1616

1617
                return nil
×
1618
        }, sqldb.NoOpReset)
1619
        if err != nil {
×
1620
                return fmt.Errorf("unable to mark edge live "+
×
1621
                        "(channel_id=%d): %w", chanID, err)
×
1622
        }
×
1623

1624
        s.rejectCache.remove(chanID)
×
1625
        s.chanCache.remove(chanID)
×
1626

×
1627
        return err
×
1628
}
1629

1630
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1631
// zombie, then the two node public keys corresponding to this edge are also
1632
// returned.
1633
//
1634
// NOTE: part of the V1Store interface.
1635
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1636
        error) {
×
1637

×
1638
        var (
×
1639
                ctx              = context.TODO()
×
1640
                isZombie         bool
×
1641
                pubKey1, pubKey2 route.Vertex
×
1642
                chanIDB          = channelIDToBytes(chanID)
×
1643
        )
×
1644

×
1645
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1646
                zombie, err := db.GetZombieChannel(
×
1647
                        ctx, sqlc.GetZombieChannelParams{
×
1648
                                Scid:    chanIDB,
×
1649
                                Version: int16(ProtocolV1),
×
1650
                        },
×
1651
                )
×
1652
                if errors.Is(err, sql.ErrNoRows) {
×
1653
                        return nil
×
1654
                }
×
1655
                if err != nil {
×
1656
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1657
                                err)
×
1658
                }
×
1659

1660
                copy(pubKey1[:], zombie.NodeKey1)
×
1661
                copy(pubKey2[:], zombie.NodeKey2)
×
1662
                isZombie = true
×
1663

×
1664
                return nil
×
1665
        }, sqldb.NoOpReset)
1666
        if err != nil {
×
1667
                return false, route.Vertex{}, route.Vertex{},
×
1668
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1669
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1670
        }
×
1671

1672
        return isZombie, pubKey1, pubKey2, nil
×
1673
}
1674

1675
// NumZombies returns the current number of zombie channels in the graph.
1676
//
1677
// NOTE: part of the V1Store interface.
1678
func (s *SQLStore) NumZombies() (uint64, error) {
×
1679
        var (
×
1680
                ctx        = context.TODO()
×
1681
                numZombies uint64
×
1682
        )
×
1683
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1684
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1685
                if err != nil {
×
1686
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1687
                                err)
×
1688
                }
×
1689

1690
                numZombies = uint64(count)
×
1691

×
1692
                return nil
×
1693
        }, sqldb.NoOpReset)
1694
        if err != nil {
×
1695
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1696
        }
×
1697

1698
        return numZombies, nil
×
1699
}
1700

1701
// DeleteChannelEdges removes edges with the given channel IDs from the
1702
// database and marks them as zombies. This ensures that we're unable to re-add
1703
// it to our database once again. If an edge does not exist within the
1704
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1705
// true, then when we mark these edges as zombies, we'll set up the keys such
1706
// that we require the node that failed to send the fresh update to be the one
1707
// that resurrects the channel from its zombie state. The markZombie bool
1708
// denotes whether to mark the channel as a zombie.
1709
//
1710
// NOTE: part of the V1Store interface.
1711
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1712
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1713

×
1714
        s.cacheMu.Lock()
×
1715
        defer s.cacheMu.Unlock()
×
1716

×
1717
        var (
×
1718
                ctx     = context.TODO()
×
1719
                deleted []*models.ChannelEdgeInfo
×
1720
        )
×
1721
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1722
                for _, chanID := range chanIDs {
×
1723
                        chanIDB := channelIDToBytes(chanID)
×
1724

×
1725
                        row, err := db.GetChannelBySCIDWithPolicies(
×
1726
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1727
                                        Scid:    chanIDB,
×
1728
                                        Version: int16(ProtocolV1),
×
1729
                                },
×
1730
                        )
×
1731
                        if errors.Is(err, sql.ErrNoRows) {
×
1732
                                return ErrEdgeNotFound
×
1733
                        } else if err != nil {
×
1734
                                return fmt.Errorf("unable to fetch channel: %w",
×
1735
                                        err)
×
1736
                        }
×
1737

1738
                        node1, node2, err := buildNodeVertices(
×
1739
                                row.Node.PubKey, row.Node_2.PubKey,
×
1740
                        )
×
1741
                        if err != nil {
×
1742
                                return err
×
1743
                        }
×
1744

1745
                        info, err := getAndBuildEdgeInfo(
×
1746
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1747
                                row.Channel, node1, node2,
×
1748
                        )
×
1749
                        if err != nil {
×
1750
                                return err
×
1751
                        }
×
1752

1753
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
1754
                        if err != nil {
×
1755
                                return fmt.Errorf("unable to delete "+
×
1756
                                        "channel: %w", err)
×
1757
                        }
×
1758

1759
                        deleted = append(deleted, info)
×
1760

×
1761
                        if !markZombie {
×
1762
                                continue
×
1763
                        }
1764

1765
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1766
                                info.NodeKey2Bytes
×
1767
                        if strictZombiePruning {
×
1768
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1769
                                if row.Policy1LastUpdate.Valid {
×
1770
                                        e1Time := time.Unix(
×
1771
                                                row.Policy1LastUpdate.Int64, 0,
×
1772
                                        )
×
1773
                                        e1UpdateTime = &e1Time
×
1774
                                }
×
1775
                                if row.Policy2LastUpdate.Valid {
×
1776
                                        e2Time := time.Unix(
×
1777
                                                row.Policy2LastUpdate.Int64, 0,
×
1778
                                        )
×
1779
                                        e2UpdateTime = &e2Time
×
1780
                                }
×
1781

1782
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1783
                                        info, e1UpdateTime, e2UpdateTime,
×
1784
                                )
×
1785
                        }
1786

1787
                        err = db.UpsertZombieChannel(
×
1788
                                ctx, sqlc.UpsertZombieChannelParams{
×
1789
                                        Version:  int16(ProtocolV1),
×
1790
                                        Scid:     chanIDB,
×
1791
                                        NodeKey1: nodeKey1[:],
×
1792
                                        NodeKey2: nodeKey2[:],
×
1793
                                },
×
1794
                        )
×
1795
                        if err != nil {
×
1796
                                return fmt.Errorf("unable to mark channel as "+
×
1797
                                        "zombie: %w", err)
×
1798
                        }
×
1799
                }
1800

1801
                return nil
×
1802
        }, func() {
×
1803
                deleted = nil
×
1804
        })
×
1805
        if err != nil {
×
1806
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1807
                        err)
×
1808
        }
×
1809

1810
        for _, chanID := range chanIDs {
×
1811
                s.rejectCache.remove(chanID)
×
1812
                s.chanCache.remove(chanID)
×
1813
        }
×
1814

1815
        return deleted, nil
×
1816
}
1817

1818
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1819
// channel identified by the channel ID. If the channel can't be found, then
1820
// ErrEdgeNotFound is returned. A struct which houses the general information
1821
// for the channel itself is returned as well as two structs that contain the
1822
// routing policies for the channel in either direction.
1823
//
1824
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1825
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1826
// the ChannelEdgeInfo will only include the public keys of each node.
1827
//
1828
// NOTE: part of the V1Store interface.
1829
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1830
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1831
        *models.ChannelEdgePolicy, error) {
×
1832

×
1833
        var (
×
1834
                ctx              = context.TODO()
×
1835
                edge             *models.ChannelEdgeInfo
×
1836
                policy1, policy2 *models.ChannelEdgePolicy
×
1837
                chanIDB          = channelIDToBytes(chanID)
×
1838
        )
×
1839
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1840
                row, err := db.GetChannelBySCIDWithPolicies(
×
1841
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1842
                                Scid:    chanIDB,
×
1843
                                Version: int16(ProtocolV1),
×
1844
                        },
×
1845
                )
×
1846
                if errors.Is(err, sql.ErrNoRows) {
×
1847
                        // First check if this edge is perhaps in the zombie
×
1848
                        // index.
×
1849
                        zombie, err := db.GetZombieChannel(
×
1850
                                ctx, sqlc.GetZombieChannelParams{
×
1851
                                        Scid:    chanIDB,
×
1852
                                        Version: int16(ProtocolV1),
×
1853
                                },
×
1854
                        )
×
1855
                        if errors.Is(err, sql.ErrNoRows) {
×
1856
                                return ErrEdgeNotFound
×
1857
                        } else if err != nil {
×
1858
                                return fmt.Errorf("unable to check if "+
×
1859
                                        "channel is zombie: %w", err)
×
1860
                        }
×
1861

1862
                        // At this point, we know the channel is a zombie, so
1863
                        // we'll return an error indicating this, and we will
1864
                        // populate the edge info with the public keys of each
1865
                        // party as this is the only information we have about
1866
                        // it.
1867
                        edge = &models.ChannelEdgeInfo{}
×
1868
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1869
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1870

×
1871
                        return ErrZombieEdge
×
1872
                } else if err != nil {
×
1873
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1874
                }
×
1875

1876
                node1, node2, err := buildNodeVertices(
×
1877
                        row.Node.PubKey, row.Node_2.PubKey,
×
1878
                )
×
1879
                if err != nil {
×
1880
                        return err
×
1881
                }
×
1882

1883
                edge, err = getAndBuildEdgeInfo(
×
1884
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1885
                        node1, node2,
×
1886
                )
×
1887
                if err != nil {
×
1888
                        return fmt.Errorf("unable to build channel info: %w",
×
1889
                                err)
×
1890
                }
×
1891

1892
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1893
                if err != nil {
×
1894
                        return fmt.Errorf("unable to extract channel "+
×
1895
                                "policies: %w", err)
×
1896
                }
×
1897

1898
                policy1, policy2, err = getAndBuildChanPolicies(
×
1899
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1900
                )
×
1901
                if err != nil {
×
1902
                        return fmt.Errorf("unable to build channel "+
×
1903
                                "policies: %w", err)
×
1904
                }
×
1905

1906
                return nil
×
1907
        }, sqldb.NoOpReset)
1908
        if err != nil {
×
1909
                // If we are returning the ErrZombieEdge, then we also need to
×
1910
                // return the edge info as the method comment indicates that
×
1911
                // this will be populated when the edge is a zombie.
×
1912
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1913
                        err)
×
1914
        }
×
1915

1916
        return edge, policy1, policy2, nil
×
1917
}
1918

1919
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1920
// the channel identified by the funding outpoint. If the channel can't be
1921
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1922
// information for the channel itself is returned as well as two structs that
1923
// contain the routing policies for the channel in either direction.
1924
//
1925
// NOTE: part of the V1Store interface.
1926
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1927
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1928
        *models.ChannelEdgePolicy, error) {
×
1929

×
1930
        var (
×
1931
                ctx              = context.TODO()
×
1932
                edge             *models.ChannelEdgeInfo
×
1933
                policy1, policy2 *models.ChannelEdgePolicy
×
1934
        )
×
1935
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1936
                row, err := db.GetChannelByOutpointWithPolicies(
×
1937
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1938
                                Outpoint: op.String(),
×
1939
                                Version:  int16(ProtocolV1),
×
1940
                        },
×
1941
                )
×
1942
                if errors.Is(err, sql.ErrNoRows) {
×
1943
                        return ErrEdgeNotFound
×
1944
                } else if err != nil {
×
1945
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1946
                }
×
1947

1948
                node1, node2, err := buildNodeVertices(
×
1949
                        row.Node1Pubkey, row.Node2Pubkey,
×
1950
                )
×
1951
                if err != nil {
×
1952
                        return err
×
1953
                }
×
1954

1955
                edge, err = getAndBuildEdgeInfo(
×
1956
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1957
                        node1, node2,
×
1958
                )
×
1959
                if err != nil {
×
1960
                        return fmt.Errorf("unable to build channel info: %w",
×
1961
                                err)
×
1962
                }
×
1963

1964
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1965
                if err != nil {
×
1966
                        return fmt.Errorf("unable to extract channel "+
×
1967
                                "policies: %w", err)
×
1968
                }
×
1969

1970
                policy1, policy2, err = getAndBuildChanPolicies(
×
1971
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1972
                )
×
1973
                if err != nil {
×
1974
                        return fmt.Errorf("unable to build channel "+
×
1975
                                "policies: %w", err)
×
1976
                }
×
1977

1978
                return nil
×
1979
        }, sqldb.NoOpReset)
1980
        if err != nil {
×
1981
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1982
                        err)
×
1983
        }
×
1984

1985
        return edge, policy1, policy2, nil
×
1986
}
1987

1988
// HasChannelEdge returns true if the database knows of a channel edge with the
1989
// passed channel ID, and false otherwise. If an edge with that ID is found
1990
// within the graph, then two time stamps representing the last time the edge
1991
// was updated for both directed edges are returned along with the boolean. If
1992
// it is not found, then the zombie index is checked and its result is returned
1993
// as the second boolean.
1994
//
1995
// NOTE: part of the V1Store interface.
1996
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1997
        bool, error) {
×
1998

×
1999
        ctx := context.TODO()
×
2000

×
2001
        var (
×
2002
                exists          bool
×
2003
                isZombie        bool
×
2004
                node1LastUpdate time.Time
×
2005
                node2LastUpdate time.Time
×
2006
        )
×
2007

×
2008
        // We'll query the cache with the shared lock held to allow multiple
×
2009
        // readers to access values in the cache concurrently if they exist.
×
2010
        s.cacheMu.RLock()
×
2011
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2012
                s.cacheMu.RUnlock()
×
2013
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2014
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2015
                exists, isZombie = entry.flags.unpack()
×
2016

×
2017
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2018
        }
×
2019
        s.cacheMu.RUnlock()
×
2020

×
2021
        s.cacheMu.Lock()
×
2022
        defer s.cacheMu.Unlock()
×
2023

×
2024
        // The item was not found with the shared lock, so we'll acquire the
×
2025
        // exclusive lock and check the cache again in case another method added
×
2026
        // the entry to the cache while no lock was held.
×
2027
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2028
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2029
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2030
                exists, isZombie = entry.flags.unpack()
×
2031

×
2032
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2033
        }
×
2034

2035
        chanIDB := channelIDToBytes(chanID)
×
2036
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2037
                channel, err := db.GetChannelBySCID(
×
2038
                        ctx, sqlc.GetChannelBySCIDParams{
×
2039
                                Scid:    chanIDB,
×
2040
                                Version: int16(ProtocolV1),
×
2041
                        },
×
2042
                )
×
2043
                if errors.Is(err, sql.ErrNoRows) {
×
2044
                        // Check if it is a zombie channel.
×
2045
                        isZombie, err = db.IsZombieChannel(
×
2046
                                ctx, sqlc.IsZombieChannelParams{
×
2047
                                        Scid:    chanIDB,
×
2048
                                        Version: int16(ProtocolV1),
×
2049
                                },
×
2050
                        )
×
2051
                        if err != nil {
×
2052
                                return fmt.Errorf("could not check if channel "+
×
2053
                                        "is zombie: %w", err)
×
2054
                        }
×
2055

2056
                        return nil
×
2057
                } else if err != nil {
×
2058
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2059
                }
×
2060

2061
                exists = true
×
2062

×
2063
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2064
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2065
                                Version:   int16(ProtocolV1),
×
2066
                                ChannelID: channel.ID,
×
2067
                                NodeID:    channel.NodeID1,
×
2068
                        },
×
2069
                )
×
2070
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2071
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2072
                                err)
×
2073
                } else if err == nil {
×
2074
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2075
                }
×
2076

2077
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2078
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2079
                                Version:   int16(ProtocolV1),
×
2080
                                ChannelID: channel.ID,
×
2081
                                NodeID:    channel.NodeID2,
×
2082
                        },
×
2083
                )
×
2084
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2085
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2086
                                err)
×
2087
                } else if err == nil {
×
2088
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2089
                }
×
2090

2091
                return nil
×
2092
        }, sqldb.NoOpReset)
2093
        if err != nil {
×
2094
                return time.Time{}, time.Time{}, false, false,
×
2095
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2096
        }
×
2097

2098
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2099
                upd1Time: node1LastUpdate.Unix(),
×
2100
                upd2Time: node2LastUpdate.Unix(),
×
2101
                flags:    packRejectFlags(exists, isZombie),
×
2102
        })
×
2103

×
2104
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2105
}
2106

2107
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2108
// passed channel point (outpoint). If the passed channel doesn't exist within
2109
// the database, then ErrEdgeNotFound is returned.
2110
//
2111
// NOTE: part of the V1Store interface.
2112
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2113
        var (
×
2114
                ctx       = context.TODO()
×
2115
                channelID uint64
×
2116
        )
×
2117
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2118
                chanID, err := db.GetSCIDByOutpoint(
×
2119
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2120
                                Outpoint: chanPoint.String(),
×
2121
                                Version:  int16(ProtocolV1),
×
2122
                        },
×
2123
                )
×
2124
                if errors.Is(err, sql.ErrNoRows) {
×
2125
                        return ErrEdgeNotFound
×
2126
                } else if err != nil {
×
2127
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2128
                                err)
×
2129
                }
×
2130

2131
                channelID = byteOrder.Uint64(chanID)
×
2132

×
2133
                return nil
×
2134
        }, sqldb.NoOpReset)
2135
        if err != nil {
×
2136
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2137
        }
×
2138

2139
        return channelID, nil
×
2140
}
2141

2142
// IsPublicNode is a helper method that determines whether the node with the
2143
// given public key is seen as a public node in the graph from the graph's
2144
// source node's point of view.
2145
//
2146
// NOTE: part of the V1Store interface.
2147
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2148
        ctx := context.TODO()
×
2149

×
2150
        var isPublic bool
×
2151
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2152
                var err error
×
2153
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2154

×
2155
                return err
×
2156
        }, sqldb.NoOpReset)
×
2157
        if err != nil {
×
2158
                return false, fmt.Errorf("unable to check if node is "+
×
2159
                        "public: %w", err)
×
2160
        }
×
2161

2162
        return isPublic, nil
×
2163
}
2164

2165
// FetchChanInfos returns the set of channel edges that correspond to the passed
2166
// channel ID's. If an edge is the query is unknown to the database, it will
2167
// skipped and the result will contain only those edges that exist at the time
2168
// of the query. This can be used to respond to peer queries that are seeking to
2169
// fill in gaps in their view of the channel graph.
2170
//
2171
// NOTE: part of the V1Store interface.
2172
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2173
        var (
×
2174
                ctx   = context.TODO()
×
2175
                edges []ChannelEdge
×
2176
        )
×
2177
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2178
                for _, chanID := range chanIDs {
×
2179
                        chanIDB := channelIDToBytes(chanID)
×
2180

×
2181
                        // TODO(elle): potentially optimize this by using
×
2182
                        //  sqlc.slice() once that works for both SQLite and
×
2183
                        //  Postgres.
×
2184
                        row, err := db.GetChannelBySCIDWithPolicies(
×
2185
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
2186
                                        Scid:    chanIDB,
×
2187
                                        Version: int16(ProtocolV1),
×
2188
                                },
×
2189
                        )
×
2190
                        if errors.Is(err, sql.ErrNoRows) {
×
2191
                                continue
×
2192
                        } else if err != nil {
×
2193
                                return fmt.Errorf("unable to fetch channel: %w",
×
2194
                                        err)
×
2195
                        }
×
2196

2197
                        node1, node2, err := buildNodes(
×
2198
                                ctx, db, row.Node, row.Node_2,
×
2199
                        )
×
2200
                        if err != nil {
×
2201
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2202
                                        err)
×
2203
                        }
×
2204

2205
                        edge, err := getAndBuildEdgeInfo(
×
2206
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2207
                                row.Channel, node1.PubKeyBytes,
×
2208
                                node2.PubKeyBytes,
×
2209
                        )
×
2210
                        if err != nil {
×
2211
                                return fmt.Errorf("unable to build "+
×
2212
                                        "channel info: %w", err)
×
2213
                        }
×
2214

2215
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2216
                        if err != nil {
×
2217
                                return fmt.Errorf("unable to extract channel "+
×
2218
                                        "policies: %w", err)
×
2219
                        }
×
2220

2221
                        p1, p2, err := getAndBuildChanPolicies(
×
2222
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2223
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2224
                        )
×
2225
                        if err != nil {
×
2226
                                return fmt.Errorf("unable to build channel "+
×
2227
                                        "policies: %w", err)
×
2228
                        }
×
2229

2230
                        edges = append(edges, ChannelEdge{
×
2231
                                Info:    edge,
×
2232
                                Policy1: p1,
×
2233
                                Policy2: p2,
×
2234
                                Node1:   node1,
×
2235
                                Node2:   node2,
×
2236
                        })
×
2237
                }
2238

2239
                return nil
×
2240
        }, func() {
×
2241
                edges = nil
×
2242
        })
×
2243
        if err != nil {
×
2244
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2245
        }
×
2246

2247
        return edges, nil
×
2248
}
2249

2250
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2251
// ID's that we don't know and are not known zombies of the passed set. In other
2252
// words, we perform a set difference of our set of chan ID's and the ones
2253
// passed in. This method can be used by callers to determine the set of
2254
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2255
// known zombies is also returned.
2256
//
2257
// NOTE: part of the V1Store interface.
2258
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2259
        []ChannelUpdateInfo, error) {
×
2260

×
2261
        var (
×
2262
                ctx          = context.TODO()
×
2263
                newChanIDs   []uint64
×
2264
                knownZombies []ChannelUpdateInfo
×
2265
        )
×
2266
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2267
                for _, chanInfo := range chansInfo {
×
2268
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2269
                        chanIDB := channelIDToBytes(channelID)
×
2270

×
2271
                        // TODO(elle): potentially optimize this by using
×
2272
                        //  sqlc.slice() once that works for both SQLite and
×
2273
                        //  Postgres.
×
2274
                        _, err := db.GetChannelBySCID(
×
2275
                                ctx, sqlc.GetChannelBySCIDParams{
×
2276
                                        Version: int16(ProtocolV1),
×
2277
                                        Scid:    chanIDB,
×
2278
                                },
×
2279
                        )
×
2280
                        if err == nil {
×
2281
                                continue
×
2282
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
2283
                                return fmt.Errorf("unable to fetch channel: %w",
×
2284
                                        err)
×
2285
                        }
×
2286

2287
                        isZombie, err := db.IsZombieChannel(
×
2288
                                ctx, sqlc.IsZombieChannelParams{
×
2289
                                        Scid:    chanIDB,
×
2290
                                        Version: int16(ProtocolV1),
×
2291
                                },
×
2292
                        )
×
2293
                        if err != nil {
×
2294
                                return fmt.Errorf("unable to fetch zombie "+
×
2295
                                        "channel: %w", err)
×
2296
                        }
×
2297

2298
                        if isZombie {
×
2299
                                knownZombies = append(knownZombies, chanInfo)
×
2300

×
2301
                                continue
×
2302
                        }
2303

2304
                        newChanIDs = append(newChanIDs, channelID)
×
2305
                }
2306

2307
                return nil
×
2308
        }, func() {
×
2309
                newChanIDs = nil
×
2310
                knownZombies = nil
×
2311
        })
×
2312
        if err != nil {
×
2313
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2314
        }
×
2315

2316
        return newChanIDs, knownZombies, nil
×
2317
}
2318

2319
// PruneGraphNodes is a garbage collection method which attempts to prune out
2320
// any nodes from the channel graph that are currently unconnected. This ensure
2321
// that we only maintain a graph of reachable nodes. In the event that a pruned
2322
// node gains more channels, it will be re-added back to the graph.
2323
//
2324
// NOTE: this prunes nodes across protocol versions. It will never prune the
2325
// source nodes.
2326
//
2327
// NOTE: part of the V1Store interface.
2328
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2329
        var ctx = context.TODO()
×
2330

×
2331
        var prunedNodes []route.Vertex
×
2332
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2333
                var err error
×
2334
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2335

×
2336
                return err
×
2337
        }, func() {
×
2338
                prunedNodes = nil
×
2339
        })
×
2340
        if err != nil {
×
2341
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2342
        }
×
2343

2344
        return prunedNodes, nil
×
2345
}
2346

2347
// PruneGraph prunes newly closed channels from the channel graph in response
2348
// to a new block being solved on the network. Any transactions which spend the
2349
// funding output of any known channels within he graph will be deleted.
2350
// Additionally, the "prune tip", or the last block which has been used to
2351
// prune the graph is stored so callers can ensure the graph is fully in sync
2352
// with the current UTXO state. A slice of channels that have been closed by
2353
// the target block along with any pruned nodes are returned if the function
2354
// succeeds without error.
2355
//
2356
// NOTE: part of the V1Store interface.
2357
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2358
        blockHash *chainhash.Hash, blockHeight uint32) (
2359
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2360

×
2361
        ctx := context.TODO()
×
2362

×
2363
        s.cacheMu.Lock()
×
2364
        defer s.cacheMu.Unlock()
×
2365

×
2366
        var (
×
2367
                closedChans []*models.ChannelEdgeInfo
×
2368
                prunedNodes []route.Vertex
×
2369
        )
×
2370
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2371
                for _, outpoint := range spentOutputs {
×
2372
                        // TODO(elle): potentially optimize this by using
×
2373
                        //  sqlc.slice() once that works for both SQLite and
×
2374
                        //  Postgres.
×
2375
                        //
×
2376
                        // NOTE: this fetches channels for all protocol
×
2377
                        // versions.
×
2378
                        row, err := db.GetChannelByOutpoint(
×
2379
                                ctx, outpoint.String(),
×
2380
                        )
×
2381
                        if errors.Is(err, sql.ErrNoRows) {
×
2382
                                continue
×
2383
                        } else if err != nil {
×
2384
                                return fmt.Errorf("unable to fetch channel: %w",
×
2385
                                        err)
×
2386
                        }
×
2387

2388
                        node1, node2, err := buildNodeVertices(
×
2389
                                row.Node1Pubkey, row.Node2Pubkey,
×
2390
                        )
×
2391
                        if err != nil {
×
2392
                                return err
×
2393
                        }
×
2394

2395
                        info, err := getAndBuildEdgeInfo(
×
2396
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2397
                                row.Channel, node1, node2,
×
2398
                        )
×
2399
                        if err != nil {
×
2400
                                return err
×
2401
                        }
×
2402

2403
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2404
                        if err != nil {
×
2405
                                return fmt.Errorf("unable to delete "+
×
2406
                                        "channel: %w", err)
×
2407
                        }
×
2408

2409
                        closedChans = append(closedChans, info)
×
2410
                }
2411

2412
                err := db.UpsertPruneLogEntry(
×
2413
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2414
                                BlockHash:   blockHash[:],
×
2415
                                BlockHeight: int64(blockHeight),
×
2416
                        },
×
2417
                )
×
2418
                if err != nil {
×
2419
                        return fmt.Errorf("unable to insert prune log "+
×
2420
                                "entry: %w", err)
×
2421
                }
×
2422

2423
                // Now that we've pruned some channels, we'll also prune any
2424
                // nodes that no longer have any channels.
2425
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2426
                if err != nil {
×
2427
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2428
                                err)
×
2429
                }
×
2430

2431
                return nil
×
2432
        }, func() {
×
2433
                prunedNodes = nil
×
2434
                closedChans = nil
×
2435
        })
×
2436
        if err != nil {
×
2437
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2438
        }
×
2439

2440
        for _, channel := range closedChans {
×
2441
                s.rejectCache.remove(channel.ChannelID)
×
2442
                s.chanCache.remove(channel.ChannelID)
×
2443
        }
×
2444

2445
        return closedChans, prunedNodes, nil
×
2446
}
2447

2448
// ChannelView returns the verifiable edge information for each active channel
2449
// within the known channel graph. The set of UTXOs (along with their scripts)
2450
// returned are the ones that need to be watched on chain to detect channel
2451
// closes on the resident blockchain.
2452
//
2453
// NOTE: part of the V1Store interface.
2454
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2455
        var (
×
2456
                ctx        = context.TODO()
×
2457
                edgePoints []EdgePoint
×
2458
        )
×
2459

×
2460
        handleChannel := func(db SQLQueries,
×
2461
                channel sqlc.ListChannelsPaginatedRow) error {
×
2462

×
2463
                pkScript, err := genMultiSigP2WSH(
×
2464
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
2465
                )
×
2466
                if err != nil {
×
2467
                        return err
×
2468
                }
×
2469

2470
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2471
                if err != nil {
×
2472
                        return err
×
2473
                }
×
2474

2475
                edgePoints = append(edgePoints, EdgePoint{
×
2476
                        FundingPkScript: pkScript,
×
2477
                        OutPoint:        *op,
×
2478
                })
×
2479

×
2480
                return nil
×
2481
        }
2482

2483
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2484
                lastID := int64(-1)
×
2485
                for {
×
2486
                        rows, err := db.ListChannelsPaginated(
×
2487
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2488
                                        Version: int16(ProtocolV1),
×
2489
                                        ID:      lastID,
×
2490
                                        Limit:   pageSize,
×
2491
                                },
×
2492
                        )
×
2493
                        if err != nil {
×
2494
                                return err
×
2495
                        }
×
2496

2497
                        if len(rows) == 0 {
×
2498
                                break
×
2499
                        }
2500

2501
                        for _, row := range rows {
×
2502
                                err := handleChannel(db, row)
×
2503
                                if err != nil {
×
2504
                                        return err
×
2505
                                }
×
2506

2507
                                lastID = row.ID
×
2508
                        }
2509
                }
2510

2511
                return nil
×
2512
        }, func() {
×
2513
                edgePoints = nil
×
2514
        })
×
2515
        if err != nil {
×
2516
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2517
        }
×
2518

2519
        return edgePoints, nil
×
2520
}
2521

2522
// PruneTip returns the block height and hash of the latest block that has been
2523
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2524
// to tell if the graph is currently in sync with the current best known UTXO
2525
// state.
2526
//
2527
// NOTE: part of the V1Store interface.
2528
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2529
        var (
×
2530
                ctx       = context.TODO()
×
2531
                tipHash   chainhash.Hash
×
2532
                tipHeight uint32
×
2533
        )
×
2534
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2535
                pruneTip, err := db.GetPruneTip(ctx)
×
2536
                if errors.Is(err, sql.ErrNoRows) {
×
2537
                        return ErrGraphNeverPruned
×
2538
                } else if err != nil {
×
2539
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2540
                }
×
2541

2542
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2543
                tipHeight = uint32(pruneTip.BlockHeight)
×
2544

×
2545
                return nil
×
2546
        }, sqldb.NoOpReset)
2547
        if err != nil {
×
2548
                return nil, 0, err
×
2549
        }
×
2550

2551
        return &tipHash, tipHeight, nil
×
2552
}
2553

2554
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2555
//
2556
// NOTE: this prunes nodes across protocol versions. It will never prune the
2557
// source nodes.
2558
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2559
        db SQLQueries) ([]route.Vertex, error) {
×
2560

×
2561
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2562
        if err != nil {
×
2563
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2564
                        "nodes: %w", err)
×
2565
        }
×
2566

2567
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2568
        for i, nodeKey := range nodeKeys {
×
2569
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2570
                if err != nil {
×
2571
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2572
                                "from bytes: %w", err)
×
2573
                }
×
2574

2575
                prunedNodes[i] = pub
×
2576
        }
2577

2578
        return prunedNodes, nil
×
2579
}
2580

2581
// DisconnectBlockAtHeight is used to indicate that the block specified
2582
// by the passed height has been disconnected from the main chain. This
2583
// will "rewind" the graph back to the height below, deleting channels
2584
// that are no longer confirmed from the graph. The prune log will be
2585
// set to the last prune height valid for the remaining chain.
2586
// Channels that were removed from the graph resulting from the
2587
// disconnected block are returned.
2588
//
2589
// NOTE: part of the V1Store interface.
2590
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2591
        []*models.ChannelEdgeInfo, error) {
×
2592

×
2593
        ctx := context.TODO()
×
2594

×
2595
        var (
×
2596
                // Every channel having a ShortChannelID starting at 'height'
×
2597
                // will no longer be confirmed.
×
2598
                startShortChanID = lnwire.ShortChannelID{
×
2599
                        BlockHeight: height,
×
2600
                }
×
2601

×
2602
                // Delete everything after this height from the db up until the
×
2603
                // SCID alias range.
×
2604
                endShortChanID = aliasmgr.StartingAlias
×
2605

×
2606
                removedChans []*models.ChannelEdgeInfo
×
2607

×
2608
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2609
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2610
        )
×
2611

×
2612
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2613
                rows, err := db.GetChannelsBySCIDRange(
×
2614
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2615
                                StartScid: chanIDStart,
×
2616
                                EndScid:   chanIDEnd,
×
2617
                        },
×
2618
                )
×
2619
                if err != nil {
×
2620
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2621
                }
×
2622

2623
                for _, row := range rows {
×
2624
                        node1, node2, err := buildNodeVertices(
×
2625
                                row.Node1PubKey, row.Node2PubKey,
×
2626
                        )
×
2627
                        if err != nil {
×
2628
                                return err
×
2629
                        }
×
2630

2631
                        channel, err := getAndBuildEdgeInfo(
×
2632
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2633
                                row.Channel, node1, node2,
×
2634
                        )
×
2635
                        if err != nil {
×
2636
                                return err
×
2637
                        }
×
2638

2639
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2640
                        if err != nil {
×
2641
                                return fmt.Errorf("unable to delete "+
×
2642
                                        "channel: %w", err)
×
2643
                        }
×
2644

2645
                        removedChans = append(removedChans, channel)
×
2646
                }
2647

2648
                return db.DeletePruneLogEntriesInRange(
×
2649
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2650
                                StartHeight: int64(height),
×
2651
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2652
                        },
×
2653
                )
×
2654
        }, func() {
×
2655
                removedChans = nil
×
2656
        })
×
2657
        if err != nil {
×
2658
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2659
                        "height: %w", err)
×
2660
        }
×
2661

2662
        for _, channel := range removedChans {
×
2663
                s.rejectCache.remove(channel.ChannelID)
×
2664
                s.chanCache.remove(channel.ChannelID)
×
2665
        }
×
2666

2667
        return removedChans, nil
×
2668
}
2669

2670
// AddEdgeProof sets the proof of an existing edge in the graph database.
2671
//
2672
// NOTE: part of the V1Store interface.
2673
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2674
        proof *models.ChannelAuthProof) error {
×
2675

×
2676
        var (
×
2677
                ctx       = context.TODO()
×
2678
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2679
        )
×
2680

×
2681
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2682
                res, err := db.AddV1ChannelProof(
×
2683
                        ctx, sqlc.AddV1ChannelProofParams{
×
2684
                                Scid:              scidBytes,
×
2685
                                Node1Signature:    proof.NodeSig1Bytes,
×
2686
                                Node2Signature:    proof.NodeSig2Bytes,
×
2687
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2688
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2689
                        },
×
2690
                )
×
2691
                if err != nil {
×
2692
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2693
                }
×
2694

2695
                n, err := res.RowsAffected()
×
2696
                if err != nil {
×
2697
                        return err
×
2698
                }
×
2699

2700
                if n == 0 {
×
2701
                        return fmt.Errorf("no rows affected when adding edge "+
×
2702
                                "proof for SCID %v", scid)
×
2703
                } else if n > 1 {
×
2704
                        return fmt.Errorf("multiple rows affected when adding "+
×
2705
                                "edge proof for SCID %v: %d rows affected",
×
2706
                                scid, n)
×
2707
                }
×
2708

2709
                return nil
×
2710
        }, sqldb.NoOpReset)
2711
        if err != nil {
×
2712
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2713
        }
×
2714

2715
        return nil
×
2716
}
2717

2718
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2719
// that we can ignore channel announcements that we know to be closed without
2720
// having to validate them and fetch a block.
2721
//
2722
// NOTE: part of the V1Store interface.
2723
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2724
        var (
×
2725
                ctx     = context.TODO()
×
2726
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2727
        )
×
2728

×
2729
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2730
                return db.InsertClosedChannel(ctx, chanIDB)
×
2731
        }, sqldb.NoOpReset)
×
2732
}
2733

2734
// IsClosedScid checks whether a channel identified by the passed in scid is
2735
// closed. This helps avoid having to perform expensive validation checks.
2736
//
2737
// NOTE: part of the V1Store interface.
2738
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2739
        var (
×
2740
                ctx      = context.TODO()
×
2741
                isClosed bool
×
2742
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2743
        )
×
2744
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2745
                var err error
×
2746
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2747
                if err != nil {
×
2748
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2749
                                err)
×
2750
                }
×
2751

2752
                return nil
×
2753
        }, sqldb.NoOpReset)
2754
        if err != nil {
×
2755
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2756
                        err)
×
2757
        }
×
2758

2759
        return isClosed, nil
×
2760
}
2761

2762
// GraphSession will provide the call-back with access to a NodeTraverser
2763
// instance which can be used to perform queries against the channel graph.
2764
//
2765
// NOTE: part of the V1Store interface.
2766
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error) error {
×
2767
        var ctx = context.TODO()
×
2768

×
2769
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2770
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2771
        }, sqldb.NoOpReset)
×
2772
}
2773

2774
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2775
// read only transaction for a consistent view of the graph.
2776
type sqlNodeTraverser struct {
2777
        db    SQLQueries
2778
        chain chainhash.Hash
2779
}
2780

2781
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2782
// NodeTraverser interface.
2783
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2784

2785
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2786
func newSQLNodeTraverser(db SQLQueries,
2787
        chain chainhash.Hash) *sqlNodeTraverser {
×
2788

×
2789
        return &sqlNodeTraverser{
×
2790
                db:    db,
×
2791
                chain: chain,
×
2792
        }
×
2793
}
×
2794

2795
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2796
// node.
2797
//
2798
// NOTE: Part of the NodeTraverser interface.
2799
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2800
        cb func(channel *DirectedChannel) error) error {
×
2801

×
2802
        ctx := context.TODO()
×
2803

×
2804
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2805
}
×
2806

2807
// FetchNodeFeatures returns the features of the given node. If the node is
2808
// unknown, assume no additional features are supported.
2809
//
2810
// NOTE: Part of the NodeTraverser interface.
2811
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2812
        *lnwire.FeatureVector, error) {
×
2813

×
2814
        ctx := context.TODO()
×
2815

×
2816
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2817
}
×
2818

2819
// forEachNodeDirectedChannel iterates through all channels of a given
2820
// node, executing the passed callback on the directed edge representing the
2821
// channel and its incoming policy. If the node is not found, no error is
2822
// returned.
2823
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2824
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2825

×
2826
        toNodeCallback := func() route.Vertex {
×
2827
                return nodePub
×
2828
        }
×
2829

2830
        dbID, err := db.GetNodeIDByPubKey(
×
2831
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2832
                        Version: int16(ProtocolV1),
×
2833
                        PubKey:  nodePub[:],
×
2834
                },
×
2835
        )
×
2836
        if errors.Is(err, sql.ErrNoRows) {
×
2837
                return nil
×
2838
        } else if err != nil {
×
2839
                return fmt.Errorf("unable to fetch node: %w", err)
×
2840
        }
×
2841

2842
        rows, err := db.ListChannelsByNodeID(
×
2843
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2844
                        Version: int16(ProtocolV1),
×
2845
                        NodeID1: dbID,
×
2846
                },
×
2847
        )
×
2848
        if err != nil {
×
2849
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2850
        }
×
2851

2852
        // Exit early if there are no channels for this node so we don't
2853
        // do the unnecessary feature fetching.
2854
        if len(rows) == 0 {
×
2855
                return nil
×
2856
        }
×
2857

2858
        features, err := getNodeFeatures(ctx, db, dbID)
×
2859
        if err != nil {
×
2860
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2861
        }
×
2862

2863
        for _, row := range rows {
×
2864
                node1, node2, err := buildNodeVertices(
×
2865
                        row.Node1Pubkey, row.Node2Pubkey,
×
2866
                )
×
2867
                if err != nil {
×
2868
                        return fmt.Errorf("unable to build node vertices: %w",
×
2869
                                err)
×
2870
                }
×
2871

2872
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
2873

×
2874
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2875
                if err != nil {
×
2876
                        return err
×
2877
                }
×
2878

2879
                var p1, p2 *models.CachedEdgePolicy
×
2880
                if dbPol1 != nil {
×
2881
                        policy1, err := buildChanPolicy(
×
2882
                                *dbPol1, edge.ChannelID, nil, node2,
×
2883
                        )
×
2884
                        if err != nil {
×
2885
                                return err
×
2886
                        }
×
2887

2888
                        p1 = models.NewCachedPolicy(policy1)
×
2889
                }
2890
                if dbPol2 != nil {
×
2891
                        policy2, err := buildChanPolicy(
×
2892
                                *dbPol2, edge.ChannelID, nil, node1,
×
2893
                        )
×
2894
                        if err != nil {
×
2895
                                return err
×
2896
                        }
×
2897

2898
                        p2 = models.NewCachedPolicy(policy2)
×
2899
                }
2900

2901
                // Determine the outgoing and incoming policy for this
2902
                // channel and node combo.
2903
                outPolicy, inPolicy := p1, p2
×
2904
                if p1 != nil && node2 == nodePub {
×
2905
                        outPolicy, inPolicy = p2, p1
×
2906
                } else if p2 != nil && node1 != nodePub {
×
2907
                        outPolicy, inPolicy = p2, p1
×
2908
                }
×
2909

2910
                var cachedInPolicy *models.CachedEdgePolicy
×
2911
                if inPolicy != nil {
×
2912
                        cachedInPolicy = inPolicy
×
2913
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2914
                        cachedInPolicy.ToNodeFeatures = features
×
2915
                }
×
2916

2917
                directedChannel := &DirectedChannel{
×
2918
                        ChannelID:    edge.ChannelID,
×
2919
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2920
                        OtherNode:    edge.NodeKey2Bytes,
×
2921
                        Capacity:     edge.Capacity,
×
2922
                        OutPolicySet: outPolicy != nil,
×
2923
                        InPolicy:     cachedInPolicy,
×
2924
                }
×
2925
                if outPolicy != nil {
×
2926
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2927
                                directedChannel.InboundFee = fee
×
2928
                        })
×
2929
                }
2930

2931
                if nodePub == edge.NodeKey2Bytes {
×
2932
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2933
                }
×
2934

2935
                if err := cb(directedChannel); err != nil {
×
2936
                        return err
×
2937
                }
×
2938
        }
2939

2940
        return nil
×
2941
}
2942

2943
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2944
// and executes the provided callback for each node.
2945
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
2946
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
2947

×
2948
        lastID := int64(-1)
×
2949

×
2950
        for {
×
2951
                nodes, err := db.ListNodeIDsAndPubKeys(
×
2952
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
2953
                                Version: int16(ProtocolV1),
×
2954
                                ID:      lastID,
×
2955
                                Limit:   pageSize,
×
2956
                        },
×
2957
                )
×
2958
                if err != nil {
×
2959
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
2960
                }
×
2961

2962
                if len(nodes) == 0 {
×
2963
                        break
×
2964
                }
2965

2966
                for _, node := range nodes {
×
2967
                        var pub route.Vertex
×
2968
                        copy(pub[:], node.PubKey)
×
2969

×
2970
                        if err := cb(node.ID, pub); err != nil {
×
2971
                                return fmt.Errorf("forEachNodeCacheable "+
×
2972
                                        "callback failed for node(id=%d): %w",
×
2973
                                        node.ID, err)
×
2974
                        }
×
2975

2976
                        lastID = node.ID
×
2977
                }
2978
        }
2979

2980
        return nil
×
2981
}
2982

2983
// forEachNodeChannel iterates through all channels of a node, executing
2984
// the passed callback on each. The call-back is provided with the channel's
2985
// edge information, the outgoing policy and the incoming policy for the
2986
// channel and node combo.
2987
func forEachNodeChannel(ctx context.Context, db SQLQueries,
2988
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
2989
                *models.ChannelEdgePolicy,
2990
                *models.ChannelEdgePolicy) error) error {
×
2991

×
2992
        // Get all the V1 channels for this node.Add commentMore actions
×
2993
        rows, err := db.ListChannelsByNodeID(
×
2994
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2995
                        Version: int16(ProtocolV1),
×
2996
                        NodeID1: id,
×
2997
                },
×
2998
        )
×
2999
        if err != nil {
×
3000
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3001
        }
×
3002

3003
        // Call the call-back for each channel and its known policies.
3004
        for _, row := range rows {
×
3005
                node1, node2, err := buildNodeVertices(
×
3006
                        row.Node1Pubkey, row.Node2Pubkey,
×
3007
                )
×
3008
                if err != nil {
×
3009
                        return fmt.Errorf("unable to build node vertices: %w",
×
3010
                                err)
×
3011
                }
×
3012

3013
                edge, err := getAndBuildEdgeInfo(
×
3014
                        ctx, db, chain, row.Channel.ID, row.Channel, node1,
×
3015
                        node2,
×
3016
                )
×
3017
                if err != nil {
×
3018
                        return fmt.Errorf("unable to build channel info: %w",
×
3019
                                err)
×
3020
                }
×
3021

3022
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3023
                if err != nil {
×
3024
                        return fmt.Errorf("unable to extract channel "+
×
3025
                                "policies: %w", err)
×
3026
                }
×
3027

3028
                p1, p2, err := getAndBuildChanPolicies(
×
3029
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3030
                )
×
3031
                if err != nil {
×
3032
                        return fmt.Errorf("unable to build channel "+
×
3033
                                "policies: %w", err)
×
3034
                }
×
3035

3036
                // Determine the outgoing and incoming policy for this
3037
                // channel and node combo.
3038
                p1ToNode := row.Channel.NodeID2
×
3039
                p2ToNode := row.Channel.NodeID1
×
3040
                outPolicy, inPolicy := p1, p2
×
3041
                if (p1 != nil && p1ToNode == id) ||
×
3042
                        (p2 != nil && p2ToNode != id) {
×
3043

×
3044
                        outPolicy, inPolicy = p2, p1
×
3045
                }
×
3046

3047
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3048
                        return err
×
3049
                }
×
3050
        }
3051

3052
        return nil
×
3053
}
3054

3055
// updateChanEdgePolicy upserts the channel policy info we have stored for
3056
// a channel we already know of.
3057
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3058
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3059
        error) {
×
3060

×
3061
        var (
×
3062
                node1Pub, node2Pub route.Vertex
×
3063
                isNode1            bool
×
3064
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3065
        )
×
3066

×
3067
        // Check that this edge policy refers to a channel that we already
×
3068
        // know of. We do this explicitly so that we can return the appropriate
×
3069
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3070
        // abort the transaction which would abort the entire batch.
×
3071
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3072
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3073
                        Scid:    chanIDB,
×
3074
                        Version: int16(ProtocolV1),
×
3075
                },
×
3076
        )
×
3077
        if errors.Is(err, sql.ErrNoRows) {
×
3078
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3079
        } else if err != nil {
×
3080
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3081
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3082
        }
×
3083

3084
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3085
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3086

×
3087
        // Figure out which node this edge is from.
×
3088
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3089
        nodeID := dbChan.NodeID1
×
3090
        if !isNode1 {
×
3091
                nodeID = dbChan.NodeID2
×
3092
        }
×
3093

3094
        var (
×
3095
                inboundBase sql.NullInt64
×
3096
                inboundRate sql.NullInt64
×
3097
        )
×
3098
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3099
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3100
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3101
        })
×
3102

3103
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3104
                Version:     int16(ProtocolV1),
×
3105
                ChannelID:   dbChan.ID,
×
3106
                NodeID:      nodeID,
×
3107
                Timelock:    int32(edge.TimeLockDelta),
×
3108
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3109
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3110
                MinHtlcMsat: int64(edge.MinHTLC),
×
3111
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3112
                Disabled: sql.NullBool{
×
3113
                        Valid: true,
×
3114
                        Bool:  edge.IsDisabled(),
×
3115
                },
×
3116
                MaxHtlcMsat: sql.NullInt64{
×
3117
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3118
                        Int64: int64(edge.MaxHTLC),
×
3119
                },
×
3120
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3121
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3122
                InboundBaseFeeMsat:      inboundBase,
×
3123
                InboundFeeRateMilliMsat: inboundRate,
×
3124
                Signature:               edge.SigBytes,
×
3125
        })
×
3126
        if err != nil {
×
3127
                return node1Pub, node2Pub, isNode1,
×
3128
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3129
        }
×
3130

3131
        // Convert the flat extra opaque data into a map of TLV types to
3132
        // values.
3133
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3134
        if err != nil {
×
3135
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3136
                        "marshal extra opaque data: %w", err)
×
3137
        }
×
3138

3139
        // Update the channel policy's extra signed fields.
3140
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3141
        if err != nil {
×
3142
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3143
                        "policy extra TLVs: %w", err)
×
3144
        }
×
3145

3146
        return node1Pub, node2Pub, isNode1, nil
×
3147
}
3148

3149
// getNodeByPubKey attempts to look up a target node by its public key.
3150
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3151
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3152

×
3153
        dbNode, err := db.GetNodeByPubKey(
×
3154
                ctx, sqlc.GetNodeByPubKeyParams{
×
3155
                        Version: int16(ProtocolV1),
×
3156
                        PubKey:  pubKey[:],
×
3157
                },
×
3158
        )
×
3159
        if errors.Is(err, sql.ErrNoRows) {
×
3160
                return 0, nil, ErrGraphNodeNotFound
×
3161
        } else if err != nil {
×
3162
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3163
        }
×
3164

3165
        node, err := buildNode(ctx, db, &dbNode)
×
3166
        if err != nil {
×
3167
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3168
        }
×
3169

3170
        return dbNode.ID, node, nil
×
3171
}
3172

3173
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3174
// provided database channel row and the public keys of the two nodes
3175
// involved in the channel.
3176
func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub,
3177
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3178

×
3179
        return &models.CachedEdgeInfo{
×
3180
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
3181
                NodeKey1Bytes: node1Pub,
×
3182
                NodeKey2Bytes: node2Pub,
×
3183
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
3184
        }
×
3185
}
×
3186

3187
// buildNode constructs a LightningNode instance from the given database node
3188
// record. The node's features, addresses and extra signed fields are also
3189
// fetched from the database and set on the node.
3190
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
3191
        *models.LightningNode, error) {
×
3192

×
3193
        if dbNode.Version != int16(ProtocolV1) {
×
3194
                return nil, fmt.Errorf("unsupported node version: %d",
×
3195
                        dbNode.Version)
×
3196
        }
×
3197

3198
        var pub [33]byte
×
3199
        copy(pub[:], dbNode.PubKey)
×
3200

×
3201
        node := &models.LightningNode{
×
3202
                PubKeyBytes: pub,
×
3203
                Features:    lnwire.EmptyFeatureVector(),
×
3204
                LastUpdate:  time.Unix(0, 0),
×
3205
        }
×
3206

×
3207
        if len(dbNode.Signature) == 0 {
×
3208
                return node, nil
×
3209
        }
×
3210

3211
        node.HaveNodeAnnouncement = true
×
3212
        node.AuthSigBytes = dbNode.Signature
×
3213
        node.Alias = dbNode.Alias.String
×
3214
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3215

×
3216
        var err error
×
3217
        if dbNode.Color.Valid {
×
3218
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3219
                if err != nil {
×
3220
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3221
                                err)
×
3222
                }
×
3223
        }
3224

3225
        // Fetch the node's features.
3226
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
3227
        if err != nil {
×
3228
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3229
                        "features: %w", dbNode.ID, err)
×
3230
        }
×
3231

3232
        // Fetch the node's addresses.
3233
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
3234
        if err != nil {
×
3235
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3236
                        "addresses: %w", dbNode.ID, err)
×
3237
        }
×
3238

3239
        // Fetch the node's extra signed fields.
3240
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
3241
        if err != nil {
×
3242
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3243
                        "extra signed fields: %w", dbNode.ID, err)
×
3244
        }
×
3245

3246
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
3247
        if err != nil {
×
3248
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3249
                        "fields: %w", err)
×
3250
        }
×
3251

3252
        if len(recs) != 0 {
×
3253
                node.ExtraOpaqueData = recs
×
3254
        }
×
3255

3256
        return node, nil
×
3257
}
3258

3259
// getNodeFeatures fetches the feature bits and constructs the feature vector
3260
// for a node with the given DB ID.
3261
func getNodeFeatures(ctx context.Context, db SQLQueries,
3262
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3263

×
3264
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3265
        if err != nil {
×
3266
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3267
                        nodeID, err)
×
3268
        }
×
3269

3270
        features := lnwire.EmptyFeatureVector()
×
3271
        for _, feature := range rows {
×
3272
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3273
        }
×
3274

3275
        return features, nil
×
3276
}
3277

3278
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3279
// given DB ID.
3280
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3281
        nodeID int64) (map[uint64][]byte, error) {
×
3282

×
3283
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3284
        if err != nil {
×
3285
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
3286
                        "signed fields: %w", nodeID, err)
×
3287
        }
×
3288

3289
        extraFields := make(map[uint64][]byte)
×
3290
        for _, field := range fields {
×
3291
                extraFields[uint64(field.Type)] = field.Value
×
3292
        }
×
3293

3294
        return extraFields, nil
×
3295
}
3296

3297
// upsertNode upserts the node record into the database. If the node already
3298
// exists, then the node's information is updated. If the node doesn't exist,
3299
// then a new node is created. The node's features, addresses and extra TLV
3300
// types are also updated. The node's DB ID is returned.
3301
func upsertNode(ctx context.Context, db SQLQueries,
3302
        node *models.LightningNode) (int64, error) {
×
3303

×
3304
        params := sqlc.UpsertNodeParams{
×
3305
                Version: int16(ProtocolV1),
×
3306
                PubKey:  node.PubKeyBytes[:],
×
3307
        }
×
3308

×
3309
        if node.HaveNodeAnnouncement {
×
3310
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3311
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3312
                params.Alias = sqldb.SQLStr(node.Alias)
×
3313
                params.Signature = node.AuthSigBytes
×
3314
        }
×
3315

3316
        nodeID, err := db.UpsertNode(ctx, params)
×
3317
        if err != nil {
×
3318
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3319
                        err)
×
3320
        }
×
3321

3322
        // We can exit here if we don't have the announcement yet.
3323
        if !node.HaveNodeAnnouncement {
×
3324
                return nodeID, nil
×
3325
        }
×
3326

3327
        // Update the node's features.
3328
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3329
        if err != nil {
×
3330
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3331
        }
×
3332

3333
        // Update the node's addresses.
3334
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3335
        if err != nil {
×
3336
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3337
        }
×
3338

3339
        // Convert the flat extra opaque data into a map of TLV types to
3340
        // values.
3341
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3342
        if err != nil {
×
3343
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3344
                        err)
×
3345
        }
×
3346

3347
        // Update the node's extra signed fields.
3348
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3349
        if err != nil {
×
3350
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3351
        }
×
3352

3353
        return nodeID, nil
×
3354
}
3355

3356
// upsertNodeFeatures updates the node's features node_features table. This
3357
// includes deleting any feature bits no longer present and inserting any new
3358
// feature bits. If the feature bit does not yet exist in the features table,
3359
// then an entry is created in that table first.
3360
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3361
        features *lnwire.FeatureVector) error {
×
3362

×
3363
        // Get any existing features for the node.
×
3364
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3365
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3366
                return err
×
3367
        }
×
3368

3369
        // Copy the nodes latest set of feature bits.
3370
        newFeatures := make(map[int32]struct{})
×
3371
        if features != nil {
×
3372
                for feature := range features.Features() {
×
3373
                        newFeatures[int32(feature)] = struct{}{}
×
3374
                }
×
3375
        }
3376

3377
        // For any current feature that already exists in the DB, remove it from
3378
        // the in-memory map. For any existing feature that does not exist in
3379
        // the in-memory map, delete it from the database.
3380
        for _, feature := range existingFeatures {
×
3381
                // The feature is still present, so there are no updates to be
×
3382
                // made.
×
3383
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3384
                        delete(newFeatures, feature.FeatureBit)
×
3385
                        continue
×
3386
                }
3387

3388
                // The feature is no longer present, so we remove it from the
3389
                // database.
3390
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3391
                        NodeID:     nodeID,
×
3392
                        FeatureBit: feature.FeatureBit,
×
3393
                })
×
3394
                if err != nil {
×
3395
                        return fmt.Errorf("unable to delete node(%d) "+
×
3396
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3397
                                err)
×
3398
                }
×
3399
        }
3400

3401
        // Any remaining entries in newFeatures are new features that need to be
3402
        // added to the database for the first time.
3403
        for feature := range newFeatures {
×
3404
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3405
                        NodeID:     nodeID,
×
3406
                        FeatureBit: feature,
×
3407
                })
×
3408
                if err != nil {
×
3409
                        return fmt.Errorf("unable to insert node(%d) "+
×
3410
                                "feature(%v): %w", nodeID, feature, err)
×
3411
                }
×
3412
        }
3413

3414
        return nil
×
3415
}
3416

3417
// fetchNodeFeatures fetches the features for a node with the given public key.
3418
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3419
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3420

×
3421
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3422
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3423
                        PubKey:  nodePub[:],
×
3424
                        Version: int16(ProtocolV1),
×
3425
                },
×
3426
        )
×
3427
        if err != nil {
×
3428
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3429
                        nodePub, err)
×
3430
        }
×
3431

3432
        features := lnwire.EmptyFeatureVector()
×
3433
        for _, bit := range rows {
×
3434
                features.Set(lnwire.FeatureBit(bit))
×
3435
        }
×
3436

3437
        return features, nil
×
3438
}
3439

3440
// dbAddressType is an enum type that represents the different address types
3441
// that we store in the node_addresses table. The address type determines how
3442
// the address is to be serialised/deserialize.
3443
type dbAddressType uint8
3444

3445
const (
3446
        addressTypeIPv4   dbAddressType = 1
3447
        addressTypeIPv6   dbAddressType = 2
3448
        addressTypeTorV2  dbAddressType = 3
3449
        addressTypeTorV3  dbAddressType = 4
3450
        addressTypeOpaque dbAddressType = math.MaxInt8
3451
)
3452

3453
// upsertNodeAddresses updates the node's addresses in the database. This
3454
// includes deleting any existing addresses and inserting the new set of
3455
// addresses. The deletion is necessary since the ordering of the addresses may
3456
// change, and we need to ensure that the database reflects the latest set of
3457
// addresses so that at the time of reconstructing the node announcement, the
3458
// order is preserved and the signature over the message remains valid.
3459
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3460
        addresses []net.Addr) error {
×
3461

×
3462
        // Delete any existing addresses for the node. This is required since
×
3463
        // even if the new set of addresses is the same, the ordering may have
×
3464
        // changed for a given address type.
×
3465
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3466
        if err != nil {
×
3467
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3468
                        nodeID, err)
×
3469
        }
×
3470

3471
        // Copy the nodes latest set of addresses.
3472
        newAddresses := map[dbAddressType][]string{
×
3473
                addressTypeIPv4:   {},
×
3474
                addressTypeIPv6:   {},
×
3475
                addressTypeTorV2:  {},
×
3476
                addressTypeTorV3:  {},
×
3477
                addressTypeOpaque: {},
×
3478
        }
×
3479
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3480
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3481
        }
×
3482

3483
        for _, address := range addresses {
×
3484
                switch addr := address.(type) {
×
3485
                case *net.TCPAddr:
×
3486
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3487
                                addAddr(addressTypeIPv4, addr)
×
3488
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3489
                                addAddr(addressTypeIPv6, addr)
×
3490
                        } else {
×
3491
                                return fmt.Errorf("unhandled IP address: %v",
×
3492
                                        addr)
×
3493
                        }
×
3494

3495
                case *tor.OnionAddr:
×
3496
                        switch len(addr.OnionService) {
×
3497
                        case tor.V2Len:
×
3498
                                addAddr(addressTypeTorV2, addr)
×
3499
                        case tor.V3Len:
×
3500
                                addAddr(addressTypeTorV3, addr)
×
3501
                        default:
×
3502
                                return fmt.Errorf("invalid length for a tor " +
×
3503
                                        "address")
×
3504
                        }
3505

3506
                case *lnwire.OpaqueAddrs:
×
3507
                        addAddr(addressTypeOpaque, addr)
×
3508

3509
                default:
×
3510
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3511
                }
3512
        }
3513

3514
        // Any remaining entries in newAddresses are new addresses that need to
3515
        // be added to the database for the first time.
3516
        for addrType, addrList := range newAddresses {
×
3517
                for position, addr := range addrList {
×
3518
                        err := db.InsertNodeAddress(
×
3519
                                ctx, sqlc.InsertNodeAddressParams{
×
3520
                                        NodeID:   nodeID,
×
3521
                                        Type:     int16(addrType),
×
3522
                                        Address:  addr,
×
3523
                                        Position: int32(position),
×
3524
                                },
×
3525
                        )
×
3526
                        if err != nil {
×
3527
                                return fmt.Errorf("unable to insert "+
×
3528
                                        "node(%d) address(%v): %w", nodeID,
×
3529
                                        addr, err)
×
3530
                        }
×
3531
                }
3532
        }
3533

3534
        return nil
×
3535
}
3536

3537
// getNodeAddresses fetches the addresses for a node with the given public key.
3538
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3539
        []net.Addr, error) {
×
3540

×
3541
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
3542
        // are returned in the same order as they were inserted.
×
3543
        rows, err := db.GetNodeAddressesByPubKey(
×
3544
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
3545
                        Version: int16(ProtocolV1),
×
3546
                        PubKey:  nodePub,
×
3547
                },
×
3548
        )
×
3549
        if err != nil {
×
3550
                return false, nil, err
×
3551
        }
×
3552

3553
        // GetNodeAddressesByPubKey uses a left join so there should always be
3554
        // at least one row returned if the node exists even if it has no
3555
        // addresses.
3556
        if len(rows) == 0 {
×
3557
                return false, nil, nil
×
3558
        }
×
3559

3560
        addresses := make([]net.Addr, 0, len(rows))
×
3561
        for _, addr := range rows {
×
3562
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3563
                        continue
×
3564
                }
3565

3566
                address := addr.Address.String
×
3567

×
3568
                switch dbAddressType(addr.Type.Int16) {
×
3569
                case addressTypeIPv4:
×
3570
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3571
                        if err != nil {
×
3572
                                return false, nil, nil
×
3573
                        }
×
3574
                        tcp.IP = tcp.IP.To4()
×
3575

×
3576
                        addresses = append(addresses, tcp)
×
3577

3578
                case addressTypeIPv6:
×
3579
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3580
                        if err != nil {
×
3581
                                return false, nil, nil
×
3582
                        }
×
3583
                        addresses = append(addresses, tcp)
×
3584

3585
                case addressTypeTorV3, addressTypeTorV2:
×
3586
                        service, portStr, err := net.SplitHostPort(address)
×
3587
                        if err != nil {
×
3588
                                return false, nil, fmt.Errorf("unable to "+
×
3589
                                        "split tor v3 address: %v",
×
3590
                                        addr.Address)
×
3591
                        }
×
3592

3593
                        port, err := strconv.Atoi(portStr)
×
3594
                        if err != nil {
×
3595
                                return false, nil, err
×
3596
                        }
×
3597

3598
                        addresses = append(addresses, &tor.OnionAddr{
×
3599
                                OnionService: service,
×
3600
                                Port:         port,
×
3601
                        })
×
3602

3603
                case addressTypeOpaque:
×
3604
                        opaque, err := hex.DecodeString(address)
×
3605
                        if err != nil {
×
3606
                                return false, nil, fmt.Errorf("unable to "+
×
3607
                                        "decode opaque address: %v", addr)
×
3608
                        }
×
3609

3610
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3611
                                Payload: opaque,
×
3612
                        })
×
3613

3614
                default:
×
3615
                        return false, nil, fmt.Errorf("unknown address "+
×
3616
                                "type: %v", addr.Type)
×
3617
                }
3618
        }
3619

3620
        // If we have no addresses, then we'll return nil instead of an
3621
        // empty slice.
3622
        if len(addresses) == 0 {
×
3623
                addresses = nil
×
3624
        }
×
3625

3626
        return true, addresses, nil
×
3627
}
3628

3629
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3630
// database. This includes updating any existing types, inserting any new types,
3631
// and deleting any types that are no longer present.
3632
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3633
        nodeID int64, extraFields map[uint64][]byte) error {
×
3634

×
3635
        // Get any existing extra signed fields for the node.
×
3636
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3637
        if err != nil {
×
3638
                return err
×
3639
        }
×
3640

3641
        // Make a lookup map of the existing field types so that we can use it
3642
        // to keep track of any fields we should delete.
3643
        m := make(map[uint64]bool)
×
3644
        for _, field := range existingFields {
×
3645
                m[uint64(field.Type)] = true
×
3646
        }
×
3647

3648
        // For all the new fields, we'll upsert them and remove them from the
3649
        // map of existing fields.
3650
        for tlvType, value := range extraFields {
×
3651
                err = db.UpsertNodeExtraType(
×
3652
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3653
                                NodeID: nodeID,
×
3654
                                Type:   int64(tlvType),
×
3655
                                Value:  value,
×
3656
                        },
×
3657
                )
×
3658
                if err != nil {
×
3659
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3660
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3661
                }
×
3662

3663
                // Remove the field from the map of existing fields if it was
3664
                // present.
3665
                delete(m, tlvType)
×
3666
        }
3667

3668
        // For all the fields that are left in the map of existing fields, we'll
3669
        // delete them as they are no longer present in the new set of fields.
3670
        for tlvType := range m {
×
3671
                err = db.DeleteExtraNodeType(
×
3672
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3673
                                NodeID: nodeID,
×
3674
                                Type:   int64(tlvType),
×
3675
                        },
×
3676
                )
×
3677
                if err != nil {
×
3678
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3679
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3680
                }
×
3681
        }
3682

3683
        return nil
×
3684
}
3685

3686
// srcNodeInfo holds the information about the source node of the graph.
3687
type srcNodeInfo struct {
3688
        // id is the DB level ID of the source node entry in the "nodes" table.
3689
        id int64
3690

3691
        // pub is the public key of the source node.
3692
        pub route.Vertex
3693
}
3694

3695
// sourceNode returns the DB node ID and pub key of the source node for the
3696
// specified protocol version.
3697
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3698
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3699

×
3700
        s.srcNodeMu.Lock()
×
3701
        defer s.srcNodeMu.Unlock()
×
3702

×
3703
        // If we already have the source node ID and pub key cached, then
×
3704
        // return them.
×
3705
        if info, ok := s.srcNodes[version]; ok {
×
3706
                return info.id, info.pub, nil
×
3707
        }
×
3708

3709
        var pubKey route.Vertex
×
3710

×
3711
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3712
        if err != nil {
×
3713
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3714
                        err)
×
3715
        }
×
3716

3717
        if len(nodes) == 0 {
×
3718
                return 0, pubKey, ErrSourceNodeNotSet
×
3719
        } else if len(nodes) > 1 {
×
3720
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3721
                        "protocol %s found", version)
×
3722
        }
×
3723

3724
        copy(pubKey[:], nodes[0].PubKey)
×
3725

×
3726
        s.srcNodes[version] = &srcNodeInfo{
×
3727
                id:  nodes[0].NodeID,
×
3728
                pub: pubKey,
×
3729
        }
×
3730

×
3731
        return nodes[0].NodeID, pubKey, nil
×
3732
}
3733

3734
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3735
// This then produces a map from TLV type to value. If the input is not a
3736
// valid TLV stream, then an error is returned.
3737
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3738
        r := bytes.NewReader(data)
×
3739

×
3740
        tlvStream, err := tlv.NewStream()
×
3741
        if err != nil {
×
3742
                return nil, err
×
3743
        }
×
3744

3745
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3746
        // pass it into the P2P decoding variant.
3747
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3748
        if err != nil {
×
3749
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3750
        }
×
3751
        if len(parsedTypes) == 0 {
×
3752
                return nil, nil
×
3753
        }
×
3754

3755
        records := make(map[uint64][]byte)
×
3756
        for k, v := range parsedTypes {
×
3757
                records[uint64(k)] = v
×
3758
        }
×
3759

3760
        return records, nil
×
3761
}
3762

3763
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3764
// channel.
3765
type dbChanInfo struct {
3766
        channelID int64
3767
        node1ID   int64
3768
        node2ID   int64
3769
}
3770

3771
// insertChannel inserts a new channel record into the database.
3772
func insertChannel(ctx context.Context, db SQLQueries,
3773
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3774

×
3775
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3776

×
3777
        // Make sure that the channel doesn't already exist. We do this
×
3778
        // explicitly instead of relying on catching a unique constraint error
×
3779
        // because relying on SQL to throw that error would abort the entire
×
3780
        // batch of transactions.
×
3781
        _, err := db.GetChannelBySCID(
×
3782
                ctx, sqlc.GetChannelBySCIDParams{
×
3783
                        Scid:    chanIDB,
×
3784
                        Version: int16(ProtocolV1),
×
3785
                },
×
3786
        )
×
3787
        if err == nil {
×
3788
                return nil, ErrEdgeAlreadyExist
×
3789
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3790
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3791
        }
×
3792

3793
        // Make sure that at least a "shell" entry for each node is present in
3794
        // the nodes table.
3795
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3796
        if err != nil {
×
3797
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3798
        }
×
3799

3800
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3801
        if err != nil {
×
3802
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3803
        }
×
3804

3805
        var capacity sql.NullInt64
×
3806
        if edge.Capacity != 0 {
×
3807
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3808
        }
×
3809

3810
        createParams := sqlc.CreateChannelParams{
×
3811
                Version:     int16(ProtocolV1),
×
3812
                Scid:        chanIDB,
×
3813
                NodeID1:     node1DBID,
×
3814
                NodeID2:     node2DBID,
×
3815
                Outpoint:    edge.ChannelPoint.String(),
×
3816
                Capacity:    capacity,
×
3817
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3818
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3819
        }
×
3820

×
3821
        if edge.AuthProof != nil {
×
3822
                proof := edge.AuthProof
×
3823

×
3824
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3825
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3826
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3827
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3828
        }
×
3829

3830
        // Insert the new channel record.
3831
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3832
        if err != nil {
×
3833
                return nil, err
×
3834
        }
×
3835

3836
        // Insert any channel features.
NEW
3837
        for feature := range edge.Features.Features() {
×
NEW
3838
                err = db.InsertChannelFeature(
×
NEW
3839
                        ctx, sqlc.InsertChannelFeatureParams{
×
NEW
3840
                                ChannelID:  dbChanID,
×
NEW
3841
                                FeatureBit: int32(feature),
×
NEW
3842
                        },
×
NEW
3843
                )
×
3844
                if err != nil {
×
NEW
3845
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
NEW
3846
                                "feature(%v): %w", dbChanID, feature, err)
×
UNCOV
3847
                }
×
3848
        }
3849

3850
        // Finally, insert any extra TLV fields in the channel announcement.
3851
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3852
        if err != nil {
×
3853
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
3854
                        "data: %w", err)
×
3855
        }
×
3856

3857
        for tlvType, value := range extra {
×
3858
                err := db.CreateChannelExtraType(
×
3859
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3860
                                ChannelID: dbChanID,
×
3861
                                Type:      int64(tlvType),
×
3862
                                Value:     value,
×
3863
                        },
×
3864
                )
×
3865
                if err != nil {
×
3866
                        return nil, fmt.Errorf("unable to upsert "+
×
3867
                                "channel(%d) extra signed field(%v): %w",
×
3868
                                edge.ChannelID, tlvType, err)
×
3869
                }
×
3870
        }
3871

3872
        return &dbChanInfo{
×
3873
                channelID: dbChanID,
×
3874
                node1ID:   node1DBID,
×
3875
                node2ID:   node2DBID,
×
3876
        }, nil
×
3877
}
3878

3879
// maybeCreateShellNode checks if a shell node entry exists for the
3880
// given public key. If it does not exist, then a new shell node entry is
3881
// created. The ID of the node is returned. A shell node only has a protocol
3882
// version and public key persisted.
3883
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3884
        pubKey route.Vertex) (int64, error) {
×
3885

×
3886
        dbNode, err := db.GetNodeByPubKey(
×
3887
                ctx, sqlc.GetNodeByPubKeyParams{
×
3888
                        PubKey:  pubKey[:],
×
3889
                        Version: int16(ProtocolV1),
×
3890
                },
×
3891
        )
×
3892
        // The node exists. Return the ID.
×
3893
        if err == nil {
×
3894
                return dbNode.ID, nil
×
3895
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3896
                return 0, err
×
3897
        }
×
3898

3899
        // Otherwise, the node does not exist, so we create a shell entry for
3900
        // it.
3901
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3902
                Version: int16(ProtocolV1),
×
3903
                PubKey:  pubKey[:],
×
3904
        })
×
3905
        if err != nil {
×
3906
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3907
        }
×
3908

3909
        return id, nil
×
3910
}
3911

3912
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3913
// the database. This includes deleting any existing types and then inserting
3914
// the new types.
3915
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3916
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3917

×
3918
        // Delete all existing extra signed fields for the channel policy.
×
3919
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3920
        if err != nil {
×
3921
                return fmt.Errorf("unable to delete "+
×
3922
                        "existing policy extra signed fields for policy %d: %w",
×
3923
                        chanPolicyID, err)
×
3924
        }
×
3925

3926
        // Insert all new extra signed fields for the channel policy.
3927
        for tlvType, value := range extraFields {
×
3928
                err = db.InsertChanPolicyExtraType(
×
3929
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3930
                                ChannelPolicyID: chanPolicyID,
×
3931
                                Type:            int64(tlvType),
×
3932
                                Value:           value,
×
3933
                        },
×
3934
                )
×
3935
                if err != nil {
×
3936
                        return fmt.Errorf("unable to insert "+
×
3937
                                "channel_policy(%d) extra signed field(%v): %w",
×
3938
                                chanPolicyID, tlvType, err)
×
3939
                }
×
3940
        }
3941

3942
        return nil
×
3943
}
3944

3945
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3946
// provided dbChanRow and also fetches any other required information
3947
// to construct the edge info.
3948
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
3949
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1,
3950
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3951

×
3952
        if dbChan.Version != int16(ProtocolV1) {
×
3953
                return nil, fmt.Errorf("unsupported channel version: %d",
×
3954
                        dbChan.Version)
×
3955
        }
×
3956

3957
        fv, extras, err := getChanFeaturesAndExtras(
×
3958
                ctx, db, dbChanID,
×
3959
        )
×
3960
        if err != nil {
×
3961
                return nil, err
×
3962
        }
×
3963

3964
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
3965
        if err != nil {
×
3966
                return nil, err
×
3967
        }
×
3968

3969
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3970
        if err != nil {
×
3971
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3972
                        "fields: %w", err)
×
3973
        }
×
3974
        if recs == nil {
×
3975
                recs = make([]byte, 0)
×
3976
        }
×
3977

3978
        var btcKey1, btcKey2 route.Vertex
×
3979
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
3980
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
3981

×
3982
        channel := &models.ChannelEdgeInfo{
×
3983
                ChainHash:        chain,
×
3984
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
3985
                NodeKey1Bytes:    node1,
×
3986
                NodeKey2Bytes:    node2,
×
3987
                BitcoinKey1Bytes: btcKey1,
×
3988
                BitcoinKey2Bytes: btcKey2,
×
3989
                ChannelPoint:     *op,
×
3990
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
NEW
3991
                Features:         fv,
×
3992
                ExtraOpaqueData:  recs,
×
3993
        }
×
3994

×
3995
        // We always set all the signatures at the same time, so we can
×
3996
        // safely check if one signature is present to determine if we have the
×
3997
        // rest of the signatures for the auth proof.
×
3998
        if len(dbChan.Bitcoin1Signature) > 0 {
×
3999
                channel.AuthProof = &models.ChannelAuthProof{
×
4000
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4001
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4002
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4003
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4004
                }
×
4005
        }
×
4006

4007
        return channel, nil
×
4008
}
4009

4010
// buildNodeVertices is a helper that converts raw node public keys
4011
// into route.Vertex instances.
4012
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4013
        route.Vertex, error) {
×
4014

×
4015
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4016
        if err != nil {
×
4017
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4018
                        "create vertex from node1 pubkey: %w", err)
×
4019
        }
×
4020

4021
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4022
        if err != nil {
×
4023
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4024
                        "create vertex from node2 pubkey: %w", err)
×
4025
        }
×
4026

4027
        return node1Vertex, node2Vertex, nil
×
4028
}
4029

4030
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
4031
// for a channel with the given ID.
4032
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
4033
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
4034

×
4035
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
4036
        if err != nil {
×
4037
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
4038
                        "features and extras: %w", err)
×
4039
        }
×
4040

4041
        var (
×
4042
                fv     = lnwire.EmptyFeatureVector()
×
4043
                extras = make(map[uint64][]byte)
×
4044
        )
×
4045
        for _, row := range rows {
×
4046
                if row.IsFeature {
×
4047
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
4048

×
4049
                        continue
×
4050
                }
4051

4052
                tlvType, ok := row.ExtraKey.(int64)
×
4053
                if !ok {
×
4054
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4055
                                "TLV type: %T", row.ExtraKey)
×
4056
                }
×
4057

4058
                valueBytes, ok := row.Value.([]byte)
×
4059
                if !ok {
×
4060
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4061
                                "Value: %T", row.Value)
×
4062
                }
×
4063

4064
                extras[uint64(tlvType)] = valueBytes
×
4065
        }
4066

4067
        return fv, extras, nil
×
4068
}
4069

4070
// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves
4071
// all the extra info required to build the complete models.ChannelEdgePolicy
4072
// types. It returns two policies, which may be nil if the provided
4073
// sqlc.ChannelPolicy records are nil.
4074
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4075
        dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1,
4076
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4077
        *models.ChannelEdgePolicy, error) {
×
4078

×
4079
        if dbPol1 == nil && dbPol2 == nil {
×
4080
                return nil, nil, nil
×
4081
        }
×
4082

4083
        var (
×
4084
                policy1ID int64
×
4085
                policy2ID int64
×
4086
        )
×
4087
        if dbPol1 != nil {
×
4088
                policy1ID = dbPol1.ID
×
4089
        }
×
4090
        if dbPol2 != nil {
×
4091
                policy2ID = dbPol2.ID
×
4092
        }
×
4093
        rows, err := db.GetChannelPolicyExtraTypes(
×
4094
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
4095
                        ID:   policy1ID,
×
4096
                        ID_2: policy2ID,
×
4097
                },
×
4098
        )
×
4099
        if err != nil {
×
4100
                return nil, nil, err
×
4101
        }
×
4102

4103
        var (
×
4104
                dbPol1Extras = make(map[uint64][]byte)
×
4105
                dbPol2Extras = make(map[uint64][]byte)
×
4106
        )
×
4107
        for _, row := range rows {
×
4108
                switch row.PolicyID {
×
4109
                case policy1ID:
×
4110
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
4111
                case policy2ID:
×
4112
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
4113
                default:
×
4114
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
4115
                                "in row: %v", row.PolicyID, row)
×
4116
                }
4117
        }
4118

4119
        var pol1, pol2 *models.ChannelEdgePolicy
×
4120
        if dbPol1 != nil {
×
4121
                pol1, err = buildChanPolicy(
×
4122
                        *dbPol1, channelID, dbPol1Extras, node2,
×
4123
                )
×
4124
                if err != nil {
×
4125
                        return nil, nil, err
×
4126
                }
×
4127
        }
4128
        if dbPol2 != nil {
×
4129
                pol2, err = buildChanPolicy(
×
4130
                        *dbPol2, channelID, dbPol2Extras, node1,
×
4131
                )
×
4132
                if err != nil {
×
4133
                        return nil, nil, err
×
4134
                }
×
4135
        }
4136

4137
        return pol1, pol2, nil
×
4138
}
4139

4140
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4141
// provided sqlc.ChannelPolicy and other required information.
4142
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
4143
        extras map[uint64][]byte,
4144
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4145

×
4146
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4147
        if err != nil {
×
4148
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4149
                        "fields: %w", err)
×
4150
        }
×
4151

4152
        var inboundFee fn.Option[lnwire.Fee]
×
4153
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4154
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4155

×
4156
                inboundFee = fn.Some(lnwire.Fee{
×
4157
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4158
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4159
                })
×
4160
        }
×
4161

4162
        return &models.ChannelEdgePolicy{
×
4163
                SigBytes:  dbPolicy.Signature,
×
4164
                ChannelID: channelID,
×
4165
                LastUpdate: time.Unix(
×
4166
                        dbPolicy.LastUpdate.Int64, 0,
×
4167
                ),
×
4168
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4169
                        dbPolicy.MessageFlags,
×
4170
                ),
×
4171
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4172
                        dbPolicy.ChannelFlags,
×
4173
                ),
×
4174
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4175
                MinHTLC: lnwire.MilliSatoshi(
×
4176
                        dbPolicy.MinHtlcMsat,
×
4177
                ),
×
4178
                MaxHTLC: lnwire.MilliSatoshi(
×
4179
                        dbPolicy.MaxHtlcMsat.Int64,
×
4180
                ),
×
4181
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4182
                        dbPolicy.BaseFeeMsat,
×
4183
                ),
×
4184
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4185
                ToNode:                    toNode,
×
4186
                InboundFee:                inboundFee,
×
4187
                ExtraOpaqueData:           recs,
×
4188
        }, nil
×
4189
}
4190

4191
// buildNodes builds the models.LightningNode instances for the
4192
// given row which is expected to be a sqlc type that contains node information.
4193
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4194
        dbNode2 sqlc.Node) (*models.LightningNode, *models.LightningNode,
4195
        error) {
×
4196

×
4197
        node1, err := buildNode(ctx, db, &dbNode1)
×
4198
        if err != nil {
×
4199
                return nil, nil, err
×
4200
        }
×
4201

4202
        node2, err := buildNode(ctx, db, &dbNode2)
×
4203
        if err != nil {
×
4204
                return nil, nil, err
×
4205
        }
×
4206

4207
        return node1, node2, nil
×
4208
}
4209

4210
// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give
4211
// row which is expected to be a sqlc type that contains channel policy
4212
// information. It returns two policies, which may be nil if the policy
4213
// information is not present in the row.
4214
//
4215
//nolint:ll,dupl,funlen
4216
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
4217
        error) {
×
4218

×
4219
        var policy1, policy2 *sqlc.ChannelPolicy
×
4220
        switch r := row.(type) {
×
4221
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4222
                if r.Policy1ID.Valid {
×
4223
                        policy1 = &sqlc.ChannelPolicy{
×
4224
                                ID:                      r.Policy1ID.Int64,
×
4225
                                Version:                 r.Policy1Version.Int16,
×
4226
                                ChannelID:               r.Channel.ID,
×
4227
                                NodeID:                  r.Policy1NodeID.Int64,
×
4228
                                Timelock:                r.Policy1Timelock.Int32,
×
4229
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4230
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4231
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4232
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4233
                                LastUpdate:              r.Policy1LastUpdate,
×
4234
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4235
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4236
                                Disabled:                r.Policy1Disabled,
×
4237
                                MessageFlags:            r.Policy1MessageFlags,
×
4238
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4239
                                Signature:               r.Policy1Signature,
×
4240
                        }
×
4241
                }
×
4242
                if r.Policy2ID.Valid {
×
4243
                        policy2 = &sqlc.ChannelPolicy{
×
4244
                                ID:                      r.Policy2ID.Int64,
×
4245
                                Version:                 r.Policy2Version.Int16,
×
4246
                                ChannelID:               r.Channel.ID,
×
4247
                                NodeID:                  r.Policy2NodeID.Int64,
×
4248
                                Timelock:                r.Policy2Timelock.Int32,
×
4249
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4250
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4251
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4252
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4253
                                LastUpdate:              r.Policy2LastUpdate,
×
4254
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4255
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4256
                                Disabled:                r.Policy2Disabled,
×
4257
                                MessageFlags:            r.Policy2MessageFlags,
×
4258
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4259
                                Signature:               r.Policy2Signature,
×
4260
                        }
×
4261
                }
×
4262

4263
                return policy1, policy2, nil
×
4264

4265
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4266
                if r.Policy1ID.Valid {
×
4267
                        policy1 = &sqlc.ChannelPolicy{
×
4268
                                ID:                      r.Policy1ID.Int64,
×
4269
                                Version:                 r.Policy1Version.Int16,
×
4270
                                ChannelID:               r.Channel.ID,
×
4271
                                NodeID:                  r.Policy1NodeID.Int64,
×
4272
                                Timelock:                r.Policy1Timelock.Int32,
×
4273
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4274
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4275
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4276
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4277
                                LastUpdate:              r.Policy1LastUpdate,
×
4278
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4279
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4280
                                Disabled:                r.Policy1Disabled,
×
4281
                                MessageFlags:            r.Policy1MessageFlags,
×
4282
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4283
                                Signature:               r.Policy1Signature,
×
4284
                        }
×
4285
                }
×
4286
                if r.Policy2ID.Valid {
×
4287
                        policy2 = &sqlc.ChannelPolicy{
×
4288
                                ID:                      r.Policy2ID.Int64,
×
4289
                                Version:                 r.Policy2Version.Int16,
×
4290
                                ChannelID:               r.Channel.ID,
×
4291
                                NodeID:                  r.Policy2NodeID.Int64,
×
4292
                                Timelock:                r.Policy2Timelock.Int32,
×
4293
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4294
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4295
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4296
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4297
                                LastUpdate:              r.Policy2LastUpdate,
×
4298
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4299
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4300
                                Disabled:                r.Policy2Disabled,
×
4301
                                MessageFlags:            r.Policy2MessageFlags,
×
4302
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4303
                                Signature:               r.Policy2Signature,
×
4304
                        }
×
4305
                }
×
4306

4307
                return policy1, policy2, nil
×
4308

4309
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4310
                if r.Policy1ID.Valid {
×
4311
                        policy1 = &sqlc.ChannelPolicy{
×
4312
                                ID:                      r.Policy1ID.Int64,
×
4313
                                Version:                 r.Policy1Version.Int16,
×
4314
                                ChannelID:               r.Channel.ID,
×
4315
                                NodeID:                  r.Policy1NodeID.Int64,
×
4316
                                Timelock:                r.Policy1Timelock.Int32,
×
4317
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4318
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4319
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4320
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4321
                                LastUpdate:              r.Policy1LastUpdate,
×
4322
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4323
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4324
                                Disabled:                r.Policy1Disabled,
×
4325
                                MessageFlags:            r.Policy1MessageFlags,
×
4326
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4327
                                Signature:               r.Policy1Signature,
×
4328
                        }
×
4329
                }
×
4330
                if r.Policy2ID.Valid {
×
4331
                        policy2 = &sqlc.ChannelPolicy{
×
4332
                                ID:                      r.Policy2ID.Int64,
×
4333
                                Version:                 r.Policy2Version.Int16,
×
4334
                                ChannelID:               r.Channel.ID,
×
4335
                                NodeID:                  r.Policy2NodeID.Int64,
×
4336
                                Timelock:                r.Policy2Timelock.Int32,
×
4337
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4338
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4339
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4340
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4341
                                LastUpdate:              r.Policy2LastUpdate,
×
4342
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4343
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4344
                                Disabled:                r.Policy2Disabled,
×
4345
                                MessageFlags:            r.Policy2MessageFlags,
×
4346
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4347
                                Signature:               r.Policy2Signature,
×
4348
                        }
×
4349
                }
×
4350

4351
                return policy1, policy2, nil
×
4352

4353
        case sqlc.ListChannelsByNodeIDRow:
×
4354
                if r.Policy1ID.Valid {
×
4355
                        policy1 = &sqlc.ChannelPolicy{
×
4356
                                ID:                      r.Policy1ID.Int64,
×
4357
                                Version:                 r.Policy1Version.Int16,
×
4358
                                ChannelID:               r.Channel.ID,
×
4359
                                NodeID:                  r.Policy1NodeID.Int64,
×
4360
                                Timelock:                r.Policy1Timelock.Int32,
×
4361
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4362
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4363
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4364
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4365
                                LastUpdate:              r.Policy1LastUpdate,
×
4366
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4367
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4368
                                Disabled:                r.Policy1Disabled,
×
4369
                                MessageFlags:            r.Policy1MessageFlags,
×
4370
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4371
                                Signature:               r.Policy1Signature,
×
4372
                        }
×
4373
                }
×
4374
                if r.Policy2ID.Valid {
×
4375
                        policy2 = &sqlc.ChannelPolicy{
×
4376
                                ID:                      r.Policy2ID.Int64,
×
4377
                                Version:                 r.Policy2Version.Int16,
×
4378
                                ChannelID:               r.Channel.ID,
×
4379
                                NodeID:                  r.Policy2NodeID.Int64,
×
4380
                                Timelock:                r.Policy2Timelock.Int32,
×
4381
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4382
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4383
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4384
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4385
                                LastUpdate:              r.Policy2LastUpdate,
×
4386
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4387
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4388
                                Disabled:                r.Policy2Disabled,
×
4389
                                MessageFlags:            r.Policy2MessageFlags,
×
4390
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4391
                                Signature:               r.Policy2Signature,
×
4392
                        }
×
4393
                }
×
4394

4395
                return policy1, policy2, nil
×
4396

4397
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4398
                if r.Policy1ID.Valid {
×
4399
                        policy1 = &sqlc.ChannelPolicy{
×
4400
                                ID:                      r.Policy1ID.Int64,
×
4401
                                Version:                 r.Policy1Version.Int16,
×
4402
                                ChannelID:               r.Channel.ID,
×
4403
                                NodeID:                  r.Policy1NodeID.Int64,
×
4404
                                Timelock:                r.Policy1Timelock.Int32,
×
4405
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4406
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4407
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4408
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4409
                                LastUpdate:              r.Policy1LastUpdate,
×
4410
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4411
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4412
                                Disabled:                r.Policy1Disabled,
×
4413
                                MessageFlags:            r.Policy1MessageFlags,
×
4414
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4415
                                Signature:               r.Policy1Signature,
×
4416
                        }
×
4417
                }
×
4418
                if r.Policy2ID.Valid {
×
4419
                        policy2 = &sqlc.ChannelPolicy{
×
4420
                                ID:                      r.Policy2ID.Int64,
×
4421
                                Version:                 r.Policy2Version.Int16,
×
4422
                                ChannelID:               r.Channel.ID,
×
4423
                                NodeID:                  r.Policy2NodeID.Int64,
×
4424
                                Timelock:                r.Policy2Timelock.Int32,
×
4425
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4426
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4427
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4428
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4429
                                LastUpdate:              r.Policy2LastUpdate,
×
4430
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4431
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4432
                                Disabled:                r.Policy2Disabled,
×
4433
                                MessageFlags:            r.Policy2MessageFlags,
×
4434
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4435
                                Signature:               r.Policy2Signature,
×
4436
                        }
×
4437
                }
×
4438

4439
                return policy1, policy2, nil
×
4440
        default:
×
4441
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4442
                        "extractChannelPolicies: %T", r)
×
4443
        }
4444
}
4445

4446
// channelIDToBytes converts a channel ID (SCID) to a byte array
4447
// representation.
4448
func channelIDToBytes(channelID uint64) []byte {
×
4449
        var chanIDB [8]byte
×
4450
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4451

×
4452
        return chanIDB[:]
×
4453
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc