• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15993432885

01 Jul 2025 08:01AM UTC coverage: 57.796% (+1.1%) from 56.662%
15993432885

push

github

web-flow
Merge pull request #10006 from ellemouton/fixSQLFetchChannelEdgesByID

graph/db: let FetchChannelEdgesByID behave as promised

0 of 14 new or added lines in 1 file covered. (0.0%)

2 existing lines in 1 file now uncovered.

98443 of 170327 relevant lines covered (57.8%)

1.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        GetUnconnectedNodes(ctx context.Context) ([]sqlc.GetUnconnectedNodesRow, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
72
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
73
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
74

75
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
96
        GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error)
97
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
98
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
99
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
100
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
101
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
102
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
103
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
104
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
105
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
106
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
107
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
108
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
109
        DeleteChannel(ctx context.Context, id int64) error
110

111
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
112
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
113

114
        /*
115
                Channel Policy table queries.
116
        */
117
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
118
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.ChannelPolicy, error)
119
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
120

121
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
122
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
123
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
124

125
        /*
126
                Zombie index queries.
127
        */
128
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
129
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error)
130
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
131
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
132
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
133

134
        /*
135
                Prune log table queries.
136
        */
137
        GetPruneTip(ctx context.Context) (sqlc.PruneLog, error)
138
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
139
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
140

141
        /*
142
                Closed SCID table queries.
143
        */
144
        InsertClosedChannel(ctx context.Context, scid []byte) error
145
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
146
}
147

148
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
149
// database operations.
150
type BatchedSQLQueries interface {
151
        SQLQueries
152
        sqldb.BatchedTx[SQLQueries]
153
}
154

155
// SQLStore is an implementation of the V1Store interface that uses a SQL
156
// database as the backend.
157
//
158
// NOTE: currently, this temporarily embeds the KVStore struct so that we can
159
// implement the V1Store interface incrementally. For any method not
160
// implemented,  things will fall back to the KVStore. This is ONLY the case
161
// for the time being while this struct is purely used in unit tests only.
162
type SQLStore struct {
163
        cfg *SQLStoreConfig
164
        db  BatchedSQLQueries
165

166
        // cacheMu guards all caches (rejectCache and chanCache). If
167
        // this mutex will be acquired at the same time as the DB mutex then
168
        // the cacheMu MUST be acquired first to prevent deadlock.
169
        cacheMu     sync.RWMutex
170
        rejectCache *rejectCache
171
        chanCache   *channelCache
172

173
        chanScheduler batch.Scheduler[SQLQueries]
174
        nodeScheduler batch.Scheduler[SQLQueries]
175

176
        srcNodes  map[ProtocolVersion]*srcNodeInfo
177
        srcNodeMu sync.Mutex
178
}
179

180
// A compile-time assertion to ensure that SQLStore implements the V1Store
181
// interface.
182
var _ V1Store = (*SQLStore)(nil)
183

184
// SQLStoreConfig holds the configuration for the SQLStore.
185
type SQLStoreConfig struct {
186
        // ChainHash is the genesis hash for the chain that all the gossip
187
        // messages in this store are aimed at.
188
        ChainHash chainhash.Hash
189
}
190

191
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
192
// storage backend.
193
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
194
        options ...StoreOptionModifier) (*SQLStore, error) {
×
195

×
196
        opts := DefaultOptions()
×
197
        for _, o := range options {
×
198
                o(opts)
×
199
        }
×
200

201
        if opts.NoMigration {
×
202
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
203
                        "supported for SQL stores")
×
204
        }
×
205

206
        s := &SQLStore{
×
207
                cfg:         cfg,
×
208
                db:          db,
×
209
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
210
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
211
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
212
        }
×
213

×
214
        s.chanScheduler = batch.NewTimeScheduler(
×
215
                db, &s.cacheMu, opts.BatchCommitInterval,
×
216
        )
×
217
        s.nodeScheduler = batch.NewTimeScheduler(
×
218
                db, nil, opts.BatchCommitInterval,
×
219
        )
×
220

×
221
        return s, nil
×
222
}
223

224
// AddLightningNode adds a vertex/node to the graph database. If the node is not
225
// in the database from before, this will add a new, unconnected one to the
226
// graph. If it is present from before, this will update that node's
227
// information.
228
//
229
// NOTE: part of the V1Store interface.
230
func (s *SQLStore) AddLightningNode(ctx context.Context,
231
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
232

×
233
        r := &batch.Request[SQLQueries]{
×
234
                Opts: batch.NewSchedulerOptions(opts...),
×
235
                Do: func(queries SQLQueries) error {
×
236
                        _, err := upsertNode(ctx, queries, node)
×
237
                        return err
×
238
                },
×
239
        }
240

241
        return s.nodeScheduler.Execute(ctx, r)
×
242
}
243

244
// FetchLightningNode attempts to look up a target node by its identity public
245
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
246
// returned.
247
//
248
// NOTE: part of the V1Store interface.
249
func (s *SQLStore) FetchLightningNode(ctx context.Context,
250
        pubKey route.Vertex) (*models.LightningNode, error) {
×
251

×
252
        var node *models.LightningNode
×
253
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
254
                var err error
×
255
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
256

×
257
                return err
×
258
        }, sqldb.NoOpReset)
×
259
        if err != nil {
×
260
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
261
        }
×
262

263
        return node, nil
×
264
}
265

266
// HasLightningNode determines if the graph has a vertex identified by the
267
// target node identity public key. If the node exists in the database, a
268
// timestamp of when the data for the node was lasted updated is returned along
269
// with a true boolean. Otherwise, an empty time.Time is returned with a false
270
// boolean.
271
//
272
// NOTE: part of the V1Store interface.
273
func (s *SQLStore) HasLightningNode(ctx context.Context,
274
        pubKey [33]byte) (time.Time, bool, error) {
×
275

×
276
        var (
×
277
                exists     bool
×
278
                lastUpdate time.Time
×
279
        )
×
280
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
281
                dbNode, err := db.GetNodeByPubKey(
×
282
                        ctx, sqlc.GetNodeByPubKeyParams{
×
283
                                Version: int16(ProtocolV1),
×
284
                                PubKey:  pubKey[:],
×
285
                        },
×
286
                )
×
287
                if errors.Is(err, sql.ErrNoRows) {
×
288
                        return nil
×
289
                } else if err != nil {
×
290
                        return fmt.Errorf("unable to fetch node: %w", err)
×
291
                }
×
292

293
                exists = true
×
294

×
295
                if dbNode.LastUpdate.Valid {
×
296
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
297
                }
×
298

299
                return nil
×
300
        }, sqldb.NoOpReset)
301
        if err != nil {
×
302
                return time.Time{}, false,
×
303
                        fmt.Errorf("unable to fetch node: %w", err)
×
304
        }
×
305

306
        return lastUpdate, exists, nil
×
307
}
308

309
// AddrsForNode returns all known addresses for the target node public key
310
// that the graph DB is aware of. The returned boolean indicates if the
311
// given node is unknown to the graph DB or not.
312
//
313
// NOTE: part of the V1Store interface.
314
func (s *SQLStore) AddrsForNode(ctx context.Context,
315
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
316

×
317
        var (
×
318
                addresses []net.Addr
×
319
                known     bool
×
320
        )
×
321
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
322
                var err error
×
323
                known, addresses, err = getNodeAddresses(
×
324
                        ctx, db, nodePub.SerializeCompressed(),
×
325
                )
×
326
                if err != nil {
×
327
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
328
                                err)
×
329
                }
×
330

331
                return nil
×
332
        }, sqldb.NoOpReset)
333
        if err != nil {
×
334
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
335
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
336
        }
×
337

338
        return known, addresses, nil
×
339
}
340

341
// DeleteLightningNode starts a new database transaction to remove a vertex/node
342
// from the database according to the node's public key.
343
//
344
// NOTE: part of the V1Store interface.
345
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
346
        pubKey route.Vertex) error {
×
347

×
348
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
349
                res, err := db.DeleteNodeByPubKey(
×
350
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
351
                                Version: int16(ProtocolV1),
×
352
                                PubKey:  pubKey[:],
×
353
                        },
×
354
                )
×
355
                if err != nil {
×
356
                        return err
×
357
                }
×
358

359
                rows, err := res.RowsAffected()
×
360
                if err != nil {
×
361
                        return err
×
362
                }
×
363

364
                if rows == 0 {
×
365
                        return ErrGraphNodeNotFound
×
366
                } else if rows > 1 {
×
367
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
368
                }
×
369

370
                return err
×
371
        }, sqldb.NoOpReset)
372
        if err != nil {
×
373
                return fmt.Errorf("unable to delete node: %w", err)
×
374
        }
×
375

376
        return nil
×
377
}
378

379
// FetchNodeFeatures returns the features of the given node. If no features are
380
// known for the node, an empty feature vector is returned.
381
//
382
// NOTE: this is part of the graphdb.NodeTraverser interface.
383
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
384
        *lnwire.FeatureVector, error) {
×
385

×
386
        ctx := context.TODO()
×
387

×
388
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
389
}
×
390

391
// DisabledChannelIDs returns the channel ids of disabled channels.
392
// A channel is disabled when two of the associated ChanelEdgePolicies
393
// have their disabled bit on.
394
//
395
// NOTE: part of the V1Store interface.
396
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
397
        var (
×
398
                ctx     = context.TODO()
×
399
                chanIDs []uint64
×
400
        )
×
401
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
402
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
403
                if err != nil {
×
404
                        return fmt.Errorf("unable to fetch disabled "+
×
405
                                "channels: %w", err)
×
406
                }
×
407

408
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
409

×
410
                return nil
×
411
        }, sqldb.NoOpReset)
412
        if err != nil {
×
413
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
414
                        err)
×
415
        }
×
416

417
        return chanIDs, nil
×
418
}
419

420
// LookupAlias attempts to return the alias as advertised by the target node.
421
//
422
// NOTE: part of the V1Store interface.
423
func (s *SQLStore) LookupAlias(ctx context.Context,
424
        pub *btcec.PublicKey) (string, error) {
×
425

×
426
        var alias string
×
427
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
428
                dbNode, err := db.GetNodeByPubKey(
×
429
                        ctx, sqlc.GetNodeByPubKeyParams{
×
430
                                Version: int16(ProtocolV1),
×
431
                                PubKey:  pub.SerializeCompressed(),
×
432
                        },
×
433
                )
×
434
                if errors.Is(err, sql.ErrNoRows) {
×
435
                        return ErrNodeAliasNotFound
×
436
                } else if err != nil {
×
437
                        return fmt.Errorf("unable to fetch node: %w", err)
×
438
                }
×
439

440
                if !dbNode.Alias.Valid {
×
441
                        return ErrNodeAliasNotFound
×
442
                }
×
443

444
                alias = dbNode.Alias.String
×
445

×
446
                return nil
×
447
        }, sqldb.NoOpReset)
448
        if err != nil {
×
449
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
450
        }
×
451

452
        return alias, nil
×
453
}
454

455
// SourceNode returns the source node of the graph. The source node is treated
456
// as the center node within a star-graph. This method may be used to kick off
457
// a path finding algorithm in order to explore the reachability of another
458
// node based off the source node.
459
//
460
// NOTE: part of the V1Store interface.
461
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
462
        error) {
×
463

×
464
        var node *models.LightningNode
×
465
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
466
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
467
                if err != nil {
×
468
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
469
                                err)
×
470
                }
×
471

472
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
473

×
474
                return err
×
475
        }, sqldb.NoOpReset)
476
        if err != nil {
×
477
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
478
        }
×
479

480
        return node, nil
×
481
}
482

483
// SetSourceNode sets the source node within the graph database. The source
484
// node is to be used as the center of a star-graph within path finding
485
// algorithms.
486
//
487
// NOTE: part of the V1Store interface.
488
func (s *SQLStore) SetSourceNode(ctx context.Context,
489
        node *models.LightningNode) error {
×
490

×
491
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
492
                id, err := upsertNode(ctx, db, node)
×
493
                if err != nil {
×
494
                        return fmt.Errorf("unable to upsert source node: %w",
×
495
                                err)
×
496
                }
×
497

498
                // Make sure that if a source node for this version is already
499
                // set, then the ID is the same as the one we are about to set.
500
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
501
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
502
                        return fmt.Errorf("unable to fetch source node: %w",
×
503
                                err)
×
504
                } else if err == nil {
×
505
                        if dbSourceNodeID != id {
×
506
                                return fmt.Errorf("v1 source node already "+
×
507
                                        "set to a different node: %d vs %d",
×
508
                                        dbSourceNodeID, id)
×
509
                        }
×
510

511
                        return nil
×
512
                }
513

514
                return db.AddSourceNode(ctx, id)
×
515
        }, sqldb.NoOpReset)
516
}
517

518
// NodeUpdatesInHorizon returns all the known lightning node which have an
519
// update timestamp within the passed range. This method can be used by two
520
// nodes to quickly determine if they have the same set of up to date node
521
// announcements.
522
//
523
// NOTE: This is part of the V1Store interface.
524
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
525
        endTime time.Time) ([]models.LightningNode, error) {
×
526

×
527
        ctx := context.TODO()
×
528

×
529
        var nodes []models.LightningNode
×
530
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
531
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
532
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
533
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
534
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
535
                        },
×
536
                )
×
537
                if err != nil {
×
538
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
539
                }
×
540

541
                for _, dbNode := range dbNodes {
×
542
                        node, err := buildNode(ctx, db, &dbNode)
×
543
                        if err != nil {
×
544
                                return fmt.Errorf("unable to build node: %w",
×
545
                                        err)
×
546
                        }
×
547

548
                        nodes = append(nodes, *node)
×
549
                }
550

551
                return nil
×
552
        }, sqldb.NoOpReset)
553
        if err != nil {
×
554
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
555
        }
×
556

557
        return nodes, nil
×
558
}
559

560
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
561
// undirected edge from the two target nodes are created. The information stored
562
// denotes the static attributes of the channel, such as the channelID, the keys
563
// involved in creation of the channel, and the set of features that the channel
564
// supports. The chanPoint and chanID are used to uniquely identify the edge
565
// globally within the database.
566
//
567
// NOTE: part of the V1Store interface.
568
func (s *SQLStore) AddChannelEdge(ctx context.Context,
569
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
570

×
571
        var alreadyExists bool
×
572
        r := &batch.Request[SQLQueries]{
×
573
                Opts: batch.NewSchedulerOptions(opts...),
×
574
                Reset: func() {
×
575
                        alreadyExists = false
×
576
                },
×
577
                Do: func(tx SQLQueries) error {
×
578
                        err := insertChannel(ctx, tx, edge)
×
579

×
580
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
581
                        // succeed, but propagate the error via local state.
×
582
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
583
                                alreadyExists = true
×
584
                                return nil
×
585
                        }
×
586

587
                        return err
×
588
                },
589
                OnCommit: func(err error) error {
×
590
                        switch {
×
591
                        case err != nil:
×
592
                                return err
×
593
                        case alreadyExists:
×
594
                                return ErrEdgeAlreadyExist
×
595
                        default:
×
596
                                s.rejectCache.remove(edge.ChannelID)
×
597
                                s.chanCache.remove(edge.ChannelID)
×
598
                                return nil
×
599
                        }
600
                },
601
        }
602

603
        return s.chanScheduler.Execute(ctx, r)
×
604
}
605

606
// HighestChanID returns the "highest" known channel ID in the channel graph.
607
// This represents the "newest" channel from the PoV of the chain. This method
608
// can be used by peers to quickly determine if their graphs are in sync.
609
//
610
// NOTE: This is part of the V1Store interface.
611
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
612
        var highestChanID uint64
×
613
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
614
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
615
                if errors.Is(err, sql.ErrNoRows) {
×
616
                        return nil
×
617
                } else if err != nil {
×
618
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
619
                                err)
×
620
                }
×
621

622
                highestChanID = byteOrder.Uint64(chanID)
×
623

×
624
                return nil
×
625
        }, sqldb.NoOpReset)
626
        if err != nil {
×
627
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
628
        }
×
629

630
        return highestChanID, nil
×
631
}
632

633
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
634
// within the database for the referenced channel. The `flags` attribute within
635
// the ChannelEdgePolicy determines which of the directed edges are being
636
// updated. If the flag is 1, then the first node's information is being
637
// updated, otherwise it's the second node's information. The node ordering is
638
// determined by the lexicographical ordering of the identity public keys of the
639
// nodes on either side of the channel.
640
//
641
// NOTE: part of the V1Store interface.
642
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
643
        edge *models.ChannelEdgePolicy,
644
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
645

×
646
        var (
×
647
                isUpdate1    bool
×
648
                edgeNotFound bool
×
649
                from, to     route.Vertex
×
650
        )
×
651

×
652
        r := &batch.Request[SQLQueries]{
×
653
                Opts: batch.NewSchedulerOptions(opts...),
×
654
                Reset: func() {
×
655
                        isUpdate1 = false
×
656
                        edgeNotFound = false
×
657
                },
×
658
                Do: func(tx SQLQueries) error {
×
659
                        var err error
×
660
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
661
                                ctx, tx, edge,
×
662
                        )
×
663
                        if err != nil {
×
664
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
665
                        }
×
666

667
                        // Silence ErrEdgeNotFound so that the batch can
668
                        // succeed, but propagate the error via local state.
669
                        if errors.Is(err, ErrEdgeNotFound) {
×
670
                                edgeNotFound = true
×
671
                                return nil
×
672
                        }
×
673

674
                        return err
×
675
                },
676
                OnCommit: func(err error) error {
×
677
                        switch {
×
678
                        case err != nil:
×
679
                                return err
×
680
                        case edgeNotFound:
×
681
                                return ErrEdgeNotFound
×
682
                        default:
×
683
                                s.updateEdgeCache(edge, isUpdate1)
×
684
                                return nil
×
685
                        }
686
                },
687
        }
688

689
        err := s.chanScheduler.Execute(ctx, r)
×
690

×
691
        return from, to, err
×
692
}
693

694
// updateEdgeCache updates our reject and channel caches with the new
695
// edge policy information.
696
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
697
        isUpdate1 bool) {
×
698

×
699
        // If an entry for this channel is found in reject cache, we'll modify
×
700
        // the entry with the updated timestamp for the direction that was just
×
701
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
702
        // during the next query for this edge.
×
703
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
704
                if isUpdate1 {
×
705
                        entry.upd1Time = e.LastUpdate.Unix()
×
706
                } else {
×
707
                        entry.upd2Time = e.LastUpdate.Unix()
×
708
                }
×
709
                s.rejectCache.insert(e.ChannelID, entry)
×
710
        }
711

712
        // If an entry for this channel is found in channel cache, we'll modify
713
        // the entry with the updated policy for the direction that was just
714
        // written. If the edge doesn't exist, we'll defer loading the info and
715
        // policies and lazily read from disk during the next query.
716
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
717
                if isUpdate1 {
×
718
                        channel.Policy1 = e
×
719
                } else {
×
720
                        channel.Policy2 = e
×
721
                }
×
722
                s.chanCache.insert(e.ChannelID, channel)
×
723
        }
724
}
725

726
// ForEachSourceNodeChannel iterates through all channels of the source node,
727
// executing the passed callback on each. The call-back is provided with the
728
// channel's outpoint, whether we have a policy for the channel and the channel
729
// peer's node information.
730
//
731
// NOTE: part of the V1Store interface.
732
func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint,
733
        havePolicy bool, otherNode *models.LightningNode) error) error {
×
734

×
735
        var ctx = context.TODO()
×
736

×
737
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
738
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
739
                if err != nil {
×
740
                        return fmt.Errorf("unable to fetch source node: %w",
×
741
                                err)
×
742
                }
×
743

744
                return forEachNodeChannel(
×
745
                        ctx, db, s.cfg.ChainHash, nodeID,
×
746
                        func(info *models.ChannelEdgeInfo,
×
747
                                outPolicy *models.ChannelEdgePolicy,
×
748
                                _ *models.ChannelEdgePolicy) error {
×
749

×
750
                                // Fetch the other node.
×
751
                                var (
×
752
                                        otherNodePub [33]byte
×
753
                                        node1        = info.NodeKey1Bytes
×
754
                                        node2        = info.NodeKey2Bytes
×
755
                                )
×
756
                                switch {
×
757
                                case bytes.Equal(node1[:], nodePub[:]):
×
758
                                        otherNodePub = node2
×
759
                                case bytes.Equal(node2[:], nodePub[:]):
×
760
                                        otherNodePub = node1
×
761
                                default:
×
762
                                        return fmt.Errorf("node not " +
×
763
                                                "participating in this channel")
×
764
                                }
765

766
                                _, otherNode, err := getNodeByPubKey(
×
767
                                        ctx, db, otherNodePub,
×
768
                                )
×
769
                                if err != nil {
×
770
                                        return fmt.Errorf("unable to fetch "+
×
771
                                                "other node(%x): %w",
×
772
                                                otherNodePub, err)
×
773
                                }
×
774

775
                                return cb(
×
776
                                        info.ChannelPoint, outPolicy != nil,
×
777
                                        otherNode,
×
778
                                )
×
779
                        },
780
                )
781
        }, sqldb.NoOpReset)
782
}
783

784
// ForEachNode iterates through all the stored vertices/nodes in the graph,
785
// executing the passed callback with each node encountered. If the callback
786
// returns an error, then the transaction is aborted and the iteration stops
787
// early. Any operations performed on the NodeTx passed to the call-back are
788
// executed under the same read transaction and so, methods on the NodeTx object
789
// _MUST_ only be called from within the call-back.
790
//
791
// NOTE: part of the V1Store interface.
792
func (s *SQLStore) ForEachNode(cb func(tx NodeRTx) error) error {
×
793
        var (
×
794
                ctx          = context.TODO()
×
795
                lastID int64 = 0
×
796
        )
×
797

×
798
        handleNode := func(db SQLQueries, dbNode sqlc.Node) error {
×
799
                node, err := buildNode(ctx, db, &dbNode)
×
800
                if err != nil {
×
801
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
802
                                dbNode.ID, err)
×
803
                }
×
804

805
                err = cb(
×
806
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
807
                )
×
808
                if err != nil {
×
809
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
810
                                dbNode.ID, err)
×
811
                }
×
812

813
                return nil
×
814
        }
815

816
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
817
                for {
×
818
                        nodes, err := db.ListNodesPaginated(
×
819
                                ctx, sqlc.ListNodesPaginatedParams{
×
820
                                        Version: int16(ProtocolV1),
×
821
                                        ID:      lastID,
×
822
                                        Limit:   pageSize,
×
823
                                },
×
824
                        )
×
825
                        if err != nil {
×
826
                                return fmt.Errorf("unable to fetch nodes: %w",
×
827
                                        err)
×
828
                        }
×
829

830
                        if len(nodes) == 0 {
×
831
                                break
×
832
                        }
833

834
                        for _, dbNode := range nodes {
×
835
                                err = handleNode(db, dbNode)
×
836
                                if err != nil {
×
837
                                        return err
×
838
                                }
×
839

840
                                lastID = dbNode.ID
×
841
                        }
842
                }
843

844
                return nil
×
845
        }, sqldb.NoOpReset)
846
}
847

848
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
849
// SQLStore and a SQL transaction.
850
type sqlGraphNodeTx struct {
851
        db    SQLQueries
852
        id    int64
853
        node  *models.LightningNode
854
        chain chainhash.Hash
855
}
856

857
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
858
// interface.
859
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
860

861
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
862
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
863

×
864
        return &sqlGraphNodeTx{
×
865
                db:    db,
×
866
                chain: chain,
×
867
                id:    id,
×
868
                node:  node,
×
869
        }
×
870
}
×
871

872
// Node returns the raw information of the node.
873
//
874
// NOTE: This is a part of the NodeRTx interface.
875
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
876
        return s.node
×
877
}
×
878

879
// ForEachChannel can be used to iterate over the node's channels under the same
880
// transaction used to fetch the node.
881
//
882
// NOTE: This is a part of the NodeRTx interface.
883
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
884
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
885

×
886
        ctx := context.TODO()
×
887

×
888
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
889
}
×
890

891
// FetchNode fetches the node with the given pub key under the same transaction
892
// used to fetch the current node. The returned node is also a NodeRTx and any
893
// operations on that NodeRTx will also be done under the same transaction.
894
//
895
// NOTE: This is a part of the NodeRTx interface.
896
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
897
        ctx := context.TODO()
×
898

×
899
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
900
        if err != nil {
×
901
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
902
                        nodePub, err)
×
903
        }
×
904

905
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
906
}
907

908
// ForEachNodeDirectedChannel iterates through all channels of a given node,
909
// executing the passed callback on the directed edge representing the channel
910
// and its incoming policy. If the callback returns an error, then the iteration
911
// is halted with the error propagated back up to the caller.
912
//
913
// Unknown policies are passed into the callback as nil values.
914
//
915
// NOTE: this is part of the graphdb.NodeTraverser interface.
916
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
917
        cb func(channel *DirectedChannel) error) error {
×
918

×
919
        var ctx = context.TODO()
×
920

×
921
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
922
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
923
        }, sqldb.NoOpReset)
×
924
}
925

926
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
927
// graph, executing the passed callback with each node encountered. If the
928
// callback returns an error, then the transaction is aborted and the iteration
929
// stops early.
930
//
931
// NOTE: This is a part of the V1Store interface.
932
func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex,
933
        *lnwire.FeatureVector) error) error {
×
934

×
935
        ctx := context.TODO()
×
936

×
937
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
938
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
939
                        nodePub route.Vertex) error {
×
940

×
941
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
942
                        if err != nil {
×
943
                                return fmt.Errorf("unable to fetch node "+
×
944
                                        "features: %w", err)
×
945
                        }
×
946

947
                        return cb(nodePub, features)
×
948
                })
949
        }, sqldb.NoOpReset)
950
        if err != nil {
×
951
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
952
        }
×
953

954
        return nil
×
955
}
956

957
// ForEachNodeChannel iterates through all channels of the given node,
958
// executing the passed callback with an edge info structure and the policies
959
// of each end of the channel. The first edge policy is the outgoing edge *to*
960
// the connecting node, while the second is the incoming edge *from* the
961
// connecting node. If the callback returns an error, then the iteration is
962
// halted with the error propagated back up to the caller.
963
//
964
// Unknown policies are passed into the callback as nil values.
965
//
966
// NOTE: part of the V1Store interface.
967
func (s *SQLStore) ForEachNodeChannel(nodePub route.Vertex,
968
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
969
                *models.ChannelEdgePolicy) error) error {
×
970

×
971
        var ctx = context.TODO()
×
972

×
973
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
974
                dbNode, err := db.GetNodeByPubKey(
×
975
                        ctx, sqlc.GetNodeByPubKeyParams{
×
976
                                Version: int16(ProtocolV1),
×
977
                                PubKey:  nodePub[:],
×
978
                        },
×
979
                )
×
980
                if errors.Is(err, sql.ErrNoRows) {
×
981
                        return nil
×
982
                } else if err != nil {
×
983
                        return fmt.Errorf("unable to fetch node: %w", err)
×
984
                }
×
985

986
                return forEachNodeChannel(
×
987
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
988
                )
×
989
        }, sqldb.NoOpReset)
990
}
991

992
// ChanUpdatesInHorizon returns all the known channel edges which have at least
993
// one edge that has an update timestamp within the specified horizon.
994
//
995
// NOTE: This is part of the V1Store interface.
996
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
997
        endTime time.Time) ([]ChannelEdge, error) {
×
998

×
999
        s.cacheMu.Lock()
×
1000
        defer s.cacheMu.Unlock()
×
1001

×
1002
        var (
×
1003
                ctx = context.TODO()
×
1004
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
1005
                // an additional map to keep track of the edges already seen to
×
1006
                // prevent re-adding it.
×
1007
                edgesSeen    = make(map[uint64]struct{})
×
1008
                edgesToCache = make(map[uint64]ChannelEdge)
×
1009
                edges        []ChannelEdge
×
1010
                hits         int
×
1011
        )
×
1012
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1013
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1014
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1015
                                Version:   int16(ProtocolV1),
×
1016
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1017
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1018
                        },
×
1019
                )
×
1020
                if err != nil {
×
1021
                        return err
×
1022
                }
×
1023

1024
                for _, row := range rows {
×
1025
                        // If we've already retrieved the info and policies for
×
1026
                        // this edge, then we can skip it as we don't need to do
×
1027
                        // so again.
×
1028
                        chanIDInt := byteOrder.Uint64(row.Channel.Scid)
×
1029
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1030
                                continue
×
1031
                        }
1032

1033
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1034
                                hits++
×
1035
                                edgesSeen[chanIDInt] = struct{}{}
×
1036
                                edges = append(edges, channel)
×
1037

×
1038
                                continue
×
1039
                        }
1040

1041
                        node1, node2, err := buildNodes(
×
1042
                                ctx, db, row.Node, row.Node_2,
×
1043
                        )
×
1044
                        if err != nil {
×
1045
                                return err
×
1046
                        }
×
1047

1048
                        channel, err := getAndBuildEdgeInfo(
×
1049
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1050
                                row.Channel, node1.PubKeyBytes,
×
1051
                                node2.PubKeyBytes,
×
1052
                        )
×
1053
                        if err != nil {
×
1054
                                return fmt.Errorf("unable to build channel "+
×
1055
                                        "info: %w", err)
×
1056
                        }
×
1057

1058
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1059
                        if err != nil {
×
1060
                                return fmt.Errorf("unable to extract channel "+
×
1061
                                        "policies: %w", err)
×
1062
                        }
×
1063

1064
                        p1, p2, err := getAndBuildChanPolicies(
×
1065
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1066
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1067
                        )
×
1068
                        if err != nil {
×
1069
                                return fmt.Errorf("unable to build channel "+
×
1070
                                        "policies: %w", err)
×
1071
                        }
×
1072

1073
                        edgesSeen[chanIDInt] = struct{}{}
×
1074
                        chanEdge := ChannelEdge{
×
1075
                                Info:    channel,
×
1076
                                Policy1: p1,
×
1077
                                Policy2: p2,
×
1078
                                Node1:   node1,
×
1079
                                Node2:   node2,
×
1080
                        }
×
1081
                        edges = append(edges, chanEdge)
×
1082
                        edgesToCache[chanIDInt] = chanEdge
×
1083
                }
1084

1085
                return nil
×
1086
        }, func() {
×
1087
                edgesSeen = make(map[uint64]struct{})
×
1088
                edgesToCache = make(map[uint64]ChannelEdge)
×
1089
                edges = nil
×
1090
        })
×
1091
        if err != nil {
×
1092
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1093
        }
×
1094

1095
        // Insert any edges loaded from disk into the cache.
1096
        for chanid, channel := range edgesToCache {
×
1097
                s.chanCache.insert(chanid, channel)
×
1098
        }
×
1099

1100
        if len(edges) > 0 {
×
1101
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1102
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1103
        } else {
×
1104
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1105
                        "horizon (%s, %s)", startTime, endTime)
×
1106
        }
×
1107

1108
        return edges, nil
×
1109
}
1110

1111
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1112
// data to the call-back.
1113
//
1114
// NOTE: The callback contents MUST not be modified.
1115
//
1116
// NOTE: part of the V1Store interface.
1117
func (s *SQLStore) ForEachNodeCached(cb func(node route.Vertex,
1118
        chans map[uint64]*DirectedChannel) error) error {
×
1119

×
1120
        var ctx = context.TODO()
×
1121

×
1122
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1123
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1124
                        nodePub route.Vertex) error {
×
1125

×
1126
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1127
                        if err != nil {
×
1128
                                return fmt.Errorf("unable to fetch "+
×
1129
                                        "node(id=%d) features: %w", nodeID, err)
×
1130
                        }
×
1131

1132
                        toNodeCallback := func() route.Vertex {
×
1133
                                return nodePub
×
1134
                        }
×
1135

1136
                        rows, err := db.ListChannelsByNodeID(
×
1137
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1138
                                        Version: int16(ProtocolV1),
×
1139
                                        NodeID1: nodeID,
×
1140
                                },
×
1141
                        )
×
1142
                        if err != nil {
×
1143
                                return fmt.Errorf("unable to fetch channels "+
×
1144
                                        "of node(id=%d): %w", nodeID, err)
×
1145
                        }
×
1146

1147
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1148
                        for _, row := range rows {
×
1149
                                node1, node2, err := buildNodeVertices(
×
1150
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1151
                                )
×
1152
                                if err != nil {
×
1153
                                        return err
×
1154
                                }
×
1155

1156
                                e, err := getAndBuildEdgeInfo(
×
1157
                                        ctx, db, s.cfg.ChainHash,
×
1158
                                        row.Channel.ID, row.Channel, node1,
×
1159
                                        node2,
×
1160
                                )
×
1161
                                if err != nil {
×
1162
                                        return fmt.Errorf("unable to build "+
×
1163
                                                "channel info: %w", err)
×
1164
                                }
×
1165

1166
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1167
                                        row,
×
1168
                                )
×
1169
                                if err != nil {
×
1170
                                        return fmt.Errorf("unable to "+
×
1171
                                                "extract channel "+
×
1172
                                                "policies: %w", err)
×
1173
                                }
×
1174

1175
                                p1, p2, err := getAndBuildChanPolicies(
×
1176
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1177
                                        node1, node2,
×
1178
                                )
×
1179
                                if err != nil {
×
1180
                                        return fmt.Errorf("unable to "+
×
1181
                                                "build channel policies: %w",
×
1182
                                                err)
×
1183
                                }
×
1184

1185
                                // Determine the outgoing and incoming policy
1186
                                // for this channel and node combo.
1187
                                outPolicy, inPolicy := p1, p2
×
1188
                                if p1 != nil && p1.ToNode == nodePub {
×
1189
                                        outPolicy, inPolicy = p2, p1
×
1190
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1191
                                        outPolicy, inPolicy = p2, p1
×
1192
                                }
×
1193

1194
                                var cachedInPolicy *models.CachedEdgePolicy
×
1195
                                if inPolicy != nil {
×
1196
                                        cachedInPolicy = models.NewCachedPolicy(
×
1197
                                                p2,
×
1198
                                        )
×
1199
                                        cachedInPolicy.ToNodePubKey =
×
1200
                                                toNodeCallback
×
1201
                                        cachedInPolicy.ToNodeFeatures =
×
1202
                                                features
×
1203
                                }
×
1204

1205
                                var inboundFee lnwire.Fee
×
1206
                                outPolicy.InboundFee.WhenSome(
×
1207
                                        func(fee lnwire.Fee) {
×
1208
                                                inboundFee = fee
×
1209
                                        },
×
1210
                                )
1211

1212
                                directedChannel := &DirectedChannel{
×
1213
                                        ChannelID: e.ChannelID,
×
1214
                                        IsNode1: nodePub ==
×
1215
                                                e.NodeKey1Bytes,
×
1216
                                        OtherNode:    e.NodeKey2Bytes,
×
1217
                                        Capacity:     e.Capacity,
×
1218
                                        OutPolicySet: p1 != nil,
×
1219
                                        InPolicy:     cachedInPolicy,
×
1220
                                        InboundFee:   inboundFee,
×
1221
                                }
×
1222

×
1223
                                if nodePub == e.NodeKey2Bytes {
×
1224
                                        directedChannel.OtherNode =
×
1225
                                                e.NodeKey1Bytes
×
1226
                                }
×
1227

1228
                                channels[e.ChannelID] = directedChannel
×
1229
                        }
1230

1231
                        return cb(nodePub, channels)
×
1232
                })
1233
        }, sqldb.NoOpReset)
1234
}
1235

1236
// ForEachChannelCacheable iterates through all the channel edges stored
1237
// within the graph and invokes the passed callback for each edge. The
1238
// callback takes two edges as since this is a directed graph, both the
1239
// in/out edges are visited. If the callback returns an error, then the
1240
// transaction is aborted and the iteration stops early.
1241
//
1242
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1243
// pointer for that particular channel edge routing policy will be
1244
// passed into the callback.
1245
//
1246
// NOTE: this method is like ForEachChannel but fetches only the data
1247
// required for the graph cache.
1248
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1249
        *models.CachedEdgePolicy,
1250
        *models.CachedEdgePolicy) error) error {
×
1251

×
1252
        ctx := context.TODO()
×
1253

×
1254
        handleChannel := func(db SQLQueries,
×
1255
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1256

×
1257
                node1, node2, err := buildNodeVertices(
×
1258
                        row.Node1Pubkey, row.Node2Pubkey,
×
1259
                )
×
1260
                if err != nil {
×
1261
                        return err
×
1262
                }
×
1263

1264
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
1265

×
1266
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1267
                if err != nil {
×
1268
                        return err
×
1269
                }
×
1270

1271
                var pol1, pol2 *models.CachedEdgePolicy
×
1272
                if dbPol1 != nil {
×
1273
                        policy1, err := buildChanPolicy(
×
1274
                                *dbPol1, edge.ChannelID, nil, node2, true,
×
1275
                        )
×
1276
                        if err != nil {
×
1277
                                return err
×
1278
                        }
×
1279

1280
                        pol1 = models.NewCachedPolicy(policy1)
×
1281
                }
1282
                if dbPol2 != nil {
×
1283
                        policy2, err := buildChanPolicy(
×
1284
                                *dbPol2, edge.ChannelID, nil, node1, false,
×
1285
                        )
×
1286
                        if err != nil {
×
1287
                                return err
×
1288
                        }
×
1289

1290
                        pol2 = models.NewCachedPolicy(policy2)
×
1291
                }
1292

1293
                if err := cb(edge, pol1, pol2); err != nil {
×
1294
                        return err
×
1295
                }
×
1296

1297
                return nil
×
1298
        }
1299

1300
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1301
                lastID := int64(-1)
×
1302
                for {
×
1303
                        //nolint:ll
×
1304
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1305
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1306
                                        Version: int16(ProtocolV1),
×
1307
                                        ID:      lastID,
×
1308
                                        Limit:   pageSize,
×
1309
                                },
×
1310
                        )
×
1311
                        if err != nil {
×
1312
                                return err
×
1313
                        }
×
1314

1315
                        if len(rows) == 0 {
×
1316
                                break
×
1317
                        }
1318

1319
                        for _, row := range rows {
×
1320
                                err := handleChannel(db, row)
×
1321
                                if err != nil {
×
1322
                                        return err
×
1323
                                }
×
1324

1325
                                lastID = row.Channel.ID
×
1326
                        }
1327
                }
1328

1329
                return nil
×
1330
        }, sqldb.NoOpReset)
1331
}
1332

1333
// ForEachChannel iterates through all the channel edges stored within the
1334
// graph and invokes the passed callback for each edge. The callback takes two
1335
// edges as since this is a directed graph, both the in/out edges are visited.
1336
// If the callback returns an error, then the transaction is aborted and the
1337
// iteration stops early.
1338
//
1339
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1340
// for that particular channel edge routing policy will be passed into the
1341
// callback.
1342
//
1343
// NOTE: part of the V1Store interface.
1344
func (s *SQLStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
1345
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
1346

×
1347
        ctx := context.TODO()
×
1348

×
1349
        handleChannel := func(db SQLQueries,
×
1350
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1351

×
1352
                node1, node2, err := buildNodeVertices(
×
1353
                        row.Node1Pubkey, row.Node2Pubkey,
×
1354
                )
×
1355
                if err != nil {
×
1356
                        return fmt.Errorf("unable to build node vertices: %w",
×
1357
                                err)
×
1358
                }
×
1359

1360
                edge, err := getAndBuildEdgeInfo(
×
1361
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1362
                        node1, node2,
×
1363
                )
×
1364
                if err != nil {
×
1365
                        return fmt.Errorf("unable to build channel info: %w",
×
1366
                                err)
×
1367
                }
×
1368

1369
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1370
                if err != nil {
×
1371
                        return fmt.Errorf("unable to extract channel "+
×
1372
                                "policies: %w", err)
×
1373
                }
×
1374

1375
                p1, p2, err := getAndBuildChanPolicies(
×
1376
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1377
                )
×
1378
                if err != nil {
×
1379
                        return fmt.Errorf("unable to build channel "+
×
1380
                                "policies: %w", err)
×
1381
                }
×
1382

1383
                err = cb(edge, p1, p2)
×
1384
                if err != nil {
×
1385
                        return fmt.Errorf("callback failed for channel "+
×
1386
                                "id=%d: %w", edge.ChannelID, err)
×
1387
                }
×
1388

1389
                return nil
×
1390
        }
1391

1392
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1393
                lastID := int64(-1)
×
1394
                for {
×
1395
                        //nolint:ll
×
1396
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1397
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1398
                                        Version: int16(ProtocolV1),
×
1399
                                        ID:      lastID,
×
1400
                                        Limit:   pageSize,
×
1401
                                },
×
1402
                        )
×
1403
                        if err != nil {
×
1404
                                return err
×
1405
                        }
×
1406

1407
                        if len(rows) == 0 {
×
1408
                                break
×
1409
                        }
1410

1411
                        for _, row := range rows {
×
1412
                                err := handleChannel(db, row)
×
1413
                                if err != nil {
×
1414
                                        return err
×
1415
                                }
×
1416

1417
                                lastID = row.Channel.ID
×
1418
                        }
1419
                }
1420

1421
                return nil
×
1422
        }, sqldb.NoOpReset)
1423
}
1424

1425
// FilterChannelRange returns the channel ID's of all known channels which were
1426
// mined in a block height within the passed range. The channel IDs are grouped
1427
// by their common block height. This method can be used to quickly share with a
1428
// peer the set of channels we know of within a particular range to catch them
1429
// up after a period of time offline. If withTimestamps is true then the
1430
// timestamp info of the latest received channel update messages of the channel
1431
// will be included in the response.
1432
//
1433
// NOTE: This is part of the V1Store interface.
1434
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1435
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1436

×
1437
        var (
×
1438
                ctx       = context.TODO()
×
1439
                startSCID = &lnwire.ShortChannelID{
×
1440
                        BlockHeight: startHeight,
×
1441
                }
×
1442
                endSCID = lnwire.ShortChannelID{
×
1443
                        BlockHeight: endHeight,
×
1444
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1445
                        TxPosition:  math.MaxUint16,
×
1446
                }
×
1447
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1448
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1449
        )
×
1450

×
1451
        // 1) get all channels where channelID is between start and end chan ID.
×
1452
        // 2) skip if not public (ie, no channel_proof)
×
1453
        // 3) collect that channel.
×
1454
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1455
        //    and add those timestamps to the collected channel.
×
1456
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1457
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1458
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1459
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1460
                                StartScid: chanIDStart[:],
×
1461
                                EndScid:   chanIDEnd[:],
×
1462
                        },
×
1463
                )
×
1464
                if err != nil {
×
1465
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1466
                                err)
×
1467
                }
×
1468

1469
                for _, dbChan := range dbChans {
×
1470
                        cid := lnwire.NewShortChanIDFromInt(
×
1471
                                byteOrder.Uint64(dbChan.Scid),
×
1472
                        )
×
1473
                        chanInfo := NewChannelUpdateInfo(
×
1474
                                cid, time.Time{}, time.Time{},
×
1475
                        )
×
1476

×
1477
                        if !withTimestamps {
×
1478
                                channelsPerBlock[cid.BlockHeight] = append(
×
1479
                                        channelsPerBlock[cid.BlockHeight],
×
1480
                                        chanInfo,
×
1481
                                )
×
1482

×
1483
                                continue
×
1484
                        }
1485

1486
                        //nolint:ll
1487
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1488
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1489
                                        Version:   int16(ProtocolV1),
×
1490
                                        ChannelID: dbChan.ID,
×
1491
                                        NodeID:    dbChan.NodeID1,
×
1492
                                },
×
1493
                        )
×
1494
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1495
                                return fmt.Errorf("unable to fetch node1 "+
×
1496
                                        "policy: %w", err)
×
1497
                        } else if err == nil {
×
1498
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1499
                                        node1Policy.LastUpdate.Int64, 0,
×
1500
                                )
×
1501
                        }
×
1502

1503
                        //nolint:ll
1504
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1505
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1506
                                        Version:   int16(ProtocolV1),
×
1507
                                        ChannelID: dbChan.ID,
×
1508
                                        NodeID:    dbChan.NodeID2,
×
1509
                                },
×
1510
                        )
×
1511
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1512
                                return fmt.Errorf("unable to fetch node2 "+
×
1513
                                        "policy: %w", err)
×
1514
                        } else if err == nil {
×
1515
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1516
                                        node2Policy.LastUpdate.Int64, 0,
×
1517
                                )
×
1518
                        }
×
1519

1520
                        channelsPerBlock[cid.BlockHeight] = append(
×
1521
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1522
                        )
×
1523
                }
1524

1525
                return nil
×
1526
        }, func() {
×
1527
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1528
        })
×
1529
        if err != nil {
×
1530
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1531
        }
×
1532

1533
        if len(channelsPerBlock) == 0 {
×
1534
                return nil, nil
×
1535
        }
×
1536

1537
        // Return the channel ranges in ascending block height order.
1538
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1539
        slices.Sort(blocks)
×
1540

×
1541
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1542
                return BlockChannelRange{
×
1543
                        Height:   block,
×
1544
                        Channels: channelsPerBlock[block],
×
1545
                }
×
1546
        }), nil
×
1547
}
1548

1549
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1550
// zombie. This method is used on an ad-hoc basis, when channels need to be
1551
// marked as zombies outside the normal pruning cycle.
1552
//
1553
// NOTE: part of the V1Store interface.
1554
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1555
        pubKey1, pubKey2 [33]byte) error {
×
1556

×
1557
        ctx := context.TODO()
×
1558

×
1559
        s.cacheMu.Lock()
×
1560
        defer s.cacheMu.Unlock()
×
1561

×
1562
        chanIDB := channelIDToBytes(chanID)
×
1563

×
1564
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1565
                return db.UpsertZombieChannel(
×
1566
                        ctx, sqlc.UpsertZombieChannelParams{
×
1567
                                Version:  int16(ProtocolV1),
×
1568
                                Scid:     chanIDB[:],
×
1569
                                NodeKey1: pubKey1[:],
×
1570
                                NodeKey2: pubKey2[:],
×
1571
                        },
×
1572
                )
×
1573
        }, sqldb.NoOpReset)
×
1574
        if err != nil {
×
1575
                return fmt.Errorf("unable to upsert zombie channel "+
×
1576
                        "(channel_id=%d): %w", chanID, err)
×
1577
        }
×
1578

1579
        s.rejectCache.remove(chanID)
×
1580
        s.chanCache.remove(chanID)
×
1581

×
1582
        return nil
×
1583
}
1584

1585
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1586
//
1587
// NOTE: part of the V1Store interface.
1588
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1589
        s.cacheMu.Lock()
×
1590
        defer s.cacheMu.Unlock()
×
1591

×
1592
        var (
×
1593
                ctx     = context.TODO()
×
1594
                chanIDB = channelIDToBytes(chanID)
×
1595
        )
×
1596

×
1597
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1598
                res, err := db.DeleteZombieChannel(
×
1599
                        ctx, sqlc.DeleteZombieChannelParams{
×
1600
                                Scid:    chanIDB[:],
×
1601
                                Version: int16(ProtocolV1),
×
1602
                        },
×
1603
                )
×
1604
                if err != nil {
×
1605
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1606
                                err)
×
1607
                }
×
1608

1609
                rows, err := res.RowsAffected()
×
1610
                if err != nil {
×
1611
                        return err
×
1612
                }
×
1613

1614
                if rows == 0 {
×
1615
                        return ErrZombieEdgeNotFound
×
1616
                } else if rows > 1 {
×
1617
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1618
                                "expected 1", rows)
×
1619
                }
×
1620

1621
                return nil
×
1622
        }, sqldb.NoOpReset)
1623
        if err != nil {
×
1624
                return fmt.Errorf("unable to mark edge live "+
×
1625
                        "(channel_id=%d): %w", chanID, err)
×
1626
        }
×
1627

1628
        s.rejectCache.remove(chanID)
×
1629
        s.chanCache.remove(chanID)
×
1630

×
1631
        return err
×
1632
}
1633

1634
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1635
// zombie, then the two node public keys corresponding to this edge are also
1636
// returned.
1637
//
1638
// NOTE: part of the V1Store interface.
1639
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1640
        error) {
×
1641

×
1642
        var (
×
1643
                ctx              = context.TODO()
×
1644
                isZombie         bool
×
1645
                pubKey1, pubKey2 route.Vertex
×
1646
                chanIDB          = channelIDToBytes(chanID)
×
1647
        )
×
1648

×
1649
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1650
                zombie, err := db.GetZombieChannel(
×
1651
                        ctx, sqlc.GetZombieChannelParams{
×
1652
                                Scid:    chanIDB[:],
×
1653
                                Version: int16(ProtocolV1),
×
1654
                        },
×
1655
                )
×
1656
                if errors.Is(err, sql.ErrNoRows) {
×
1657
                        return nil
×
1658
                }
×
1659
                if err != nil {
×
1660
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1661
                                err)
×
1662
                }
×
1663

1664
                copy(pubKey1[:], zombie.NodeKey1)
×
1665
                copy(pubKey2[:], zombie.NodeKey2)
×
1666
                isZombie = true
×
1667

×
1668
                return nil
×
1669
        }, sqldb.NoOpReset)
1670
        if err != nil {
×
1671
                return false, route.Vertex{}, route.Vertex{},
×
1672
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1673
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1674
        }
×
1675

1676
        return isZombie, pubKey1, pubKey2, nil
×
1677
}
1678

1679
// NumZombies returns the current number of zombie channels in the graph.
1680
//
1681
// NOTE: part of the V1Store interface.
1682
func (s *SQLStore) NumZombies() (uint64, error) {
×
1683
        var (
×
1684
                ctx        = context.TODO()
×
1685
                numZombies uint64
×
1686
        )
×
1687
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1688
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1689
                if err != nil {
×
1690
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1691
                                err)
×
1692
                }
×
1693

1694
                numZombies = uint64(count)
×
1695

×
1696
                return nil
×
1697
        }, sqldb.NoOpReset)
1698
        if err != nil {
×
1699
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1700
        }
×
1701

1702
        return numZombies, nil
×
1703
}
1704

1705
// DeleteChannelEdges removes edges with the given channel IDs from the
1706
// database and marks them as zombies. This ensures that we're unable to re-add
1707
// it to our database once again. If an edge does not exist within the
1708
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1709
// true, then when we mark these edges as zombies, we'll set up the keys such
1710
// that we require the node that failed to send the fresh update to be the one
1711
// that resurrects the channel from its zombie state. The markZombie bool
1712
// denotes whether to mark the channel as a zombie.
1713
//
1714
// NOTE: part of the V1Store interface.
1715
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1716
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1717

×
1718
        s.cacheMu.Lock()
×
1719
        defer s.cacheMu.Unlock()
×
1720

×
1721
        var (
×
1722
                ctx     = context.TODO()
×
1723
                deleted []*models.ChannelEdgeInfo
×
1724
        )
×
1725
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1726
                for _, chanID := range chanIDs {
×
1727
                        chanIDB := channelIDToBytes(chanID)
×
1728

×
1729
                        row, err := db.GetChannelBySCIDWithPolicies(
×
1730
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1731
                                        Scid:    chanIDB[:],
×
1732
                                        Version: int16(ProtocolV1),
×
1733
                                },
×
1734
                        )
×
1735
                        if errors.Is(err, sql.ErrNoRows) {
×
1736
                                return ErrEdgeNotFound
×
1737
                        } else if err != nil {
×
1738
                                return fmt.Errorf("unable to fetch channel: %w",
×
1739
                                        err)
×
1740
                        }
×
1741

1742
                        node1, node2, err := buildNodeVertices(
×
1743
                                row.Node.PubKey, row.Node_2.PubKey,
×
1744
                        )
×
1745
                        if err != nil {
×
1746
                                return err
×
1747
                        }
×
1748

1749
                        info, err := getAndBuildEdgeInfo(
×
1750
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1751
                                row.Channel, node1, node2,
×
1752
                        )
×
1753
                        if err != nil {
×
1754
                                return err
×
1755
                        }
×
1756

1757
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
1758
                        if err != nil {
×
1759
                                return fmt.Errorf("unable to delete "+
×
1760
                                        "channel: %w", err)
×
1761
                        }
×
1762

1763
                        deleted = append(deleted, info)
×
1764

×
1765
                        if !markZombie {
×
1766
                                continue
×
1767
                        }
1768

1769
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1770
                                info.NodeKey2Bytes
×
1771
                        if strictZombiePruning {
×
1772
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1773
                                if row.Policy1LastUpdate.Valid {
×
1774
                                        e1Time := time.Unix(
×
1775
                                                row.Policy1LastUpdate.Int64, 0,
×
1776
                                        )
×
1777
                                        e1UpdateTime = &e1Time
×
1778
                                }
×
1779
                                if row.Policy2LastUpdate.Valid {
×
1780
                                        e2Time := time.Unix(
×
1781
                                                row.Policy2LastUpdate.Int64, 0,
×
1782
                                        )
×
1783
                                        e2UpdateTime = &e2Time
×
1784
                                }
×
1785

1786
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1787
                                        info, e1UpdateTime, e2UpdateTime,
×
1788
                                )
×
1789
                        }
1790

1791
                        err = db.UpsertZombieChannel(
×
1792
                                ctx, sqlc.UpsertZombieChannelParams{
×
1793
                                        Version:  int16(ProtocolV1),
×
1794
                                        Scid:     chanIDB[:],
×
1795
                                        NodeKey1: nodeKey1[:],
×
1796
                                        NodeKey2: nodeKey2[:],
×
1797
                                },
×
1798
                        )
×
1799
                        if err != nil {
×
1800
                                return fmt.Errorf("unable to mark channel as "+
×
1801
                                        "zombie: %w", err)
×
1802
                        }
×
1803
                }
1804

1805
                return nil
×
1806
        }, func() {
×
1807
                deleted = nil
×
1808
        })
×
1809
        if err != nil {
×
1810
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1811
                        err)
×
1812
        }
×
1813

1814
        for _, chanID := range chanIDs {
×
1815
                s.rejectCache.remove(chanID)
×
1816
                s.chanCache.remove(chanID)
×
1817
        }
×
1818

1819
        return deleted, nil
×
1820
}
1821

1822
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1823
// channel identified by the channel ID. If the channel can't be found, then
1824
// ErrEdgeNotFound is returned. A struct which houses the general information
1825
// for the channel itself is returned as well as two structs that contain the
1826
// routing policies for the channel in either direction.
1827
//
1828
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1829
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1830
// the ChannelEdgeInfo will only include the public keys of each node.
1831
//
1832
// NOTE: part of the V1Store interface.
1833
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1834
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1835
        *models.ChannelEdgePolicy, error) {
×
1836

×
1837
        var (
×
1838
                ctx              = context.TODO()
×
1839
                edge             *models.ChannelEdgeInfo
×
1840
                policy1, policy2 *models.ChannelEdgePolicy
×
1841
        )
×
1842
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1843
                var chanIDB [8]byte
×
1844
                byteOrder.PutUint64(chanIDB[:], chanID)
×
1845

×
1846
                row, err := db.GetChannelBySCIDWithPolicies(
×
1847
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1848
                                Scid:    chanIDB[:],
×
1849
                                Version: int16(ProtocolV1),
×
1850
                        },
×
1851
                )
×
1852
                if errors.Is(err, sql.ErrNoRows) {
×
1853
                        // First check if this edge is perhaps in the zombie
×
1854
                        // index.
×
NEW
1855
                        zombie, err := db.GetZombieChannel(
×
NEW
1856
                                ctx, sqlc.GetZombieChannelParams{
×
1857
                                        Scid:    chanIDB[:],
×
1858
                                        Version: int16(ProtocolV1),
×
1859
                                },
×
1860
                        )
×
NEW
1861
                        if errors.Is(err, sql.ErrNoRows) {
×
NEW
1862
                                return ErrEdgeNotFound
×
NEW
1863
                        } else if err != nil {
×
1864
                                return fmt.Errorf("unable to check if "+
×
1865
                                        "channel is zombie: %w", err)
×
1866
                        }
×
1867

1868
                        // At this point, we know the channel is a zombie, so
1869
                        // we'll return an error indicating this, and we will
1870
                        // populate the edge info with the public keys of each
1871
                        // party as this is the only information we have about
1872
                        // it.
NEW
1873
                        edge = &models.ChannelEdgeInfo{}
×
NEW
1874
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
NEW
1875
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
NEW
1876

×
NEW
1877
                        return ErrZombieEdge
×
1878
                } else if err != nil {
×
1879
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1880
                }
×
1881

1882
                node1, node2, err := buildNodeVertices(
×
1883
                        row.Node.PubKey, row.Node_2.PubKey,
×
1884
                )
×
1885
                if err != nil {
×
1886
                        return err
×
1887
                }
×
1888

1889
                edge, err = getAndBuildEdgeInfo(
×
1890
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1891
                        node1, node2,
×
1892
                )
×
1893
                if err != nil {
×
1894
                        return fmt.Errorf("unable to build channel info: %w",
×
1895
                                err)
×
1896
                }
×
1897

1898
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1899
                if err != nil {
×
1900
                        return fmt.Errorf("unable to extract channel "+
×
1901
                                "policies: %w", err)
×
1902
                }
×
1903

1904
                policy1, policy2, err = getAndBuildChanPolicies(
×
1905
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1906
                )
×
1907
                if err != nil {
×
1908
                        return fmt.Errorf("unable to build channel "+
×
1909
                                "policies: %w", err)
×
1910
                }
×
1911

1912
                return nil
×
1913
        }, sqldb.NoOpReset)
1914
        if err != nil {
×
NEW
1915
                // If we are returning the ErrZombieEdge, then we also need to
×
NEW
1916
                // return the edge info as the method comment indicates that
×
NEW
1917
                // this will be populated when the edge is a zombie.
×
NEW
1918
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1919
                        err)
×
1920
        }
×
1921

1922
        return edge, policy1, policy2, nil
×
1923
}
1924

1925
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1926
// the channel identified by the funding outpoint. If the channel can't be
1927
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1928
// information for the channel itself is returned as well as two structs that
1929
// contain the routing policies for the channel in either direction.
1930
//
1931
// NOTE: part of the V1Store interface.
1932
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1933
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1934
        *models.ChannelEdgePolicy, error) {
×
1935

×
1936
        var (
×
1937
                ctx              = context.TODO()
×
1938
                edge             *models.ChannelEdgeInfo
×
1939
                policy1, policy2 *models.ChannelEdgePolicy
×
1940
        )
×
1941
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1942
                row, err := db.GetChannelByOutpointWithPolicies(
×
1943
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1944
                                Outpoint: op.String(),
×
1945
                                Version:  int16(ProtocolV1),
×
1946
                        },
×
1947
                )
×
1948
                if errors.Is(err, sql.ErrNoRows) {
×
1949
                        return ErrEdgeNotFound
×
1950
                } else if err != nil {
×
1951
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1952
                }
×
1953

1954
                node1, node2, err := buildNodeVertices(
×
1955
                        row.Node1Pubkey, row.Node2Pubkey,
×
1956
                )
×
1957
                if err != nil {
×
1958
                        return err
×
1959
                }
×
1960

1961
                edge, err = getAndBuildEdgeInfo(
×
1962
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1963
                        node1, node2,
×
1964
                )
×
1965
                if err != nil {
×
1966
                        return fmt.Errorf("unable to build channel info: %w",
×
1967
                                err)
×
1968
                }
×
1969

1970
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1971
                if err != nil {
×
1972
                        return fmt.Errorf("unable to extract channel "+
×
1973
                                "policies: %w", err)
×
1974
                }
×
1975

1976
                policy1, policy2, err = getAndBuildChanPolicies(
×
1977
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1978
                )
×
1979
                if err != nil {
×
1980
                        return fmt.Errorf("unable to build channel "+
×
1981
                                "policies: %w", err)
×
1982
                }
×
1983

1984
                return nil
×
1985
        }, sqldb.NoOpReset)
1986
        if err != nil {
×
1987
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1988
                        err)
×
1989
        }
×
1990

1991
        return edge, policy1, policy2, nil
×
1992
}
1993

1994
// HasChannelEdge returns true if the database knows of a channel edge with the
1995
// passed channel ID, and false otherwise. If an edge with that ID is found
1996
// within the graph, then two time stamps representing the last time the edge
1997
// was updated for both directed edges are returned along with the boolean. If
1998
// it is not found, then the zombie index is checked and its result is returned
1999
// as the second boolean.
2000
//
2001
// NOTE: part of the V1Store interface.
2002
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
2003
        bool, error) {
×
2004

×
2005
        ctx := context.TODO()
×
2006

×
2007
        var (
×
2008
                exists          bool
×
2009
                isZombie        bool
×
2010
                node1LastUpdate time.Time
×
2011
                node2LastUpdate time.Time
×
2012
        )
×
2013

×
2014
        // We'll query the cache with the shared lock held to allow multiple
×
2015
        // readers to access values in the cache concurrently if they exist.
×
2016
        s.cacheMu.RLock()
×
2017
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2018
                s.cacheMu.RUnlock()
×
2019
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2020
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2021
                exists, isZombie = entry.flags.unpack()
×
2022

×
2023
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2024
        }
×
2025
        s.cacheMu.RUnlock()
×
2026

×
2027
        s.cacheMu.Lock()
×
2028
        defer s.cacheMu.Unlock()
×
2029

×
2030
        // The item was not found with the shared lock, so we'll acquire the
×
2031
        // exclusive lock and check the cache again in case another method added
×
2032
        // the entry to the cache while no lock was held.
×
2033
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2034
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2035
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2036
                exists, isZombie = entry.flags.unpack()
×
2037

×
2038
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2039
        }
×
2040

2041
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2042
                var chanIDB [8]byte
×
2043
                byteOrder.PutUint64(chanIDB[:], chanID)
×
2044

×
2045
                channel, err := db.GetChannelBySCID(
×
2046
                        ctx, sqlc.GetChannelBySCIDParams{
×
2047
                                Scid:    chanIDB[:],
×
2048
                                Version: int16(ProtocolV1),
×
2049
                        },
×
2050
                )
×
2051
                if errors.Is(err, sql.ErrNoRows) {
×
2052
                        // Check if it is a zombie channel.
×
2053
                        isZombie, err = db.IsZombieChannel(
×
2054
                                ctx, sqlc.IsZombieChannelParams{
×
2055
                                        Scid:    chanIDB[:],
×
2056
                                        Version: int16(ProtocolV1),
×
2057
                                },
×
2058
                        )
×
2059
                        if err != nil {
×
2060
                                return fmt.Errorf("could not check if channel "+
×
2061
                                        "is zombie: %w", err)
×
2062
                        }
×
2063

2064
                        return nil
×
2065
                } else if err != nil {
×
2066
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2067
                }
×
2068

2069
                exists = true
×
2070

×
2071
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2072
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2073
                                Version:   int16(ProtocolV1),
×
2074
                                ChannelID: channel.ID,
×
2075
                                NodeID:    channel.NodeID1,
×
2076
                        },
×
2077
                )
×
2078
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2079
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2080
                                err)
×
2081
                } else if err == nil {
×
2082
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2083
                }
×
2084

2085
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2086
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2087
                                Version:   int16(ProtocolV1),
×
2088
                                ChannelID: channel.ID,
×
2089
                                NodeID:    channel.NodeID2,
×
2090
                        },
×
2091
                )
×
2092
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2093
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2094
                                err)
×
2095
                } else if err == nil {
×
2096
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2097
                }
×
2098

2099
                return nil
×
2100
        }, sqldb.NoOpReset)
2101
        if err != nil {
×
2102
                return time.Time{}, time.Time{}, false, false,
×
2103
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2104
        }
×
2105

2106
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2107
                upd1Time: node1LastUpdate.Unix(),
×
2108
                upd2Time: node2LastUpdate.Unix(),
×
2109
                flags:    packRejectFlags(exists, isZombie),
×
2110
        })
×
2111

×
2112
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2113
}
2114

2115
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2116
// passed channel point (outpoint). If the passed channel doesn't exist within
2117
// the database, then ErrEdgeNotFound is returned.
2118
//
2119
// NOTE: part of the V1Store interface.
2120
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2121
        var (
×
2122
                ctx       = context.TODO()
×
2123
                channelID uint64
×
2124
        )
×
2125
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2126
                chanID, err := db.GetSCIDByOutpoint(
×
2127
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2128
                                Outpoint: chanPoint.String(),
×
2129
                                Version:  int16(ProtocolV1),
×
2130
                        },
×
2131
                )
×
2132
                if errors.Is(err, sql.ErrNoRows) {
×
2133
                        return ErrEdgeNotFound
×
2134
                } else if err != nil {
×
2135
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2136
                                err)
×
2137
                }
×
2138

2139
                channelID = byteOrder.Uint64(chanID)
×
2140

×
2141
                return nil
×
2142
        }, sqldb.NoOpReset)
2143
        if err != nil {
×
2144
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2145
        }
×
2146

2147
        return channelID, nil
×
2148
}
2149

2150
// IsPublicNode is a helper method that determines whether the node with the
2151
// given public key is seen as a public node in the graph from the graph's
2152
// source node's point of view.
2153
//
2154
// NOTE: part of the V1Store interface.
2155
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2156
        ctx := context.TODO()
×
2157

×
2158
        var isPublic bool
×
2159
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2160
                var err error
×
2161
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2162

×
2163
                return err
×
2164
        }, sqldb.NoOpReset)
×
2165
        if err != nil {
×
2166
                return false, fmt.Errorf("unable to check if node is "+
×
2167
                        "public: %w", err)
×
2168
        }
×
2169

2170
        return isPublic, nil
×
2171
}
2172

2173
// FetchChanInfos returns the set of channel edges that correspond to the passed
2174
// channel ID's. If an edge is the query is unknown to the database, it will
2175
// skipped and the result will contain only those edges that exist at the time
2176
// of the query. This can be used to respond to peer queries that are seeking to
2177
// fill in gaps in their view of the channel graph.
2178
//
2179
// NOTE: part of the V1Store interface.
2180
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2181
        var (
×
2182
                ctx   = context.TODO()
×
2183
                edges []ChannelEdge
×
2184
        )
×
2185
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2186
                for _, chanID := range chanIDs {
×
2187
                        var chanIDB [8]byte
×
2188
                        byteOrder.PutUint64(chanIDB[:], chanID)
×
2189

×
2190
                        // TODO(elle): potentially optimize this by using
×
2191
                        //  sqlc.slice() once that works for both SQLite and
×
2192
                        //  Postgres.
×
2193
                        row, err := db.GetChannelBySCIDWithPolicies(
×
2194
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
2195
                                        Scid:    chanIDB[:],
×
2196
                                        Version: int16(ProtocolV1),
×
2197
                                },
×
2198
                        )
×
2199
                        if errors.Is(err, sql.ErrNoRows) {
×
2200
                                continue
×
2201
                        } else if err != nil {
×
2202
                                return fmt.Errorf("unable to fetch channel: %w",
×
2203
                                        err)
×
2204
                        }
×
2205

2206
                        node1, node2, err := buildNodes(
×
2207
                                ctx, db, row.Node, row.Node_2,
×
2208
                        )
×
2209
                        if err != nil {
×
2210
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2211
                                        err)
×
2212
                        }
×
2213

2214
                        edge, err := getAndBuildEdgeInfo(
×
2215
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2216
                                row.Channel, node1.PubKeyBytes,
×
2217
                                node2.PubKeyBytes,
×
2218
                        )
×
2219
                        if err != nil {
×
2220
                                return fmt.Errorf("unable to build "+
×
2221
                                        "channel info: %w", err)
×
2222
                        }
×
2223

2224
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2225
                        if err != nil {
×
2226
                                return fmt.Errorf("unable to extract channel "+
×
2227
                                        "policies: %w", err)
×
2228
                        }
×
2229

2230
                        p1, p2, err := getAndBuildChanPolicies(
×
2231
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2232
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2233
                        )
×
2234
                        if err != nil {
×
2235
                                return fmt.Errorf("unable to build channel "+
×
2236
                                        "policies: %w", err)
×
2237
                        }
×
2238

2239
                        edges = append(edges, ChannelEdge{
×
2240
                                Info:    edge,
×
2241
                                Policy1: p1,
×
2242
                                Policy2: p2,
×
2243
                                Node1:   node1,
×
2244
                                Node2:   node2,
×
2245
                        })
×
2246
                }
2247

2248
                return nil
×
2249
        }, func() {
×
2250
                edges = nil
×
2251
        })
×
2252
        if err != nil {
×
2253
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2254
        }
×
2255

2256
        return edges, nil
×
2257
}
2258

2259
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2260
// ID's that we don't know and are not known zombies of the passed set. In other
2261
// words, we perform a set difference of our set of chan ID's and the ones
2262
// passed in. This method can be used by callers to determine the set of
2263
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2264
// known zombies is also returned.
2265
//
2266
// NOTE: part of the V1Store interface.
2267
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2268
        []ChannelUpdateInfo, error) {
×
2269

×
2270
        var (
×
2271
                ctx          = context.TODO()
×
2272
                newChanIDs   []uint64
×
2273
                knownZombies []ChannelUpdateInfo
×
2274
        )
×
2275
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2276
                for _, chanInfo := range chansInfo {
×
2277
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2278
                        var chanIDB [8]byte
×
2279
                        byteOrder.PutUint64(chanIDB[:], channelID)
×
2280

×
2281
                        // TODO(elle): potentially optimize this by using
×
2282
                        //  sqlc.slice() once that works for both SQLite and
×
2283
                        //  Postgres.
×
2284
                        _, err := db.GetChannelBySCID(
×
2285
                                ctx, sqlc.GetChannelBySCIDParams{
×
2286
                                        Version: int16(ProtocolV1),
×
2287
                                        Scid:    chanIDB[:],
×
2288
                                },
×
2289
                        )
×
2290
                        if err == nil {
×
2291
                                continue
×
2292
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
2293
                                return fmt.Errorf("unable to fetch channel: %w",
×
2294
                                        err)
×
2295
                        }
×
2296

2297
                        isZombie, err := db.IsZombieChannel(
×
2298
                                ctx, sqlc.IsZombieChannelParams{
×
2299
                                        Scid:    chanIDB[:],
×
2300
                                        Version: int16(ProtocolV1),
×
2301
                                },
×
2302
                        )
×
2303
                        if err != nil {
×
2304
                                return fmt.Errorf("unable to fetch zombie "+
×
2305
                                        "channel: %w", err)
×
2306
                        }
×
2307

2308
                        if isZombie {
×
2309
                                knownZombies = append(knownZombies, chanInfo)
×
2310

×
2311
                                continue
×
2312
                        }
2313

2314
                        newChanIDs = append(newChanIDs, channelID)
×
2315
                }
2316

2317
                return nil
×
2318
        }, func() {
×
2319
                newChanIDs = nil
×
2320
                knownZombies = nil
×
2321
        })
×
2322
        if err != nil {
×
2323
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2324
        }
×
2325

2326
        return newChanIDs, knownZombies, nil
×
2327
}
2328

2329
// PruneGraphNodes is a garbage collection method which attempts to prune out
2330
// any nodes from the channel graph that are currently unconnected. This ensure
2331
// that we only maintain a graph of reachable nodes. In the event that a pruned
2332
// node gains more channels, it will be re-added back to the graph.
2333
//
2334
// NOTE: this prunes nodes across protocol versions. It will never prune the
2335
// source nodes.
2336
//
2337
// NOTE: part of the V1Store interface.
2338
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2339
        var ctx = context.TODO()
×
2340

×
2341
        var prunedNodes []route.Vertex
×
2342
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2343
                var err error
×
2344
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2345

×
2346
                return err
×
2347
        }, func() {
×
2348
                prunedNodes = nil
×
2349
        })
×
2350
        if err != nil {
×
2351
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2352
        }
×
2353

2354
        return prunedNodes, nil
×
2355
}
2356

2357
// PruneGraph prunes newly closed channels from the channel graph in response
2358
// to a new block being solved on the network. Any transactions which spend the
2359
// funding output of any known channels within he graph will be deleted.
2360
// Additionally, the "prune tip", or the last block which has been used to
2361
// prune the graph is stored so callers can ensure the graph is fully in sync
2362
// with the current UTXO state. A slice of channels that have been closed by
2363
// the target block along with any pruned nodes are returned if the function
2364
// succeeds without error.
2365
//
2366
// NOTE: part of the V1Store interface.
2367
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2368
        blockHash *chainhash.Hash, blockHeight uint32) (
2369
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2370

×
2371
        ctx := context.TODO()
×
2372

×
2373
        s.cacheMu.Lock()
×
2374
        defer s.cacheMu.Unlock()
×
2375

×
2376
        var (
×
2377
                closedChans []*models.ChannelEdgeInfo
×
2378
                prunedNodes []route.Vertex
×
2379
        )
×
2380
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2381
                for _, outpoint := range spentOutputs {
×
2382
                        // TODO(elle): potentially optimize this by using
×
2383
                        //  sqlc.slice() once that works for both SQLite and
×
2384
                        //  Postgres.
×
2385
                        //
×
2386
                        // NOTE: this fetches channels for all protocol
×
2387
                        // versions.
×
2388
                        row, err := db.GetChannelByOutpoint(
×
2389
                                ctx, outpoint.String(),
×
2390
                        )
×
2391
                        if errors.Is(err, sql.ErrNoRows) {
×
2392
                                continue
×
2393
                        } else if err != nil {
×
2394
                                return fmt.Errorf("unable to fetch channel: %w",
×
2395
                                        err)
×
2396
                        }
×
2397

2398
                        node1, node2, err := buildNodeVertices(
×
2399
                                row.Node1Pubkey, row.Node2Pubkey,
×
2400
                        )
×
2401
                        if err != nil {
×
2402
                                return err
×
2403
                        }
×
2404

2405
                        info, err := getAndBuildEdgeInfo(
×
2406
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2407
                                row.Channel, node1, node2,
×
2408
                        )
×
2409
                        if err != nil {
×
2410
                                return err
×
2411
                        }
×
2412

2413
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2414
                        if err != nil {
×
2415
                                return fmt.Errorf("unable to delete "+
×
2416
                                        "channel: %w", err)
×
2417
                        }
×
2418

2419
                        closedChans = append(closedChans, info)
×
2420
                }
2421

2422
                err := db.UpsertPruneLogEntry(
×
2423
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2424
                                BlockHash:   blockHash[:],
×
2425
                                BlockHeight: int64(blockHeight),
×
2426
                        },
×
2427
                )
×
2428
                if err != nil {
×
2429
                        return fmt.Errorf("unable to insert prune log "+
×
2430
                                "entry: %w", err)
×
2431
                }
×
2432

2433
                // Now that we've pruned some channels, we'll also prune any
2434
                // nodes that no longer have any channels.
2435
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2436
                if err != nil {
×
2437
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2438
                                err)
×
2439
                }
×
2440

2441
                return nil
×
2442
        }, func() {
×
2443
                prunedNodes = nil
×
2444
                closedChans = nil
×
2445
        })
×
2446
        if err != nil {
×
2447
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2448
        }
×
2449

2450
        for _, channel := range closedChans {
×
2451
                s.rejectCache.remove(channel.ChannelID)
×
2452
                s.chanCache.remove(channel.ChannelID)
×
2453
        }
×
2454

2455
        return closedChans, prunedNodes, nil
×
2456
}
2457

2458
// ChannelView returns the verifiable edge information for each active channel
2459
// within the known channel graph. The set of UTXOs (along with their scripts)
2460
// returned are the ones that need to be watched on chain to detect channel
2461
// closes on the resident blockchain.
2462
//
2463
// NOTE: part of the V1Store interface.
2464
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2465
        var (
×
2466
                ctx        = context.TODO()
×
2467
                edgePoints []EdgePoint
×
2468
        )
×
2469

×
2470
        handleChannel := func(db SQLQueries,
×
2471
                channel sqlc.ListChannelsPaginatedRow) error {
×
2472

×
2473
                pkScript, err := genMultiSigP2WSH(
×
2474
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
2475
                )
×
2476
                if err != nil {
×
2477
                        return err
×
2478
                }
×
2479

2480
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2481
                if err != nil {
×
2482
                        return err
×
2483
                }
×
2484

2485
                edgePoints = append(edgePoints, EdgePoint{
×
2486
                        FundingPkScript: pkScript,
×
2487
                        OutPoint:        *op,
×
2488
                })
×
2489

×
2490
                return nil
×
2491
        }
2492

2493
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2494
                lastID := int64(-1)
×
2495
                for {
×
2496
                        rows, err := db.ListChannelsPaginated(
×
2497
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2498
                                        Version: int16(ProtocolV1),
×
2499
                                        ID:      lastID,
×
2500
                                        Limit:   pageSize,
×
2501
                                },
×
2502
                        )
×
2503
                        if err != nil {
×
2504
                                return err
×
2505
                        }
×
2506

2507
                        if len(rows) == 0 {
×
2508
                                break
×
2509
                        }
2510

2511
                        for _, row := range rows {
×
2512
                                err := handleChannel(db, row)
×
2513
                                if err != nil {
×
2514
                                        return err
×
2515
                                }
×
2516

2517
                                lastID = row.ID
×
2518
                        }
2519
                }
2520

2521
                return nil
×
2522
        }, func() {
×
2523
                edgePoints = nil
×
2524
        })
×
2525
        if err != nil {
×
2526
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2527
        }
×
2528

2529
        return edgePoints, nil
×
2530
}
2531

2532
// PruneTip returns the block height and hash of the latest block that has been
2533
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2534
// to tell if the graph is currently in sync with the current best known UTXO
2535
// state.
2536
//
2537
// NOTE: part of the V1Store interface.
2538
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2539
        var (
×
2540
                ctx       = context.TODO()
×
2541
                tipHash   chainhash.Hash
×
2542
                tipHeight uint32
×
2543
        )
×
2544
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2545
                pruneTip, err := db.GetPruneTip(ctx)
×
2546
                if errors.Is(err, sql.ErrNoRows) {
×
2547
                        return ErrGraphNeverPruned
×
2548
                } else if err != nil {
×
2549
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2550
                }
×
2551

2552
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2553
                tipHeight = uint32(pruneTip.BlockHeight)
×
2554

×
2555
                return nil
×
2556
        }, sqldb.NoOpReset)
2557
        if err != nil {
×
2558
                return nil, 0, err
×
2559
        }
×
2560

2561
        return &tipHash, tipHeight, nil
×
2562
}
2563

2564
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2565
//
2566
// NOTE: this prunes nodes across protocol versions. It will never prune the
2567
// source nodes.
2568
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2569
        db SQLQueries) ([]route.Vertex, error) {
×
2570

×
2571
        // Fetch all un-connected nodes from the database.
×
2572
        // NOTE: this will not include any nodes that are listed in the
×
2573
        // source table.
×
2574
        nodes, err := db.GetUnconnectedNodes(ctx)
×
2575
        if err != nil {
×
2576
                return nil, fmt.Errorf("unable to fetch unconnected nodes: %w",
×
2577
                        err)
×
2578
        }
×
2579

2580
        prunedNodes := make([]route.Vertex, 0, len(nodes))
×
2581
        for _, node := range nodes {
×
2582
                // TODO(elle): update to use sqlc.slice() once that works.
×
2583
                if err = db.DeleteNode(ctx, node.ID); err != nil {
×
2584
                        return nil, fmt.Errorf("unable to delete "+
×
2585
                                "node(id=%d): %w", node.ID, err)
×
2586
                }
×
2587

2588
                pubKey, err := route.NewVertexFromBytes(node.PubKey)
×
2589
                if err != nil {
×
2590
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2591
                                "for node(id=%d): %w", node.ID, err)
×
2592
                }
×
2593

2594
                prunedNodes = append(prunedNodes, pubKey)
×
2595
        }
2596

2597
        return prunedNodes, nil
×
2598
}
2599

2600
// DisconnectBlockAtHeight is used to indicate that the block specified
2601
// by the passed height has been disconnected from the main chain. This
2602
// will "rewind" the graph back to the height below, deleting channels
2603
// that are no longer confirmed from the graph. The prune log will be
2604
// set to the last prune height valid for the remaining chain.
2605
// Channels that were removed from the graph resulting from the
2606
// disconnected block are returned.
2607
//
2608
// NOTE: part of the V1Store interface.
2609
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2610
        []*models.ChannelEdgeInfo, error) {
×
2611

×
2612
        ctx := context.TODO()
×
2613

×
2614
        var (
×
2615
                // Every channel having a ShortChannelID starting at 'height'
×
2616
                // will no longer be confirmed.
×
2617
                startShortChanID = lnwire.ShortChannelID{
×
2618
                        BlockHeight: height,
×
2619
                }
×
2620

×
2621
                // Delete everything after this height from the db up until the
×
2622
                // SCID alias range.
×
2623
                endShortChanID = aliasmgr.StartingAlias
×
2624

×
2625
                removedChans []*models.ChannelEdgeInfo
×
2626
        )
×
2627

×
2628
        var chanIDStart [8]byte
×
2629
        byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64())
×
2630
        var chanIDEnd [8]byte
×
2631
        byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64())
×
2632

×
2633
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2634
                rows, err := db.GetChannelsBySCIDRange(
×
2635
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2636
                                StartScid: chanIDStart[:],
×
2637
                                EndScid:   chanIDEnd[:],
×
2638
                        },
×
2639
                )
×
2640
                if err != nil {
×
2641
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2642
                }
×
2643

2644
                for _, row := range rows {
×
2645
                        node1, node2, err := buildNodeVertices(
×
2646
                                row.Node1PubKey, row.Node2PubKey,
×
2647
                        )
×
2648
                        if err != nil {
×
2649
                                return err
×
2650
                        }
×
2651

2652
                        channel, err := getAndBuildEdgeInfo(
×
2653
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2654
                                row.Channel, node1, node2,
×
2655
                        )
×
2656
                        if err != nil {
×
2657
                                return err
×
2658
                        }
×
2659

2660
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2661
                        if err != nil {
×
2662
                                return fmt.Errorf("unable to delete "+
×
2663
                                        "channel: %w", err)
×
2664
                        }
×
2665

2666
                        removedChans = append(removedChans, channel)
×
2667
                }
2668

2669
                return db.DeletePruneLogEntriesInRange(
×
2670
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2671
                                StartHeight: int64(height),
×
2672
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2673
                        },
×
2674
                )
×
2675
        }, func() {
×
2676
                removedChans = nil
×
2677
        })
×
2678
        if err != nil {
×
2679
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2680
                        "height: %w", err)
×
2681
        }
×
2682

2683
        for _, channel := range removedChans {
×
2684
                s.rejectCache.remove(channel.ChannelID)
×
2685
                s.chanCache.remove(channel.ChannelID)
×
2686
        }
×
2687

2688
        return removedChans, nil
×
2689
}
2690

2691
// AddEdgeProof sets the proof of an existing edge in the graph database.
2692
//
2693
// NOTE: part of the V1Store interface.
2694
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2695
        proof *models.ChannelAuthProof) error {
×
2696

×
2697
        var (
×
2698
                ctx       = context.TODO()
×
2699
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2700
        )
×
2701

×
2702
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2703
                res, err := db.AddV1ChannelProof(
×
2704
                        ctx, sqlc.AddV1ChannelProofParams{
×
2705
                                Scid:              scidBytes[:],
×
2706
                                Node1Signature:    proof.NodeSig1Bytes,
×
2707
                                Node2Signature:    proof.NodeSig2Bytes,
×
2708
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2709
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2710
                        },
×
2711
                )
×
2712
                if err != nil {
×
2713
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2714
                }
×
2715

2716
                n, err := res.RowsAffected()
×
2717
                if err != nil {
×
2718
                        return err
×
2719
                }
×
2720

2721
                if n == 0 {
×
2722
                        return fmt.Errorf("no rows affected when adding edge "+
×
2723
                                "proof for SCID %v", scid)
×
2724
                } else if n > 1 {
×
2725
                        return fmt.Errorf("multiple rows affected when adding "+
×
2726
                                "edge proof for SCID %v: %d rows affected",
×
2727
                                scid, n)
×
2728
                }
×
2729

2730
                return nil
×
2731
        }, sqldb.NoOpReset)
2732
        if err != nil {
×
2733
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2734
        }
×
2735

2736
        return nil
×
2737
}
2738

2739
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2740
// that we can ignore channel announcements that we know to be closed without
2741
// having to validate them and fetch a block.
2742
//
2743
// NOTE: part of the V1Store interface.
2744
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2745
        var (
×
2746
                ctx     = context.TODO()
×
2747
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2748
        )
×
2749

×
2750
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2751
                return db.InsertClosedChannel(ctx, chanIDB[:])
×
2752
        }, sqldb.NoOpReset)
×
2753
}
2754

2755
// IsClosedScid checks whether a channel identified by the passed in scid is
2756
// closed. This helps avoid having to perform expensive validation checks.
2757
//
2758
// NOTE: part of the V1Store interface.
2759
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2760
        var (
×
2761
                ctx      = context.TODO()
×
2762
                isClosed bool
×
2763
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2764
        )
×
2765
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2766
                var err error
×
2767
                isClosed, err = db.IsClosedChannel(ctx, chanIDB[:])
×
2768
                if err != nil {
×
2769
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2770
                                err)
×
2771
                }
×
2772

2773
                return nil
×
2774
        }, sqldb.NoOpReset)
2775
        if err != nil {
×
2776
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2777
                        err)
×
2778
        }
×
2779

2780
        return isClosed, nil
×
2781
}
2782

2783
// GraphSession will provide the call-back with access to a NodeTraverser
2784
// instance which can be used to perform queries against the channel graph.
2785
//
2786
// NOTE: part of the V1Store interface.
2787
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error) error {
×
2788
        var ctx = context.TODO()
×
2789

×
2790
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2791
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2792
        }, sqldb.NoOpReset)
×
2793
}
2794

2795
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2796
// read only transaction for a consistent view of the graph.
2797
type sqlNodeTraverser struct {
2798
        db    SQLQueries
2799
        chain chainhash.Hash
2800
}
2801

2802
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2803
// NodeTraverser interface.
2804
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2805

2806
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2807
func newSQLNodeTraverser(db SQLQueries,
2808
        chain chainhash.Hash) *sqlNodeTraverser {
×
2809

×
2810
        return &sqlNodeTraverser{
×
2811
                db:    db,
×
2812
                chain: chain,
×
2813
        }
×
2814
}
×
2815

2816
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2817
// node.
2818
//
2819
// NOTE: Part of the NodeTraverser interface.
2820
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2821
        cb func(channel *DirectedChannel) error) error {
×
2822

×
2823
        ctx := context.TODO()
×
2824

×
2825
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2826
}
×
2827

2828
// FetchNodeFeatures returns the features of the given node. If the node is
2829
// unknown, assume no additional features are supported.
2830
//
2831
// NOTE: Part of the NodeTraverser interface.
2832
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2833
        *lnwire.FeatureVector, error) {
×
2834

×
2835
        ctx := context.TODO()
×
2836

×
2837
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2838
}
×
2839

2840
// forEachNodeDirectedChannel iterates through all channels of a given
2841
// node, executing the passed callback on the directed edge representing the
2842
// channel and its incoming policy. If the node is not found, no error is
2843
// returned.
2844
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2845
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2846

×
2847
        toNodeCallback := func() route.Vertex {
×
2848
                return nodePub
×
2849
        }
×
2850

2851
        dbID, err := db.GetNodeIDByPubKey(
×
2852
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2853
                        Version: int16(ProtocolV1),
×
2854
                        PubKey:  nodePub[:],
×
2855
                },
×
2856
        )
×
2857
        if errors.Is(err, sql.ErrNoRows) {
×
2858
                return nil
×
2859
        } else if err != nil {
×
2860
                return fmt.Errorf("unable to fetch node: %w", err)
×
2861
        }
×
2862

2863
        rows, err := db.ListChannelsByNodeID(
×
2864
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2865
                        Version: int16(ProtocolV1),
×
2866
                        NodeID1: dbID,
×
2867
                },
×
2868
        )
×
2869
        if err != nil {
×
2870
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2871
        }
×
2872

2873
        // Exit early if there are no channels for this node so we don't
2874
        // do the unnecessary feature fetching.
2875
        if len(rows) == 0 {
×
2876
                return nil
×
2877
        }
×
2878

2879
        features, err := getNodeFeatures(ctx, db, dbID)
×
2880
        if err != nil {
×
2881
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2882
        }
×
2883

2884
        for _, row := range rows {
×
2885
                node1, node2, err := buildNodeVertices(
×
2886
                        row.Node1Pubkey, row.Node2Pubkey,
×
2887
                )
×
2888
                if err != nil {
×
2889
                        return fmt.Errorf("unable to build node vertices: %w",
×
2890
                                err)
×
2891
                }
×
2892

2893
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
2894

×
2895
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2896
                if err != nil {
×
2897
                        return err
×
2898
                }
×
2899

2900
                var p1, p2 *models.CachedEdgePolicy
×
2901
                if dbPol1 != nil {
×
2902
                        policy1, err := buildChanPolicy(
×
2903
                                *dbPol1, edge.ChannelID, nil, node2, true,
×
2904
                        )
×
2905
                        if err != nil {
×
2906
                                return err
×
2907
                        }
×
2908

2909
                        p1 = models.NewCachedPolicy(policy1)
×
2910
                }
2911
                if dbPol2 != nil {
×
2912
                        policy2, err := buildChanPolicy(
×
2913
                                *dbPol2, edge.ChannelID, nil, node1, false,
×
2914
                        )
×
2915
                        if err != nil {
×
2916
                                return err
×
2917
                        }
×
2918

2919
                        p2 = models.NewCachedPolicy(policy2)
×
2920
                }
2921

2922
                // Determine the outgoing and incoming policy for this
2923
                // channel and node combo.
2924
                outPolicy, inPolicy := p1, p2
×
2925
                if p1 != nil && node2 == nodePub {
×
2926
                        outPolicy, inPolicy = p2, p1
×
2927
                } else if p2 != nil && node1 != nodePub {
×
2928
                        outPolicy, inPolicy = p2, p1
×
2929
                }
×
2930

2931
                var cachedInPolicy *models.CachedEdgePolicy
×
2932
                if inPolicy != nil {
×
2933
                        cachedInPolicy = inPolicy
×
2934
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2935
                        cachedInPolicy.ToNodeFeatures = features
×
2936
                }
×
2937

2938
                directedChannel := &DirectedChannel{
×
2939
                        ChannelID:    edge.ChannelID,
×
2940
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2941
                        OtherNode:    edge.NodeKey2Bytes,
×
2942
                        Capacity:     edge.Capacity,
×
2943
                        OutPolicySet: outPolicy != nil,
×
2944
                        InPolicy:     cachedInPolicy,
×
2945
                }
×
2946
                if outPolicy != nil {
×
2947
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2948
                                directedChannel.InboundFee = fee
×
2949
                        })
×
2950
                }
2951

2952
                if nodePub == edge.NodeKey2Bytes {
×
2953
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2954
                }
×
2955

2956
                if err := cb(directedChannel); err != nil {
×
2957
                        return err
×
2958
                }
×
2959
        }
2960

2961
        return nil
×
2962
}
2963

2964
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2965
// and executes the provided callback for each node.
2966
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
2967
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
2968

×
2969
        lastID := int64(-1)
×
2970

×
2971
        for {
×
2972
                nodes, err := db.ListNodeIDsAndPubKeys(
×
2973
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
2974
                                Version: int16(ProtocolV1),
×
2975
                                ID:      lastID,
×
2976
                                Limit:   pageSize,
×
2977
                        },
×
2978
                )
×
2979
                if err != nil {
×
2980
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
2981
                }
×
2982

2983
                if len(nodes) == 0 {
×
2984
                        break
×
2985
                }
2986

2987
                for _, node := range nodes {
×
2988
                        var pub route.Vertex
×
2989
                        copy(pub[:], node.PubKey)
×
2990

×
2991
                        if err := cb(node.ID, pub); err != nil {
×
2992
                                return fmt.Errorf("forEachNodeCacheable "+
×
2993
                                        "callback failed for node(id=%d): %w",
×
2994
                                        node.ID, err)
×
2995
                        }
×
2996

2997
                        lastID = node.ID
×
2998
                }
2999
        }
3000

3001
        return nil
×
3002
}
3003

3004
// forEachNodeChannel iterates through all channels of a node, executing
3005
// the passed callback on each. The call-back is provided with the channel's
3006
// edge information, the outgoing policy and the incoming policy for the
3007
// channel and node combo.
3008
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3009
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
3010
                *models.ChannelEdgePolicy,
3011
                *models.ChannelEdgePolicy) error) error {
×
3012

×
3013
        // Get all the V1 channels for this node.Add commentMore actions
×
3014
        rows, err := db.ListChannelsByNodeID(
×
3015
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3016
                        Version: int16(ProtocolV1),
×
3017
                        NodeID1: id,
×
3018
                },
×
3019
        )
×
3020
        if err != nil {
×
3021
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3022
        }
×
3023

3024
        // Call the call-back for each channel and its known policies.
3025
        for _, row := range rows {
×
3026
                node1, node2, err := buildNodeVertices(
×
3027
                        row.Node1Pubkey, row.Node2Pubkey,
×
3028
                )
×
3029
                if err != nil {
×
3030
                        return fmt.Errorf("unable to build node vertices: %w",
×
3031
                                err)
×
3032
                }
×
3033

3034
                edge, err := getAndBuildEdgeInfo(
×
3035
                        ctx, db, chain, row.Channel.ID, row.Channel, node1,
×
3036
                        node2,
×
3037
                )
×
3038
                if err != nil {
×
3039
                        return fmt.Errorf("unable to build channel info: %w",
×
3040
                                err)
×
3041
                }
×
3042

3043
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3044
                if err != nil {
×
3045
                        return fmt.Errorf("unable to extract channel "+
×
3046
                                "policies: %w", err)
×
3047
                }
×
3048

3049
                p1, p2, err := getAndBuildChanPolicies(
×
3050
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3051
                )
×
3052
                if err != nil {
×
3053
                        return fmt.Errorf("unable to build channel "+
×
3054
                                "policies: %w", err)
×
3055
                }
×
3056

3057
                // Determine the outgoing and incoming policy for this
3058
                // channel and node combo.
3059
                p1ToNode := row.Channel.NodeID2
×
3060
                p2ToNode := row.Channel.NodeID1
×
3061
                outPolicy, inPolicy := p1, p2
×
3062
                if (p1 != nil && p1ToNode == id) ||
×
3063
                        (p2 != nil && p2ToNode != id) {
×
3064

×
3065
                        outPolicy, inPolicy = p2, p1
×
3066
                }
×
3067

3068
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3069
                        return err
×
3070
                }
×
3071
        }
3072

3073
        return nil
×
3074
}
3075

3076
// updateChanEdgePolicy upserts the channel policy info we have stored for
3077
// a channel we already know of.
3078
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3079
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3080
        error) {
×
3081

×
3082
        var (
×
3083
                node1Pub, node2Pub route.Vertex
×
3084
                isNode1            bool
×
3085
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3086
        )
×
3087

×
3088
        // Check that this edge policy refers to a channel that we already
×
3089
        // know of. We do this explicitly so that we can return the appropriate
×
3090
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3091
        // abort the transaction which would abort the entire batch.
×
3092
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3093
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3094
                        Scid:    chanIDB[:],
×
3095
                        Version: int16(ProtocolV1),
×
3096
                },
×
3097
        )
×
3098
        if errors.Is(err, sql.ErrNoRows) {
×
3099
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3100
        } else if err != nil {
×
3101
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3102
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3103
        }
×
3104

3105
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3106
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3107

×
3108
        // Figure out which node this edge is from.
×
3109
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3110
        nodeID := dbChan.NodeID1
×
3111
        if !isNode1 {
×
3112
                nodeID = dbChan.NodeID2
×
3113
        }
×
3114

3115
        var (
×
3116
                inboundBase sql.NullInt64
×
3117
                inboundRate sql.NullInt64
×
3118
        )
×
3119
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3120
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3121
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3122
        })
×
3123

3124
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3125
                Version:     int16(ProtocolV1),
×
3126
                ChannelID:   dbChan.ID,
×
3127
                NodeID:      nodeID,
×
3128
                Timelock:    int32(edge.TimeLockDelta),
×
3129
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3130
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3131
                MinHtlcMsat: int64(edge.MinHTLC),
×
3132
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3133
                Disabled: sql.NullBool{
×
3134
                        Valid: true,
×
3135
                        Bool:  edge.IsDisabled(),
×
3136
                },
×
3137
                MaxHtlcMsat: sql.NullInt64{
×
3138
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3139
                        Int64: int64(edge.MaxHTLC),
×
3140
                },
×
3141
                InboundBaseFeeMsat:      inboundBase,
×
3142
                InboundFeeRateMilliMsat: inboundRate,
×
3143
                Signature:               edge.SigBytes,
×
3144
        })
×
3145
        if err != nil {
×
3146
                return node1Pub, node2Pub, isNode1,
×
3147
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3148
        }
×
3149

3150
        // Convert the flat extra opaque data into a map of TLV types to
3151
        // values.
3152
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3153
        if err != nil {
×
3154
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3155
                        "marshal extra opaque data: %w", err)
×
3156
        }
×
3157

3158
        // Update the channel policy's extra signed fields.
3159
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3160
        if err != nil {
×
3161
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3162
                        "policy extra TLVs: %w", err)
×
3163
        }
×
3164

3165
        return node1Pub, node2Pub, isNode1, nil
×
3166
}
3167

3168
// getNodeByPubKey attempts to look up a target node by its public key.
3169
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3170
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3171

×
3172
        dbNode, err := db.GetNodeByPubKey(
×
3173
                ctx, sqlc.GetNodeByPubKeyParams{
×
3174
                        Version: int16(ProtocolV1),
×
3175
                        PubKey:  pubKey[:],
×
3176
                },
×
3177
        )
×
3178
        if errors.Is(err, sql.ErrNoRows) {
×
3179
                return 0, nil, ErrGraphNodeNotFound
×
3180
        } else if err != nil {
×
3181
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3182
        }
×
3183

3184
        node, err := buildNode(ctx, db, &dbNode)
×
3185
        if err != nil {
×
3186
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3187
        }
×
3188

3189
        return dbNode.ID, node, nil
×
3190
}
3191

3192
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3193
// provided database channel row and the public keys of the two nodes
3194
// involved in the channel.
3195
func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub,
3196
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3197

×
3198
        return &models.CachedEdgeInfo{
×
3199
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
3200
                NodeKey1Bytes: node1Pub,
×
3201
                NodeKey2Bytes: node2Pub,
×
3202
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
3203
        }
×
3204
}
×
3205

3206
// buildNode constructs a LightningNode instance from the given database node
3207
// record. The node's features, addresses and extra signed fields are also
3208
// fetched from the database and set on the node.
3209
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
3210
        *models.LightningNode, error) {
×
3211

×
3212
        if dbNode.Version != int16(ProtocolV1) {
×
3213
                return nil, fmt.Errorf("unsupported node version: %d",
×
3214
                        dbNode.Version)
×
3215
        }
×
3216

3217
        var pub [33]byte
×
3218
        copy(pub[:], dbNode.PubKey)
×
3219

×
3220
        node := &models.LightningNode{
×
3221
                PubKeyBytes: pub,
×
3222
                Features:    lnwire.EmptyFeatureVector(),
×
3223
                LastUpdate:  time.Unix(0, 0),
×
3224
        }
×
3225

×
3226
        if len(dbNode.Signature) == 0 {
×
3227
                return node, nil
×
3228
        }
×
3229

3230
        node.HaveNodeAnnouncement = true
×
3231
        node.AuthSigBytes = dbNode.Signature
×
3232
        node.Alias = dbNode.Alias.String
×
3233
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3234

×
3235
        var err error
×
3236
        if dbNode.Color.Valid {
×
3237
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3238
                if err != nil {
×
3239
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3240
                                err)
×
3241
                }
×
3242
        }
3243

3244
        // Fetch the node's features.
3245
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
3246
        if err != nil {
×
3247
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3248
                        "features: %w", dbNode.ID, err)
×
3249
        }
×
3250

3251
        // Fetch the node's addresses.
3252
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
3253
        if err != nil {
×
3254
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3255
                        "addresses: %w", dbNode.ID, err)
×
3256
        }
×
3257

3258
        // Fetch the node's extra signed fields.
3259
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
3260
        if err != nil {
×
3261
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3262
                        "extra signed fields: %w", dbNode.ID, err)
×
3263
        }
×
3264

3265
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
3266
        if err != nil {
×
3267
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3268
                        "fields: %w", err)
×
3269
        }
×
3270

3271
        if len(recs) != 0 {
×
3272
                node.ExtraOpaqueData = recs
×
3273
        }
×
3274

3275
        return node, nil
×
3276
}
3277

3278
// getNodeFeatures fetches the feature bits and constructs the feature vector
3279
// for a node with the given DB ID.
3280
func getNodeFeatures(ctx context.Context, db SQLQueries,
3281
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3282

×
3283
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3284
        if err != nil {
×
3285
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3286
                        nodeID, err)
×
3287
        }
×
3288

3289
        features := lnwire.EmptyFeatureVector()
×
3290
        for _, feature := range rows {
×
3291
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3292
        }
×
3293

3294
        return features, nil
×
3295
}
3296

3297
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3298
// given DB ID.
3299
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3300
        nodeID int64) (map[uint64][]byte, error) {
×
3301

×
3302
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3303
        if err != nil {
×
3304
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
3305
                        "signed fields: %w", nodeID, err)
×
3306
        }
×
3307

3308
        extraFields := make(map[uint64][]byte)
×
3309
        for _, field := range fields {
×
3310
                extraFields[uint64(field.Type)] = field.Value
×
3311
        }
×
3312

3313
        return extraFields, nil
×
3314
}
3315

3316
// upsertNode upserts the node record into the database. If the node already
3317
// exists, then the node's information is updated. If the node doesn't exist,
3318
// then a new node is created. The node's features, addresses and extra TLV
3319
// types are also updated. The node's DB ID is returned.
3320
func upsertNode(ctx context.Context, db SQLQueries,
3321
        node *models.LightningNode) (int64, error) {
×
3322

×
3323
        params := sqlc.UpsertNodeParams{
×
3324
                Version: int16(ProtocolV1),
×
3325
                PubKey:  node.PubKeyBytes[:],
×
3326
        }
×
3327

×
3328
        if node.HaveNodeAnnouncement {
×
3329
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3330
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3331
                params.Alias = sqldb.SQLStr(node.Alias)
×
3332
                params.Signature = node.AuthSigBytes
×
3333
        }
×
3334

3335
        nodeID, err := db.UpsertNode(ctx, params)
×
3336
        if err != nil {
×
3337
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3338
                        err)
×
3339
        }
×
3340

3341
        // We can exit here if we don't have the announcement yet.
3342
        if !node.HaveNodeAnnouncement {
×
3343
                return nodeID, nil
×
3344
        }
×
3345

3346
        // Update the node's features.
3347
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3348
        if err != nil {
×
3349
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3350
        }
×
3351

3352
        // Update the node's addresses.
3353
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3354
        if err != nil {
×
3355
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3356
        }
×
3357

3358
        // Convert the flat extra opaque data into a map of TLV types to
3359
        // values.
3360
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3361
        if err != nil {
×
3362
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3363
                        err)
×
3364
        }
×
3365

3366
        // Update the node's extra signed fields.
3367
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3368
        if err != nil {
×
3369
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3370
        }
×
3371

3372
        return nodeID, nil
×
3373
}
3374

3375
// upsertNodeFeatures updates the node's features node_features table. This
3376
// includes deleting any feature bits no longer present and inserting any new
3377
// feature bits. If the feature bit does not yet exist in the features table,
3378
// then an entry is created in that table first.
3379
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3380
        features *lnwire.FeatureVector) error {
×
3381

×
3382
        // Get any existing features for the node.
×
3383
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3384
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3385
                return err
×
3386
        }
×
3387

3388
        // Copy the nodes latest set of feature bits.
3389
        newFeatures := make(map[int32]struct{})
×
3390
        if features != nil {
×
3391
                for feature := range features.Features() {
×
3392
                        newFeatures[int32(feature)] = struct{}{}
×
3393
                }
×
3394
        }
3395

3396
        // For any current feature that already exists in the DB, remove it from
3397
        // the in-memory map. For any existing feature that does not exist in
3398
        // the in-memory map, delete it from the database.
3399
        for _, feature := range existingFeatures {
×
3400
                // The feature is still present, so there are no updates to be
×
3401
                // made.
×
3402
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3403
                        delete(newFeatures, feature.FeatureBit)
×
3404
                        continue
×
3405
                }
3406

3407
                // The feature is no longer present, so we remove it from the
3408
                // database.
3409
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3410
                        NodeID:     nodeID,
×
3411
                        FeatureBit: feature.FeatureBit,
×
3412
                })
×
3413
                if err != nil {
×
3414
                        return fmt.Errorf("unable to delete node(%d) "+
×
3415
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3416
                                err)
×
3417
                }
×
3418
        }
3419

3420
        // Any remaining entries in newFeatures are new features that need to be
3421
        // added to the database for the first time.
3422
        for feature := range newFeatures {
×
3423
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3424
                        NodeID:     nodeID,
×
3425
                        FeatureBit: feature,
×
3426
                })
×
3427
                if err != nil {
×
3428
                        return fmt.Errorf("unable to insert node(%d) "+
×
3429
                                "feature(%v): %w", nodeID, feature, err)
×
3430
                }
×
3431
        }
3432

3433
        return nil
×
3434
}
3435

3436
// fetchNodeFeatures fetches the features for a node with the given public key.
3437
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3438
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3439

×
3440
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3441
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3442
                        PubKey:  nodePub[:],
×
3443
                        Version: int16(ProtocolV1),
×
3444
                },
×
3445
        )
×
3446
        if err != nil {
×
3447
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3448
                        nodePub, err)
×
3449
        }
×
3450

3451
        features := lnwire.EmptyFeatureVector()
×
3452
        for _, bit := range rows {
×
3453
                features.Set(lnwire.FeatureBit(bit))
×
3454
        }
×
3455

3456
        return features, nil
×
3457
}
3458

3459
// dbAddressType is an enum type that represents the different address types
3460
// that we store in the node_addresses table. The address type determines how
3461
// the address is to be serialised/deserialize.
3462
type dbAddressType uint8
3463

3464
const (
3465
        addressTypeIPv4   dbAddressType = 1
3466
        addressTypeIPv6   dbAddressType = 2
3467
        addressTypeTorV2  dbAddressType = 3
3468
        addressTypeTorV3  dbAddressType = 4
3469
        addressTypeOpaque dbAddressType = math.MaxInt8
3470
)
3471

3472
// upsertNodeAddresses updates the node's addresses in the database. This
3473
// includes deleting any existing addresses and inserting the new set of
3474
// addresses. The deletion is necessary since the ordering of the addresses may
3475
// change, and we need to ensure that the database reflects the latest set of
3476
// addresses so that at the time of reconstructing the node announcement, the
3477
// order is preserved and the signature over the message remains valid.
3478
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3479
        addresses []net.Addr) error {
×
3480

×
3481
        // Delete any existing addresses for the node. This is required since
×
3482
        // even if the new set of addresses is the same, the ordering may have
×
3483
        // changed for a given address type.
×
3484
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3485
        if err != nil {
×
3486
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3487
                        nodeID, err)
×
3488
        }
×
3489

3490
        // Copy the nodes latest set of addresses.
3491
        newAddresses := map[dbAddressType][]string{
×
3492
                addressTypeIPv4:   {},
×
3493
                addressTypeIPv6:   {},
×
3494
                addressTypeTorV2:  {},
×
3495
                addressTypeTorV3:  {},
×
3496
                addressTypeOpaque: {},
×
3497
        }
×
3498
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3499
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3500
        }
×
3501

3502
        for _, address := range addresses {
×
3503
                switch addr := address.(type) {
×
3504
                case *net.TCPAddr:
×
3505
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3506
                                addAddr(addressTypeIPv4, addr)
×
3507
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3508
                                addAddr(addressTypeIPv6, addr)
×
3509
                        } else {
×
3510
                                return fmt.Errorf("unhandled IP address: %v",
×
3511
                                        addr)
×
3512
                        }
×
3513

3514
                case *tor.OnionAddr:
×
3515
                        switch len(addr.OnionService) {
×
3516
                        case tor.V2Len:
×
3517
                                addAddr(addressTypeTorV2, addr)
×
3518
                        case tor.V3Len:
×
3519
                                addAddr(addressTypeTorV3, addr)
×
3520
                        default:
×
3521
                                return fmt.Errorf("invalid length for a tor " +
×
3522
                                        "address")
×
3523
                        }
3524

3525
                case *lnwire.OpaqueAddrs:
×
3526
                        addAddr(addressTypeOpaque, addr)
×
3527

3528
                default:
×
3529
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3530
                }
3531
        }
3532

3533
        // Any remaining entries in newAddresses are new addresses that need to
3534
        // be added to the database for the first time.
3535
        for addrType, addrList := range newAddresses {
×
3536
                for position, addr := range addrList {
×
3537
                        err := db.InsertNodeAddress(
×
3538
                                ctx, sqlc.InsertNodeAddressParams{
×
3539
                                        NodeID:   nodeID,
×
3540
                                        Type:     int16(addrType),
×
3541
                                        Address:  addr,
×
3542
                                        Position: int32(position),
×
3543
                                },
×
3544
                        )
×
3545
                        if err != nil {
×
3546
                                return fmt.Errorf("unable to insert "+
×
3547
                                        "node(%d) address(%v): %w", nodeID,
×
3548
                                        addr, err)
×
3549
                        }
×
3550
                }
3551
        }
3552

3553
        return nil
×
3554
}
3555

3556
// getNodeAddresses fetches the addresses for a node with the given public key.
3557
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3558
        []net.Addr, error) {
×
3559

×
3560
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
3561
        // are returned in the same order as they were inserted.
×
3562
        rows, err := db.GetNodeAddressesByPubKey(
×
3563
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
3564
                        Version: int16(ProtocolV1),
×
3565
                        PubKey:  nodePub,
×
3566
                },
×
3567
        )
×
3568
        if err != nil {
×
3569
                return false, nil, err
×
3570
        }
×
3571

3572
        // GetNodeAddressesByPubKey uses a left join so there should always be
3573
        // at least one row returned if the node exists even if it has no
3574
        // addresses.
3575
        if len(rows) == 0 {
×
3576
                return false, nil, nil
×
3577
        }
×
3578

3579
        addresses := make([]net.Addr, 0, len(rows))
×
3580
        for _, addr := range rows {
×
3581
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3582
                        continue
×
3583
                }
3584

3585
                address := addr.Address.String
×
3586

×
3587
                switch dbAddressType(addr.Type.Int16) {
×
3588
                case addressTypeIPv4:
×
3589
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3590
                        if err != nil {
×
3591
                                return false, nil, nil
×
3592
                        }
×
3593
                        tcp.IP = tcp.IP.To4()
×
3594

×
3595
                        addresses = append(addresses, tcp)
×
3596

3597
                case addressTypeIPv6:
×
3598
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3599
                        if err != nil {
×
3600
                                return false, nil, nil
×
3601
                        }
×
3602
                        addresses = append(addresses, tcp)
×
3603

3604
                case addressTypeTorV3, addressTypeTorV2:
×
3605
                        service, portStr, err := net.SplitHostPort(address)
×
3606
                        if err != nil {
×
3607
                                return false, nil, fmt.Errorf("unable to "+
×
3608
                                        "split tor v3 address: %v",
×
3609
                                        addr.Address)
×
3610
                        }
×
3611

3612
                        port, err := strconv.Atoi(portStr)
×
3613
                        if err != nil {
×
3614
                                return false, nil, err
×
3615
                        }
×
3616

3617
                        addresses = append(addresses, &tor.OnionAddr{
×
3618
                                OnionService: service,
×
3619
                                Port:         port,
×
3620
                        })
×
3621

3622
                case addressTypeOpaque:
×
3623
                        opaque, err := hex.DecodeString(address)
×
3624
                        if err != nil {
×
3625
                                return false, nil, fmt.Errorf("unable to "+
×
3626
                                        "decode opaque address: %v", addr)
×
3627
                        }
×
3628

3629
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3630
                                Payload: opaque,
×
3631
                        })
×
3632

3633
                default:
×
3634
                        return false, nil, fmt.Errorf("unknown address "+
×
3635
                                "type: %v", addr.Type)
×
3636
                }
3637
        }
3638

3639
        return true, addresses, nil
×
3640
}
3641

3642
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3643
// database. This includes updating any existing types, inserting any new types,
3644
// and deleting any types that are no longer present.
3645
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3646
        nodeID int64, extraFields map[uint64][]byte) error {
×
3647

×
3648
        // Get any existing extra signed fields for the node.
×
3649
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3650
        if err != nil {
×
3651
                return err
×
3652
        }
×
3653

3654
        // Make a lookup map of the existing field types so that we can use it
3655
        // to keep track of any fields we should delete.
3656
        m := make(map[uint64]bool)
×
3657
        for _, field := range existingFields {
×
3658
                m[uint64(field.Type)] = true
×
3659
        }
×
3660

3661
        // For all the new fields, we'll upsert them and remove them from the
3662
        // map of existing fields.
3663
        for tlvType, value := range extraFields {
×
3664
                err = db.UpsertNodeExtraType(
×
3665
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3666
                                NodeID: nodeID,
×
3667
                                Type:   int64(tlvType),
×
3668
                                Value:  value,
×
3669
                        },
×
3670
                )
×
3671
                if err != nil {
×
3672
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3673
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3674
                }
×
3675

3676
                // Remove the field from the map of existing fields if it was
3677
                // present.
3678
                delete(m, tlvType)
×
3679
        }
3680

3681
        // For all the fields that are left in the map of existing fields, we'll
3682
        // delete them as they are no longer present in the new set of fields.
3683
        for tlvType := range m {
×
3684
                err = db.DeleteExtraNodeType(
×
3685
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3686
                                NodeID: nodeID,
×
3687
                                Type:   int64(tlvType),
×
3688
                        },
×
3689
                )
×
3690
                if err != nil {
×
3691
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3692
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3693
                }
×
3694
        }
3695

3696
        return nil
×
3697
}
3698

3699
// srcNodeInfo holds the information about the source node of the graph.
3700
type srcNodeInfo struct {
3701
        // id is the DB level ID of the source node entry in the "nodes" table.
3702
        id int64
3703

3704
        // pub is the public key of the source node.
3705
        pub route.Vertex
3706
}
3707

3708
// getSourceNode returns the DB node ID and pub key of the source node for the
3709
// specified protocol version.
3710
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3711
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3712

×
3713
        s.srcNodeMu.Lock()
×
3714
        defer s.srcNodeMu.Unlock()
×
3715

×
3716
        // If we already have the source node ID and pub key cached, then
×
3717
        // return them.
×
3718
        if info, ok := s.srcNodes[version]; ok {
×
3719
                return info.id, info.pub, nil
×
3720
        }
×
3721

3722
        var pubKey route.Vertex
×
3723

×
3724
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3725
        if err != nil {
×
3726
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3727
                        err)
×
3728
        }
×
3729

3730
        if len(nodes) == 0 {
×
3731
                return 0, pubKey, ErrSourceNodeNotSet
×
3732
        } else if len(nodes) > 1 {
×
3733
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3734
                        "protocol %s found", version)
×
3735
        }
×
3736

3737
        copy(pubKey[:], nodes[0].PubKey)
×
3738

×
3739
        s.srcNodes[version] = &srcNodeInfo{
×
3740
                id:  nodes[0].NodeID,
×
3741
                pub: pubKey,
×
3742
        }
×
3743

×
3744
        return nodes[0].NodeID, pubKey, nil
×
3745
}
3746

3747
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3748
// This then produces a map from TLV type to value. If the input is not a
3749
// valid TLV stream, then an error is returned.
3750
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3751
        r := bytes.NewReader(data)
×
3752

×
3753
        tlvStream, err := tlv.NewStream()
×
3754
        if err != nil {
×
3755
                return nil, err
×
3756
        }
×
3757

3758
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3759
        // pass it into the P2P decoding variant.
3760
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3761
        if err != nil {
×
3762
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3763
        }
×
3764
        if len(parsedTypes) == 0 {
×
3765
                return nil, nil
×
3766
        }
×
3767

3768
        records := make(map[uint64][]byte)
×
3769
        for k, v := range parsedTypes {
×
3770
                records[uint64(k)] = v
×
3771
        }
×
3772

3773
        return records, nil
×
3774
}
3775

3776
// insertChannel inserts a new channel record into the database.
3777
func insertChannel(ctx context.Context, db SQLQueries,
3778
        edge *models.ChannelEdgeInfo) error {
×
3779

×
3780
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3781

×
3782
        // Make sure that the channel doesn't already exist. We do this
×
3783
        // explicitly instead of relying on catching a unique constraint error
×
3784
        // because relying on SQL to throw that error would abort the entire
×
3785
        // batch of transactions.
×
3786
        _, err := db.GetChannelBySCID(
×
3787
                ctx, sqlc.GetChannelBySCIDParams{
×
3788
                        Scid:    chanIDB[:],
×
3789
                        Version: int16(ProtocolV1),
×
3790
                },
×
3791
        )
×
3792
        if err == nil {
×
3793
                return ErrEdgeAlreadyExist
×
3794
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3795
                return fmt.Errorf("unable to fetch channel: %w", err)
×
3796
        }
×
3797

3798
        // Make sure that at least a "shell" entry for each node is present in
3799
        // the nodes table.
3800
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3801
        if err != nil {
×
3802
                return fmt.Errorf("unable to create shell node: %w", err)
×
3803
        }
×
3804

3805
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3806
        if err != nil {
×
3807
                return fmt.Errorf("unable to create shell node: %w", err)
×
3808
        }
×
3809

3810
        var capacity sql.NullInt64
×
3811
        if edge.Capacity != 0 {
×
3812
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3813
        }
×
3814

3815
        createParams := sqlc.CreateChannelParams{
×
3816
                Version:     int16(ProtocolV1),
×
3817
                Scid:        chanIDB[:],
×
3818
                NodeID1:     node1DBID,
×
3819
                NodeID2:     node2DBID,
×
3820
                Outpoint:    edge.ChannelPoint.String(),
×
3821
                Capacity:    capacity,
×
3822
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3823
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3824
        }
×
3825

×
3826
        if edge.AuthProof != nil {
×
3827
                proof := edge.AuthProof
×
3828

×
3829
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3830
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3831
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3832
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3833
        }
×
3834

3835
        // Insert the new channel record.
3836
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3837
        if err != nil {
×
3838
                return err
×
3839
        }
×
3840

3841
        // Insert any channel features.
3842
        if len(edge.Features) != 0 {
×
3843
                chanFeatures := lnwire.NewRawFeatureVector()
×
3844
                err := chanFeatures.Decode(bytes.NewReader(edge.Features))
×
3845
                if err != nil {
×
3846
                        return err
×
3847
                }
×
3848

3849
                fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features)
×
3850
                for feature := range fv.Features() {
×
3851
                        err = db.InsertChannelFeature(
×
3852
                                ctx, sqlc.InsertChannelFeatureParams{
×
3853
                                        ChannelID:  dbChanID,
×
3854
                                        FeatureBit: int32(feature),
×
3855
                                },
×
3856
                        )
×
3857
                        if err != nil {
×
3858
                                return fmt.Errorf("unable to insert "+
×
3859
                                        "channel(%d) feature(%v): %w", dbChanID,
×
3860
                                        feature, err)
×
3861
                        }
×
3862
                }
3863
        }
3864

3865
        // Finally, insert any extra TLV fields in the channel announcement.
3866
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3867
        if err != nil {
×
3868
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
3869
                        err)
×
3870
        }
×
3871

3872
        for tlvType, value := range extra {
×
3873
                err := db.CreateChannelExtraType(
×
3874
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3875
                                ChannelID: dbChanID,
×
3876
                                Type:      int64(tlvType),
×
3877
                                Value:     value,
×
3878
                        },
×
3879
                )
×
3880
                if err != nil {
×
3881
                        return fmt.Errorf("unable to upsert channel(%d) extra "+
×
3882
                                "signed field(%v): %w", edge.ChannelID,
×
3883
                                tlvType, err)
×
3884
                }
×
3885
        }
3886

3887
        return nil
×
3888
}
3889

3890
// maybeCreateShellNode checks if a shell node entry exists for the
3891
// given public key. If it does not exist, then a new shell node entry is
3892
// created. The ID of the node is returned. A shell node only has a protocol
3893
// version and public key persisted.
3894
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3895
        pubKey route.Vertex) (int64, error) {
×
3896

×
3897
        dbNode, err := db.GetNodeByPubKey(
×
3898
                ctx, sqlc.GetNodeByPubKeyParams{
×
3899
                        PubKey:  pubKey[:],
×
3900
                        Version: int16(ProtocolV1),
×
3901
                },
×
3902
        )
×
3903
        // The node exists. Return the ID.
×
3904
        if err == nil {
×
3905
                return dbNode.ID, nil
×
3906
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3907
                return 0, err
×
3908
        }
×
3909

3910
        // Otherwise, the node does not exist, so we create a shell entry for
3911
        // it.
3912
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3913
                Version: int16(ProtocolV1),
×
3914
                PubKey:  pubKey[:],
×
3915
        })
×
3916
        if err != nil {
×
3917
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3918
        }
×
3919

3920
        return id, nil
×
3921
}
3922

3923
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3924
// the database. This includes deleting any existing types and then inserting
3925
// the new types.
3926
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3927
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3928

×
3929
        // Delete all existing extra signed fields for the channel policy.
×
3930
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3931
        if err != nil {
×
3932
                return fmt.Errorf("unable to delete "+
×
3933
                        "existing policy extra signed fields for policy %d: %w",
×
3934
                        chanPolicyID, err)
×
3935
        }
×
3936

3937
        // Insert all new extra signed fields for the channel policy.
3938
        for tlvType, value := range extraFields {
×
3939
                err = db.InsertChanPolicyExtraType(
×
3940
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3941
                                ChannelPolicyID: chanPolicyID,
×
3942
                                Type:            int64(tlvType),
×
3943
                                Value:           value,
×
3944
                        },
×
3945
                )
×
3946
                if err != nil {
×
3947
                        return fmt.Errorf("unable to insert "+
×
3948
                                "channel_policy(%d) extra signed field(%v): %w",
×
3949
                                chanPolicyID, tlvType, err)
×
3950
                }
×
3951
        }
3952

3953
        return nil
×
3954
}
3955

3956
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3957
// provided dbChanRow and also fetches any other required information
3958
// to construct the edge info.
3959
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
3960
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1,
3961
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3962

×
3963
        if dbChan.Version != int16(ProtocolV1) {
×
3964
                return nil, fmt.Errorf("unsupported channel version: %d",
×
3965
                        dbChan.Version)
×
3966
        }
×
3967

3968
        fv, extras, err := getChanFeaturesAndExtras(
×
3969
                ctx, db, dbChanID,
×
3970
        )
×
3971
        if err != nil {
×
3972
                return nil, err
×
3973
        }
×
3974

3975
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
3976
        if err != nil {
×
3977
                return nil, err
×
3978
        }
×
3979

3980
        var featureBuf bytes.Buffer
×
3981
        if err := fv.Encode(&featureBuf); err != nil {
×
3982
                return nil, fmt.Errorf("unable to encode features: %w", err)
×
3983
        }
×
3984

3985
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3986
        if err != nil {
×
3987
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3988
                        "fields: %w", err)
×
3989
        }
×
3990
        if recs == nil {
×
3991
                recs = make([]byte, 0)
×
3992
        }
×
3993

3994
        var btcKey1, btcKey2 route.Vertex
×
3995
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
3996
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
3997

×
3998
        channel := &models.ChannelEdgeInfo{
×
3999
                ChainHash:        chain,
×
4000
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4001
                NodeKey1Bytes:    node1,
×
4002
                NodeKey2Bytes:    node2,
×
4003
                BitcoinKey1Bytes: btcKey1,
×
4004
                BitcoinKey2Bytes: btcKey2,
×
4005
                ChannelPoint:     *op,
×
4006
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4007
                Features:         featureBuf.Bytes(),
×
4008
                ExtraOpaqueData:  recs,
×
4009
        }
×
4010

×
4011
        // We always set all the signatures at the same time, so we can
×
4012
        // safely check if one signature is present to determine if we have the
×
4013
        // rest of the signatures for the auth proof.
×
4014
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4015
                channel.AuthProof = &models.ChannelAuthProof{
×
4016
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4017
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4018
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4019
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4020
                }
×
4021
        }
×
4022

4023
        return channel, nil
×
4024
}
4025

4026
// buildNodeVertices is a helper that converts raw node public keys
4027
// into route.Vertex instances.
4028
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4029
        route.Vertex, error) {
×
4030

×
4031
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4032
        if err != nil {
×
4033
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4034
                        "create vertex from node1 pubkey: %w", err)
×
4035
        }
×
4036

4037
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4038
        if err != nil {
×
4039
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4040
                        "create vertex from node2 pubkey: %w", err)
×
4041
        }
×
4042

4043
        return node1Vertex, node2Vertex, nil
×
4044
}
4045

4046
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
4047
// for a channel with the given ID.
4048
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
4049
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
4050

×
4051
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
4052
        if err != nil {
×
4053
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
4054
                        "features and extras: %w", err)
×
4055
        }
×
4056

4057
        var (
×
4058
                fv     = lnwire.EmptyFeatureVector()
×
4059
                extras = make(map[uint64][]byte)
×
4060
        )
×
4061
        for _, row := range rows {
×
4062
                if row.IsFeature {
×
4063
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
4064

×
4065
                        continue
×
4066
                }
4067

4068
                tlvType, ok := row.ExtraKey.(int64)
×
4069
                if !ok {
×
4070
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4071
                                "TLV type: %T", row.ExtraKey)
×
4072
                }
×
4073

4074
                valueBytes, ok := row.Value.([]byte)
×
4075
                if !ok {
×
4076
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4077
                                "Value: %T", row.Value)
×
4078
                }
×
4079

4080
                extras[uint64(tlvType)] = valueBytes
×
4081
        }
4082

4083
        return fv, extras, nil
×
4084
}
4085

4086
// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves
4087
// all the extra info required to build the complete models.ChannelEdgePolicy
4088
// types. It returns two policies, which may be nil if the provided
4089
// sqlc.ChannelPolicy records are nil.
4090
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4091
        dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1,
4092
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4093
        *models.ChannelEdgePolicy, error) {
×
4094

×
4095
        if dbPol1 == nil && dbPol2 == nil {
×
4096
                return nil, nil, nil
×
4097
        }
×
4098

4099
        var (
×
4100
                policy1ID int64
×
4101
                policy2ID int64
×
4102
        )
×
4103
        if dbPol1 != nil {
×
4104
                policy1ID = dbPol1.ID
×
4105
        }
×
4106
        if dbPol2 != nil {
×
4107
                policy2ID = dbPol2.ID
×
4108
        }
×
4109
        rows, err := db.GetChannelPolicyExtraTypes(
×
4110
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
4111
                        ID:   policy1ID,
×
4112
                        ID_2: policy2ID,
×
4113
                },
×
4114
        )
×
4115
        if err != nil {
×
4116
                return nil, nil, err
×
4117
        }
×
4118

4119
        var (
×
4120
                dbPol1Extras = make(map[uint64][]byte)
×
4121
                dbPol2Extras = make(map[uint64][]byte)
×
4122
        )
×
4123
        for _, row := range rows {
×
4124
                switch row.PolicyID {
×
4125
                case policy1ID:
×
4126
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
4127
                case policy2ID:
×
4128
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
4129
                default:
×
4130
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
4131
                                "in row: %v", row.PolicyID, row)
×
4132
                }
4133
        }
4134

4135
        var pol1, pol2 *models.ChannelEdgePolicy
×
4136
        if dbPol1 != nil {
×
4137
                pol1, err = buildChanPolicy(
×
4138
                        *dbPol1, channelID, dbPol1Extras, node2, true,
×
4139
                )
×
4140
                if err != nil {
×
4141
                        return nil, nil, err
×
4142
                }
×
4143
        }
4144
        if dbPol2 != nil {
×
4145
                pol2, err = buildChanPolicy(
×
4146
                        *dbPol2, channelID, dbPol2Extras, node1, false,
×
4147
                )
×
4148
                if err != nil {
×
4149
                        return nil, nil, err
×
4150
                }
×
4151
        }
4152

4153
        return pol1, pol2, nil
×
4154
}
4155

4156
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4157
// provided sqlc.ChannelPolicy and other required information.
4158
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
4159
        extras map[uint64][]byte, toNode route.Vertex,
4160
        isNode1 bool) (*models.ChannelEdgePolicy, error) {
×
4161

×
4162
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4163
        if err != nil {
×
4164
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4165
                        "fields: %w", err)
×
4166
        }
×
4167

4168
        var msgFlags lnwire.ChanUpdateMsgFlags
×
4169
        if dbPolicy.MaxHtlcMsat.Valid {
×
4170
                msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc
×
4171
        }
×
4172

4173
        var chanFlags lnwire.ChanUpdateChanFlags
×
4174
        if !isNode1 {
×
4175
                chanFlags |= lnwire.ChanUpdateDirection
×
4176
        }
×
4177
        if dbPolicy.Disabled.Bool {
×
4178
                chanFlags |= lnwire.ChanUpdateDisabled
×
4179
        }
×
4180

4181
        var inboundFee fn.Option[lnwire.Fee]
×
4182
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4183
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4184

×
4185
                inboundFee = fn.Some(lnwire.Fee{
×
4186
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4187
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4188
                })
×
4189
        }
×
4190

4191
        return &models.ChannelEdgePolicy{
×
4192
                SigBytes:  dbPolicy.Signature,
×
4193
                ChannelID: channelID,
×
4194
                LastUpdate: time.Unix(
×
4195
                        dbPolicy.LastUpdate.Int64, 0,
×
4196
                ),
×
4197
                MessageFlags:  msgFlags,
×
4198
                ChannelFlags:  chanFlags,
×
4199
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4200
                MinHTLC: lnwire.MilliSatoshi(
×
4201
                        dbPolicy.MinHtlcMsat,
×
4202
                ),
×
4203
                MaxHTLC: lnwire.MilliSatoshi(
×
4204
                        dbPolicy.MaxHtlcMsat.Int64,
×
4205
                ),
×
4206
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4207
                        dbPolicy.BaseFeeMsat,
×
4208
                ),
×
4209
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4210
                ToNode:                    toNode,
×
4211
                InboundFee:                inboundFee,
×
4212
                ExtraOpaqueData:           recs,
×
4213
        }, nil
×
4214
}
4215

4216
// buildNodes builds the models.LightningNode instances for the
4217
// given row which is expected to be a sqlc type that contains node information.
4218
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4219
        dbNode2 sqlc.Node) (*models.LightningNode, *models.LightningNode,
4220
        error) {
×
4221

×
4222
        node1, err := buildNode(ctx, db, &dbNode1)
×
4223
        if err != nil {
×
4224
                return nil, nil, err
×
4225
        }
×
4226

4227
        node2, err := buildNode(ctx, db, &dbNode2)
×
4228
        if err != nil {
×
4229
                return nil, nil, err
×
4230
        }
×
4231

4232
        return node1, node2, nil
×
4233
}
4234

4235
// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give
4236
// row which is expected to be a sqlc type that contains channel policy
4237
// information. It returns two policies, which may be nil if the policy
4238
// information is not present in the row.
4239
//
4240
//nolint:ll,dupl,funlen
4241
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
4242
        error) {
×
4243

×
4244
        var policy1, policy2 *sqlc.ChannelPolicy
×
4245
        switch r := row.(type) {
×
4246
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4247
                if r.Policy1ID.Valid {
×
4248
                        policy1 = &sqlc.ChannelPolicy{
×
4249
                                ID:                      r.Policy1ID.Int64,
×
4250
                                Version:                 r.Policy1Version.Int16,
×
4251
                                ChannelID:               r.Channel.ID,
×
4252
                                NodeID:                  r.Policy1NodeID.Int64,
×
4253
                                Timelock:                r.Policy1Timelock.Int32,
×
4254
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4255
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4256
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4257
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4258
                                LastUpdate:              r.Policy1LastUpdate,
×
4259
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4260
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4261
                                Disabled:                r.Policy1Disabled,
×
4262
                                Signature:               r.Policy1Signature,
×
4263
                        }
×
4264
                }
×
4265
                if r.Policy2ID.Valid {
×
4266
                        policy2 = &sqlc.ChannelPolicy{
×
4267
                                ID:                      r.Policy2ID.Int64,
×
4268
                                Version:                 r.Policy2Version.Int16,
×
4269
                                ChannelID:               r.Channel.ID,
×
4270
                                NodeID:                  r.Policy2NodeID.Int64,
×
4271
                                Timelock:                r.Policy2Timelock.Int32,
×
4272
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4273
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4274
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4275
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4276
                                LastUpdate:              r.Policy2LastUpdate,
×
4277
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4278
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4279
                                Disabled:                r.Policy2Disabled,
×
4280
                                Signature:               r.Policy2Signature,
×
4281
                        }
×
4282
                }
×
4283

4284
                return policy1, policy2, nil
×
4285

4286
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4287
                if r.Policy1ID.Valid {
×
4288
                        policy1 = &sqlc.ChannelPolicy{
×
4289
                                ID:                      r.Policy1ID.Int64,
×
4290
                                Version:                 r.Policy1Version.Int16,
×
4291
                                ChannelID:               r.Channel.ID,
×
4292
                                NodeID:                  r.Policy1NodeID.Int64,
×
4293
                                Timelock:                r.Policy1Timelock.Int32,
×
4294
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4295
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4296
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4297
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4298
                                LastUpdate:              r.Policy1LastUpdate,
×
4299
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4300
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4301
                                Disabled:                r.Policy1Disabled,
×
4302
                                Signature:               r.Policy1Signature,
×
4303
                        }
×
4304
                }
×
4305
                if r.Policy2ID.Valid {
×
4306
                        policy2 = &sqlc.ChannelPolicy{
×
4307
                                ID:                      r.Policy2ID.Int64,
×
4308
                                Version:                 r.Policy2Version.Int16,
×
4309
                                ChannelID:               r.Channel.ID,
×
4310
                                NodeID:                  r.Policy2NodeID.Int64,
×
4311
                                Timelock:                r.Policy2Timelock.Int32,
×
4312
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4313
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4314
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4315
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4316
                                LastUpdate:              r.Policy2LastUpdate,
×
4317
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4318
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4319
                                Disabled:                r.Policy2Disabled,
×
4320
                                Signature:               r.Policy2Signature,
×
4321
                        }
×
4322
                }
×
4323

4324
                return policy1, policy2, nil
×
4325

4326
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4327
                if r.Policy1ID.Valid {
×
4328
                        policy1 = &sqlc.ChannelPolicy{
×
4329
                                ID:                      r.Policy1ID.Int64,
×
4330
                                Version:                 r.Policy1Version.Int16,
×
4331
                                ChannelID:               r.Channel.ID,
×
4332
                                NodeID:                  r.Policy1NodeID.Int64,
×
4333
                                Timelock:                r.Policy1Timelock.Int32,
×
4334
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4335
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4336
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4337
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4338
                                LastUpdate:              r.Policy1LastUpdate,
×
4339
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4340
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4341
                                Disabled:                r.Policy1Disabled,
×
4342
                                Signature:               r.Policy1Signature,
×
4343
                        }
×
4344
                }
×
4345
                if r.Policy2ID.Valid {
×
4346
                        policy2 = &sqlc.ChannelPolicy{
×
4347
                                ID:                      r.Policy2ID.Int64,
×
4348
                                Version:                 r.Policy2Version.Int16,
×
4349
                                ChannelID:               r.Channel.ID,
×
4350
                                NodeID:                  r.Policy2NodeID.Int64,
×
4351
                                Timelock:                r.Policy2Timelock.Int32,
×
4352
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4353
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4354
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4355
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4356
                                LastUpdate:              r.Policy2LastUpdate,
×
4357
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4358
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4359
                                Disabled:                r.Policy2Disabled,
×
4360
                                Signature:               r.Policy2Signature,
×
4361
                        }
×
4362
                }
×
4363

4364
                return policy1, policy2, nil
×
4365

4366
        case sqlc.ListChannelsByNodeIDRow:
×
4367
                if r.Policy1ID.Valid {
×
4368
                        policy1 = &sqlc.ChannelPolicy{
×
4369
                                ID:                      r.Policy1ID.Int64,
×
4370
                                Version:                 r.Policy1Version.Int16,
×
4371
                                ChannelID:               r.Channel.ID,
×
4372
                                NodeID:                  r.Policy1NodeID.Int64,
×
4373
                                Timelock:                r.Policy1Timelock.Int32,
×
4374
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4375
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4376
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4377
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4378
                                LastUpdate:              r.Policy1LastUpdate,
×
4379
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4380
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4381
                                Disabled:                r.Policy1Disabled,
×
4382
                                Signature:               r.Policy1Signature,
×
4383
                        }
×
4384
                }
×
4385
                if r.Policy2ID.Valid {
×
4386
                        policy2 = &sqlc.ChannelPolicy{
×
4387
                                ID:                      r.Policy2ID.Int64,
×
4388
                                Version:                 r.Policy2Version.Int16,
×
4389
                                ChannelID:               r.Channel.ID,
×
4390
                                NodeID:                  r.Policy2NodeID.Int64,
×
4391
                                Timelock:                r.Policy2Timelock.Int32,
×
4392
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4393
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4394
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4395
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4396
                                LastUpdate:              r.Policy2LastUpdate,
×
4397
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4398
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4399
                                Disabled:                r.Policy2Disabled,
×
4400
                                Signature:               r.Policy2Signature,
×
4401
                        }
×
4402
                }
×
4403

4404
                return policy1, policy2, nil
×
4405

4406
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4407
                if r.Policy1ID.Valid {
×
4408
                        policy1 = &sqlc.ChannelPolicy{
×
4409
                                ID:                      r.Policy1ID.Int64,
×
4410
                                Version:                 r.Policy1Version.Int16,
×
4411
                                ChannelID:               r.Channel.ID,
×
4412
                                NodeID:                  r.Policy1NodeID.Int64,
×
4413
                                Timelock:                r.Policy1Timelock.Int32,
×
4414
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4415
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4416
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4417
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4418
                                LastUpdate:              r.Policy1LastUpdate,
×
4419
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4420
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4421
                                Disabled:                r.Policy1Disabled,
×
4422
                                Signature:               r.Policy1Signature,
×
4423
                        }
×
4424
                }
×
4425
                if r.Policy2ID.Valid {
×
4426
                        policy2 = &sqlc.ChannelPolicy{
×
4427
                                ID:                      r.Policy2ID.Int64,
×
4428
                                Version:                 r.Policy2Version.Int16,
×
4429
                                ChannelID:               r.Channel.ID,
×
4430
                                NodeID:                  r.Policy2NodeID.Int64,
×
4431
                                Timelock:                r.Policy2Timelock.Int32,
×
4432
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4433
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4434
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4435
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4436
                                LastUpdate:              r.Policy2LastUpdate,
×
4437
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4438
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4439
                                Disabled:                r.Policy2Disabled,
×
4440
                                Signature:               r.Policy2Signature,
×
4441
                        }
×
4442
                }
×
4443

4444
                return policy1, policy2, nil
×
4445
        default:
×
4446
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4447
                        "extractChannelPolicies: %T", r)
×
4448
        }
4449
}
4450

4451
// channelIDToBytes converts a channel ID (SCID) to a byte array
4452
// representation.
4453
func channelIDToBytes(channelID uint64) [8]byte {
×
4454
        var chanIDB [8]byte
×
4455
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4456

×
4457
        return chanIDB
×
4458
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc