• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15978799235

30 Jun 2025 04:47PM UTC coverage: 57.813% (-9.8%) from 67.608%
15978799235

Pull #10011

github

web-flow
Merge d0538fdbe into e54206f8c
Pull Request #10011: refactor+graph/db: refactor preparations required for incoming SQL migration code

18 of 69 new or added lines in 2 files covered. (26.09%)

28400 existing lines in 458 files now uncovered.

98467 of 170321 relevant lines covered (57.81%)

1.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        GetUnconnectedNodes(ctx context.Context) ([]sqlc.GetUnconnectedNodesRow, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
72
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
73
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
74

75
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
96
        GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error)
97
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
98
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
99
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
100
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
101
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
102
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
103
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
104
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
105
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
106
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
107
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
108
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
109
        DeleteChannel(ctx context.Context, id int64) error
110

111
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
112
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
113

114
        /*
115
                Channel Policy table queries.
116
        */
117
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
118
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.ChannelPolicy, error)
119
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
120

121
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
122
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
123
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
124

125
        /*
126
                Zombie index queries.
127
        */
128
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
129
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error)
130
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
131
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
132
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
133

134
        /*
135
                Prune log table queries.
136
        */
137
        GetPruneTip(ctx context.Context) (sqlc.PruneLog, error)
138
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
139
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
140

141
        /*
142
                Closed SCID table queries.
143
        */
144
        InsertClosedChannel(ctx context.Context, scid []byte) error
145
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
146
}
147

148
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
149
// database operations.
150
type BatchedSQLQueries interface {
151
        SQLQueries
152
        sqldb.BatchedTx[SQLQueries]
153
}
154

155
// SQLStore is an implementation of the V1Store interface that uses a SQL
156
// database as the backend.
157
//
158
// NOTE: currently, this temporarily embeds the KVStore struct so that we can
159
// implement the V1Store interface incrementally. For any method not
160
// implemented,  things will fall back to the KVStore. This is ONLY the case
161
// for the time being while this struct is purely used in unit tests only.
162
type SQLStore struct {
163
        cfg *SQLStoreConfig
164
        db  BatchedSQLQueries
165

166
        // cacheMu guards all caches (rejectCache and chanCache). If
167
        // this mutex will be acquired at the same time as the DB mutex then
168
        // the cacheMu MUST be acquired first to prevent deadlock.
169
        cacheMu     sync.RWMutex
170
        rejectCache *rejectCache
171
        chanCache   *channelCache
172

173
        chanScheduler batch.Scheduler[SQLQueries]
174
        nodeScheduler batch.Scheduler[SQLQueries]
175

176
        srcNodes  map[ProtocolVersion]*srcNodeInfo
177
        srcNodeMu sync.Mutex
178
}
179

180
// A compile-time assertion to ensure that SQLStore implements the V1Store
181
// interface.
182
var _ V1Store = (*SQLStore)(nil)
183

184
// SQLStoreConfig holds the configuration for the SQLStore.
185
type SQLStoreConfig struct {
186
        // ChainHash is the genesis hash for the chain that all the gossip
187
        // messages in this store are aimed at.
188
        ChainHash chainhash.Hash
189
}
190

191
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
192
// storage backend.
193
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
194
        options ...StoreOptionModifier) (*SQLStore, error) {
×
195

×
196
        opts := DefaultOptions()
×
197
        for _, o := range options {
×
198
                o(opts)
×
199
        }
×
200

201
        if opts.NoMigration {
×
202
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
203
                        "supported for SQL stores")
×
204
        }
×
205

206
        s := &SQLStore{
×
207
                cfg:         cfg,
×
208
                db:          db,
×
209
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
210
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
211
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
212
        }
×
213

×
214
        s.chanScheduler = batch.NewTimeScheduler(
×
215
                db, &s.cacheMu, opts.BatchCommitInterval,
×
216
        )
×
217
        s.nodeScheduler = batch.NewTimeScheduler(
×
218
                db, nil, opts.BatchCommitInterval,
×
219
        )
×
220

×
221
        return s, nil
×
222
}
223

224
// AddLightningNode adds a vertex/node to the graph database. If the node is not
225
// in the database from before, this will add a new, unconnected one to the
226
// graph. If it is present from before, this will update that node's
227
// information.
228
//
229
// NOTE: part of the V1Store interface.
230
func (s *SQLStore) AddLightningNode(ctx context.Context,
231
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
232

×
233
        r := &batch.Request[SQLQueries]{
×
234
                Opts: batch.NewSchedulerOptions(opts...),
×
235
                Do: func(queries SQLQueries) error {
×
236
                        _, err := upsertNode(ctx, queries, node)
×
237
                        return err
×
238
                },
×
239
        }
240

241
        return s.nodeScheduler.Execute(ctx, r)
×
242
}
243

244
// FetchLightningNode attempts to look up a target node by its identity public
245
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
246
// returned.
247
//
248
// NOTE: part of the V1Store interface.
249
func (s *SQLStore) FetchLightningNode(ctx context.Context,
250
        pubKey route.Vertex) (*models.LightningNode, error) {
×
251

×
252
        var node *models.LightningNode
×
253
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
254
                var err error
×
255
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
256

×
257
                return err
×
258
        }, sqldb.NoOpReset)
×
259
        if err != nil {
×
260
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
261
        }
×
262

263
        return node, nil
×
264
}
265

266
// HasLightningNode determines if the graph has a vertex identified by the
267
// target node identity public key. If the node exists in the database, a
268
// timestamp of when the data for the node was lasted updated is returned along
269
// with a true boolean. Otherwise, an empty time.Time is returned with a false
270
// boolean.
271
//
272
// NOTE: part of the V1Store interface.
273
func (s *SQLStore) HasLightningNode(ctx context.Context,
274
        pubKey [33]byte) (time.Time, bool, error) {
×
275

×
276
        var (
×
277
                exists     bool
×
278
                lastUpdate time.Time
×
279
        )
×
280
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
281
                dbNode, err := db.GetNodeByPubKey(
×
282
                        ctx, sqlc.GetNodeByPubKeyParams{
×
283
                                Version: int16(ProtocolV1),
×
284
                                PubKey:  pubKey[:],
×
285
                        },
×
286
                )
×
287
                if errors.Is(err, sql.ErrNoRows) {
×
288
                        return nil
×
289
                } else if err != nil {
×
290
                        return fmt.Errorf("unable to fetch node: %w", err)
×
291
                }
×
292

293
                exists = true
×
294

×
295
                if dbNode.LastUpdate.Valid {
×
296
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
297
                }
×
298

299
                return nil
×
300
        }, sqldb.NoOpReset)
301
        if err != nil {
×
302
                return time.Time{}, false,
×
303
                        fmt.Errorf("unable to fetch node: %w", err)
×
304
        }
×
305

306
        return lastUpdate, exists, nil
×
307
}
308

309
// AddrsForNode returns all known addresses for the target node public key
310
// that the graph DB is aware of. The returned boolean indicates if the
311
// given node is unknown to the graph DB or not.
312
//
313
// NOTE: part of the V1Store interface.
314
func (s *SQLStore) AddrsForNode(ctx context.Context,
315
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
316

×
317
        var (
×
318
                addresses []net.Addr
×
319
                known     bool
×
320
        )
×
321
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
322
                var err error
×
323
                known, addresses, err = getNodeAddresses(
×
324
                        ctx, db, nodePub.SerializeCompressed(),
×
325
                )
×
326
                if err != nil {
×
327
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
328
                                err)
×
329
                }
×
330

331
                return nil
×
332
        }, sqldb.NoOpReset)
333
        if err != nil {
×
334
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
335
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
336
        }
×
337

338
        return known, addresses, nil
×
339
}
340

341
// DeleteLightningNode starts a new database transaction to remove a vertex/node
342
// from the database according to the node's public key.
343
//
344
// NOTE: part of the V1Store interface.
345
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
346
        pubKey route.Vertex) error {
×
347

×
348
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
349
                res, err := db.DeleteNodeByPubKey(
×
350
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
351
                                Version: int16(ProtocolV1),
×
352
                                PubKey:  pubKey[:],
×
353
                        },
×
354
                )
×
355
                if err != nil {
×
356
                        return err
×
357
                }
×
358

359
                rows, err := res.RowsAffected()
×
360
                if err != nil {
×
361
                        return err
×
362
                }
×
363

364
                if rows == 0 {
×
365
                        return ErrGraphNodeNotFound
×
366
                } else if rows > 1 {
×
367
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
368
                }
×
369

370
                return err
×
371
        }, sqldb.NoOpReset)
372
        if err != nil {
×
373
                return fmt.Errorf("unable to delete node: %w", err)
×
374
        }
×
375

376
        return nil
×
377
}
378

379
// FetchNodeFeatures returns the features of the given node. If no features are
380
// known for the node, an empty feature vector is returned.
381
//
382
// NOTE: this is part of the graphdb.NodeTraverser interface.
383
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
384
        *lnwire.FeatureVector, error) {
×
385

×
386
        ctx := context.TODO()
×
387

×
388
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
389
}
×
390

391
// DisabledChannelIDs returns the channel ids of disabled channels.
392
// A channel is disabled when two of the associated ChanelEdgePolicies
393
// have their disabled bit on.
394
//
395
// NOTE: part of the V1Store interface.
396
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
397
        var (
×
398
                ctx     = context.TODO()
×
399
                chanIDs []uint64
×
400
        )
×
401
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
402
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
403
                if err != nil {
×
404
                        return fmt.Errorf("unable to fetch disabled "+
×
405
                                "channels: %w", err)
×
406
                }
×
407

408
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
409

×
410
                return nil
×
411
        }, sqldb.NoOpReset)
412
        if err != nil {
×
413
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
414
                        err)
×
415
        }
×
416

417
        return chanIDs, nil
×
418
}
419

420
// LookupAlias attempts to return the alias as advertised by the target node.
421
//
422
// NOTE: part of the V1Store interface.
423
func (s *SQLStore) LookupAlias(ctx context.Context,
424
        pub *btcec.PublicKey) (string, error) {
×
425

×
426
        var alias string
×
427
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
428
                dbNode, err := db.GetNodeByPubKey(
×
429
                        ctx, sqlc.GetNodeByPubKeyParams{
×
430
                                Version: int16(ProtocolV1),
×
431
                                PubKey:  pub.SerializeCompressed(),
×
432
                        },
×
433
                )
×
434
                if errors.Is(err, sql.ErrNoRows) {
×
435
                        return ErrNodeAliasNotFound
×
436
                } else if err != nil {
×
437
                        return fmt.Errorf("unable to fetch node: %w", err)
×
438
                }
×
439

440
                if !dbNode.Alias.Valid {
×
441
                        return ErrNodeAliasNotFound
×
442
                }
×
443

444
                alias = dbNode.Alias.String
×
445

×
446
                return nil
×
447
        }, sqldb.NoOpReset)
448
        if err != nil {
×
449
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
450
        }
×
451

452
        return alias, nil
×
453
}
454

455
// SourceNode returns the source node of the graph. The source node is treated
456
// as the center node within a star-graph. This method may be used to kick off
457
// a path finding algorithm in order to explore the reachability of another
458
// node based off the source node.
459
//
460
// NOTE: part of the V1Store interface.
461
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
462
        error) {
×
463

×
464
        var node *models.LightningNode
×
465
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
466
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
467
                if err != nil {
×
468
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
469
                                err)
×
470
                }
×
471

472
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
473

×
474
                return err
×
475
        }, sqldb.NoOpReset)
476
        if err != nil {
×
477
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
478
        }
×
479

480
        return node, nil
×
481
}
482

483
// SetSourceNode sets the source node within the graph database. The source
484
// node is to be used as the center of a star-graph within path finding
485
// algorithms.
486
//
487
// NOTE: part of the V1Store interface.
488
func (s *SQLStore) SetSourceNode(ctx context.Context,
489
        node *models.LightningNode) error {
×
490

×
491
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
492
                id, err := upsertNode(ctx, db, node)
×
493
                if err != nil {
×
494
                        return fmt.Errorf("unable to upsert source node: %w",
×
495
                                err)
×
496
                }
×
497

498
                // Make sure that if a source node for this version is already
499
                // set, then the ID is the same as the one we are about to set.
500
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
501
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
502
                        return fmt.Errorf("unable to fetch source node: %w",
×
503
                                err)
×
504
                } else if err == nil {
×
505
                        if dbSourceNodeID != id {
×
506
                                return fmt.Errorf("v1 source node already "+
×
507
                                        "set to a different node: %d vs %d",
×
508
                                        dbSourceNodeID, id)
×
509
                        }
×
510

511
                        return nil
×
512
                }
513

514
                return db.AddSourceNode(ctx, id)
×
515
        }, sqldb.NoOpReset)
516
}
517

518
// NodeUpdatesInHorizon returns all the known lightning node which have an
519
// update timestamp within the passed range. This method can be used by two
520
// nodes to quickly determine if they have the same set of up to date node
521
// announcements.
522
//
523
// NOTE: This is part of the V1Store interface.
524
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
525
        endTime time.Time) ([]models.LightningNode, error) {
×
526

×
527
        ctx := context.TODO()
×
528

×
529
        var nodes []models.LightningNode
×
530
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
531
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
532
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
533
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
534
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
535
                        },
×
536
                )
×
537
                if err != nil {
×
538
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
539
                }
×
540

541
                for _, dbNode := range dbNodes {
×
542
                        node, err := buildNode(ctx, db, &dbNode)
×
543
                        if err != nil {
×
544
                                return fmt.Errorf("unable to build node: %w",
×
545
                                        err)
×
546
                        }
×
547

548
                        nodes = append(nodes, *node)
×
549
                }
550

551
                return nil
×
552
        }, sqldb.NoOpReset)
553
        if err != nil {
×
554
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
555
        }
×
556

557
        return nodes, nil
×
558
}
559

560
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
561
// undirected edge from the two target nodes are created. The information stored
562
// denotes the static attributes of the channel, such as the channelID, the keys
563
// involved in creation of the channel, and the set of features that the channel
564
// supports. The chanPoint and chanID are used to uniquely identify the edge
565
// globally within the database.
566
//
567
// NOTE: part of the V1Store interface.
568
func (s *SQLStore) AddChannelEdge(ctx context.Context,
569
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
570

×
571
        var alreadyExists bool
×
572
        r := &batch.Request[SQLQueries]{
×
573
                Opts: batch.NewSchedulerOptions(opts...),
×
574
                Reset: func() {
×
575
                        alreadyExists = false
×
576
                },
×
577
                Do: func(tx SQLQueries) error {
×
NEW
578
                        _, err := insertChannel(ctx, tx, edge)
×
579

×
580
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
581
                        // succeed, but propagate the error via local state.
×
582
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
583
                                alreadyExists = true
×
584
                                return nil
×
585
                        }
×
586

587
                        return err
×
588
                },
589
                OnCommit: func(err error) error {
×
590
                        switch {
×
591
                        case err != nil:
×
592
                                return err
×
593
                        case alreadyExists:
×
594
                                return ErrEdgeAlreadyExist
×
595
                        default:
×
596
                                s.rejectCache.remove(edge.ChannelID)
×
597
                                s.chanCache.remove(edge.ChannelID)
×
598
                                return nil
×
599
                        }
600
                },
601
        }
602

603
        return s.chanScheduler.Execute(ctx, r)
×
604
}
605

606
// HighestChanID returns the "highest" known channel ID in the channel graph.
607
// This represents the "newest" channel from the PoV of the chain. This method
608
// can be used by peers to quickly determine if their graphs are in sync.
609
//
610
// NOTE: This is part of the V1Store interface.
611
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
612
        var highestChanID uint64
×
613
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
614
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
615
                if errors.Is(err, sql.ErrNoRows) {
×
616
                        return nil
×
617
                } else if err != nil {
×
618
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
619
                                err)
×
620
                }
×
621

622
                highestChanID = byteOrder.Uint64(chanID)
×
623

×
624
                return nil
×
625
        }, sqldb.NoOpReset)
626
        if err != nil {
×
627
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
628
        }
×
629

630
        return highestChanID, nil
×
631
}
632

633
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
634
// within the database for the referenced channel. The `flags` attribute within
635
// the ChannelEdgePolicy determines which of the directed edges are being
636
// updated. If the flag is 1, then the first node's information is being
637
// updated, otherwise it's the second node's information. The node ordering is
638
// determined by the lexicographical ordering of the identity public keys of the
639
// nodes on either side of the channel.
640
//
641
// NOTE: part of the V1Store interface.
642
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
643
        edge *models.ChannelEdgePolicy,
644
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
645

×
646
        var (
×
647
                isUpdate1    bool
×
648
                edgeNotFound bool
×
649
                from, to     route.Vertex
×
650
        )
×
651

×
652
        r := &batch.Request[SQLQueries]{
×
653
                Opts: batch.NewSchedulerOptions(opts...),
×
654
                Reset: func() {
×
655
                        isUpdate1 = false
×
656
                        edgeNotFound = false
×
657
                },
×
658
                Do: func(tx SQLQueries) error {
×
659
                        var err error
×
660
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
661
                                ctx, tx, edge,
×
662
                        )
×
663
                        if err != nil {
×
664
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
665
                        }
×
666

667
                        // Silence ErrEdgeNotFound so that the batch can
668
                        // succeed, but propagate the error via local state.
669
                        if errors.Is(err, ErrEdgeNotFound) {
×
670
                                edgeNotFound = true
×
671
                                return nil
×
672
                        }
×
673

674
                        return err
×
675
                },
676
                OnCommit: func(err error) error {
×
677
                        switch {
×
678
                        case err != nil:
×
679
                                return err
×
680
                        case edgeNotFound:
×
681
                                return ErrEdgeNotFound
×
682
                        default:
×
683
                                s.updateEdgeCache(edge, isUpdate1)
×
684
                                return nil
×
685
                        }
686
                },
687
        }
688

689
        err := s.chanScheduler.Execute(ctx, r)
×
690

×
691
        return from, to, err
×
692
}
693

694
// updateEdgeCache updates our reject and channel caches with the new
695
// edge policy information.
696
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
697
        isUpdate1 bool) {
×
698

×
699
        // If an entry for this channel is found in reject cache, we'll modify
×
700
        // the entry with the updated timestamp for the direction that was just
×
701
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
702
        // during the next query for this edge.
×
703
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
704
                if isUpdate1 {
×
705
                        entry.upd1Time = e.LastUpdate.Unix()
×
706
                } else {
×
707
                        entry.upd2Time = e.LastUpdate.Unix()
×
708
                }
×
709
                s.rejectCache.insert(e.ChannelID, entry)
×
710
        }
711

712
        // If an entry for this channel is found in channel cache, we'll modify
713
        // the entry with the updated policy for the direction that was just
714
        // written. If the edge doesn't exist, we'll defer loading the info and
715
        // policies and lazily read from disk during the next query.
716
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
717
                if isUpdate1 {
×
718
                        channel.Policy1 = e
×
719
                } else {
×
720
                        channel.Policy2 = e
×
721
                }
×
722
                s.chanCache.insert(e.ChannelID, channel)
×
723
        }
724
}
725

726
// ForEachSourceNodeChannel iterates through all channels of the source node,
727
// executing the passed callback on each. The call-back is provided with the
728
// channel's outpoint, whether we have a policy for the channel and the channel
729
// peer's node information.
730
//
731
// NOTE: part of the V1Store interface.
732
func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint,
733
        havePolicy bool, otherNode *models.LightningNode) error) error {
×
734

×
735
        var ctx = context.TODO()
×
736

×
737
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
738
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
739
                if err != nil {
×
740
                        return fmt.Errorf("unable to fetch source node: %w",
×
741
                                err)
×
742
                }
×
743

744
                return forEachNodeChannel(
×
745
                        ctx, db, s.cfg.ChainHash, nodeID,
×
746
                        func(info *models.ChannelEdgeInfo,
×
747
                                outPolicy *models.ChannelEdgePolicy,
×
748
                                _ *models.ChannelEdgePolicy) error {
×
749

×
750
                                // Fetch the other node.
×
751
                                var (
×
752
                                        otherNodePub [33]byte
×
753
                                        node1        = info.NodeKey1Bytes
×
754
                                        node2        = info.NodeKey2Bytes
×
755
                                )
×
756
                                switch {
×
757
                                case bytes.Equal(node1[:], nodePub[:]):
×
758
                                        otherNodePub = node2
×
759
                                case bytes.Equal(node2[:], nodePub[:]):
×
760
                                        otherNodePub = node1
×
761
                                default:
×
762
                                        return fmt.Errorf("node not " +
×
763
                                                "participating in this channel")
×
764
                                }
765

766
                                _, otherNode, err := getNodeByPubKey(
×
767
                                        ctx, db, otherNodePub,
×
768
                                )
×
769
                                if err != nil {
×
770
                                        return fmt.Errorf("unable to fetch "+
×
771
                                                "other node(%x): %w",
×
772
                                                otherNodePub, err)
×
773
                                }
×
774

775
                                return cb(
×
776
                                        info.ChannelPoint, outPolicy != nil,
×
777
                                        otherNode,
×
778
                                )
×
779
                        },
780
                )
781
        }, sqldb.NoOpReset)
782
}
783

784
// ForEachNode iterates through all the stored vertices/nodes in the graph,
785
// executing the passed callback with each node encountered. If the callback
786
// returns an error, then the transaction is aborted and the iteration stops
787
// early. Any operations performed on the NodeTx passed to the call-back are
788
// executed under the same read transaction and so, methods on the NodeTx object
789
// _MUST_ only be called from within the call-back.
790
//
791
// NOTE: part of the V1Store interface.
792
func (s *SQLStore) ForEachNode(cb func(tx NodeRTx) error) error {
×
793
        var (
×
794
                ctx          = context.TODO()
×
795
                lastID int64 = 0
×
796
        )
×
797

×
798
        handleNode := func(db SQLQueries, dbNode sqlc.Node) error {
×
799
                node, err := buildNode(ctx, db, &dbNode)
×
800
                if err != nil {
×
801
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
802
                                dbNode.ID, err)
×
803
                }
×
804

805
                err = cb(
×
806
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
807
                )
×
808
                if err != nil {
×
809
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
810
                                dbNode.ID, err)
×
811
                }
×
812

813
                return nil
×
814
        }
815

816
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
817
                for {
×
818
                        nodes, err := db.ListNodesPaginated(
×
819
                                ctx, sqlc.ListNodesPaginatedParams{
×
820
                                        Version: int16(ProtocolV1),
×
821
                                        ID:      lastID,
×
822
                                        Limit:   pageSize,
×
823
                                },
×
824
                        )
×
825
                        if err != nil {
×
826
                                return fmt.Errorf("unable to fetch nodes: %w",
×
827
                                        err)
×
828
                        }
×
829

830
                        if len(nodes) == 0 {
×
831
                                break
×
832
                        }
833

834
                        for _, dbNode := range nodes {
×
835
                                err = handleNode(db, dbNode)
×
836
                                if err != nil {
×
837
                                        return err
×
838
                                }
×
839

840
                                lastID = dbNode.ID
×
841
                        }
842
                }
843

844
                return nil
×
845
        }, sqldb.NoOpReset)
846
}
847

848
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
849
// SQLStore and a SQL transaction.
850
type sqlGraphNodeTx struct {
851
        db    SQLQueries
852
        id    int64
853
        node  *models.LightningNode
854
        chain chainhash.Hash
855
}
856

857
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
858
// interface.
859
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
860

861
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
862
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
863

×
864
        return &sqlGraphNodeTx{
×
865
                db:    db,
×
866
                chain: chain,
×
867
                id:    id,
×
868
                node:  node,
×
869
        }
×
870
}
×
871

872
// Node returns the raw information of the node.
873
//
874
// NOTE: This is a part of the NodeRTx interface.
875
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
876
        return s.node
×
877
}
×
878

879
// ForEachChannel can be used to iterate over the node's channels under the same
880
// transaction used to fetch the node.
881
//
882
// NOTE: This is a part of the NodeRTx interface.
883
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
884
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
885

×
886
        ctx := context.TODO()
×
887

×
888
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
889
}
×
890

891
// FetchNode fetches the node with the given pub key under the same transaction
892
// used to fetch the current node. The returned node is also a NodeRTx and any
893
// operations on that NodeRTx will also be done under the same transaction.
894
//
895
// NOTE: This is a part of the NodeRTx interface.
896
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
897
        ctx := context.TODO()
×
898

×
899
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
900
        if err != nil {
×
901
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
902
                        nodePub, err)
×
903
        }
×
904

905
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
906
}
907

908
// ForEachNodeDirectedChannel iterates through all channels of a given node,
909
// executing the passed callback on the directed edge representing the channel
910
// and its incoming policy. If the callback returns an error, then the iteration
911
// is halted with the error propagated back up to the caller.
912
//
913
// Unknown policies are passed into the callback as nil values.
914
//
915
// NOTE: this is part of the graphdb.NodeTraverser interface.
916
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
917
        cb func(channel *DirectedChannel) error) error {
×
918

×
919
        var ctx = context.TODO()
×
920

×
921
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
922
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
923
        }, sqldb.NoOpReset)
×
924
}
925

926
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
927
// graph, executing the passed callback with each node encountered. If the
928
// callback returns an error, then the transaction is aborted and the iteration
929
// stops early.
930
//
931
// NOTE: This is a part of the V1Store interface.
932
func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex,
933
        *lnwire.FeatureVector) error) error {
×
934

×
935
        ctx := context.TODO()
×
936

×
937
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
938
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
939
                        nodePub route.Vertex) error {
×
940

×
941
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
942
                        if err != nil {
×
943
                                return fmt.Errorf("unable to fetch node "+
×
944
                                        "features: %w", err)
×
945
                        }
×
946

947
                        return cb(nodePub, features)
×
948
                })
949
        }, sqldb.NoOpReset)
950
        if err != nil {
×
951
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
952
        }
×
953

954
        return nil
×
955
}
956

957
// ForEachNodeChannel iterates through all channels of the given node,
958
// executing the passed callback with an edge info structure and the policies
959
// of each end of the channel. The first edge policy is the outgoing edge *to*
960
// the connecting node, while the second is the incoming edge *from* the
961
// connecting node. If the callback returns an error, then the iteration is
962
// halted with the error propagated back up to the caller.
963
//
964
// Unknown policies are passed into the callback as nil values.
965
//
966
// NOTE: part of the V1Store interface.
967
func (s *SQLStore) ForEachNodeChannel(nodePub route.Vertex,
968
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
969
                *models.ChannelEdgePolicy) error) error {
×
970

×
971
        var ctx = context.TODO()
×
972

×
973
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
974
                dbNode, err := db.GetNodeByPubKey(
×
975
                        ctx, sqlc.GetNodeByPubKeyParams{
×
976
                                Version: int16(ProtocolV1),
×
977
                                PubKey:  nodePub[:],
×
978
                        },
×
979
                )
×
980
                if errors.Is(err, sql.ErrNoRows) {
×
981
                        return nil
×
982
                } else if err != nil {
×
983
                        return fmt.Errorf("unable to fetch node: %w", err)
×
984
                }
×
985

986
                return forEachNodeChannel(
×
987
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
988
                )
×
989
        }, sqldb.NoOpReset)
990
}
991

992
// ChanUpdatesInHorizon returns all the known channel edges which have at least
993
// one edge that has an update timestamp within the specified horizon.
994
//
995
// NOTE: This is part of the V1Store interface.
996
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
997
        endTime time.Time) ([]ChannelEdge, error) {
×
998

×
999
        s.cacheMu.Lock()
×
1000
        defer s.cacheMu.Unlock()
×
1001

×
1002
        var (
×
1003
                ctx = context.TODO()
×
1004
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
1005
                // an additional map to keep track of the edges already seen to
×
1006
                // prevent re-adding it.
×
1007
                edgesSeen    = make(map[uint64]struct{})
×
1008
                edgesToCache = make(map[uint64]ChannelEdge)
×
1009
                edges        []ChannelEdge
×
1010
                hits         int
×
1011
        )
×
1012
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1013
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1014
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1015
                                Version:   int16(ProtocolV1),
×
1016
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1017
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1018
                        },
×
1019
                )
×
1020
                if err != nil {
×
1021
                        return err
×
1022
                }
×
1023

1024
                for _, row := range rows {
×
1025
                        // If we've already retrieved the info and policies for
×
1026
                        // this edge, then we can skip it as we don't need to do
×
1027
                        // so again.
×
1028
                        chanIDInt := byteOrder.Uint64(row.Channel.Scid)
×
1029
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1030
                                continue
×
1031
                        }
1032

1033
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1034
                                hits++
×
1035
                                edgesSeen[chanIDInt] = struct{}{}
×
1036
                                edges = append(edges, channel)
×
1037

×
1038
                                continue
×
1039
                        }
1040

1041
                        node1, node2, err := buildNodes(
×
1042
                                ctx, db, row.Node, row.Node_2,
×
1043
                        )
×
1044
                        if err != nil {
×
1045
                                return err
×
1046
                        }
×
1047

1048
                        channel, err := getAndBuildEdgeInfo(
×
1049
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1050
                                row.Channel, node1.PubKeyBytes,
×
1051
                                node2.PubKeyBytes,
×
1052
                        )
×
1053
                        if err != nil {
×
1054
                                return fmt.Errorf("unable to build channel "+
×
1055
                                        "info: %w", err)
×
1056
                        }
×
1057

1058
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1059
                        if err != nil {
×
1060
                                return fmt.Errorf("unable to extract channel "+
×
1061
                                        "policies: %w", err)
×
1062
                        }
×
1063

1064
                        p1, p2, err := getAndBuildChanPolicies(
×
1065
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1066
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1067
                        )
×
1068
                        if err != nil {
×
1069
                                return fmt.Errorf("unable to build channel "+
×
1070
                                        "policies: %w", err)
×
1071
                        }
×
1072

1073
                        edgesSeen[chanIDInt] = struct{}{}
×
1074
                        chanEdge := ChannelEdge{
×
1075
                                Info:    channel,
×
1076
                                Policy1: p1,
×
1077
                                Policy2: p2,
×
1078
                                Node1:   node1,
×
1079
                                Node2:   node2,
×
1080
                        }
×
1081
                        edges = append(edges, chanEdge)
×
1082
                        edgesToCache[chanIDInt] = chanEdge
×
1083
                }
1084

1085
                return nil
×
1086
        }, func() {
×
1087
                edgesSeen = make(map[uint64]struct{})
×
1088
                edgesToCache = make(map[uint64]ChannelEdge)
×
1089
                edges = nil
×
1090
        })
×
1091
        if err != nil {
×
1092
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1093
        }
×
1094

1095
        // Insert any edges loaded from disk into the cache.
1096
        for chanid, channel := range edgesToCache {
×
1097
                s.chanCache.insert(chanid, channel)
×
1098
        }
×
1099

1100
        if len(edges) > 0 {
×
1101
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1102
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1103
        } else {
×
1104
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1105
                        "horizon (%s, %s)", startTime, endTime)
×
1106
        }
×
1107

1108
        return edges, nil
×
1109
}
1110

1111
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1112
// data to the call-back.
1113
//
1114
// NOTE: The callback contents MUST not be modified.
1115
//
1116
// NOTE: part of the V1Store interface.
1117
func (s *SQLStore) ForEachNodeCached(cb func(node route.Vertex,
1118
        chans map[uint64]*DirectedChannel) error) error {
×
1119

×
1120
        var ctx = context.TODO()
×
1121

×
1122
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1123
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1124
                        nodePub route.Vertex) error {
×
1125

×
1126
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1127
                        if err != nil {
×
1128
                                return fmt.Errorf("unable to fetch "+
×
1129
                                        "node(id=%d) features: %w", nodeID, err)
×
1130
                        }
×
1131

1132
                        toNodeCallback := func() route.Vertex {
×
1133
                                return nodePub
×
1134
                        }
×
1135

1136
                        rows, err := db.ListChannelsByNodeID(
×
1137
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1138
                                        Version: int16(ProtocolV1),
×
1139
                                        NodeID1: nodeID,
×
1140
                                },
×
1141
                        )
×
1142
                        if err != nil {
×
1143
                                return fmt.Errorf("unable to fetch channels "+
×
1144
                                        "of node(id=%d): %w", nodeID, err)
×
1145
                        }
×
1146

1147
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1148
                        for _, row := range rows {
×
1149
                                node1, node2, err := buildNodeVertices(
×
1150
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1151
                                )
×
1152
                                if err != nil {
×
1153
                                        return err
×
1154
                                }
×
1155

1156
                                e, err := getAndBuildEdgeInfo(
×
1157
                                        ctx, db, s.cfg.ChainHash,
×
1158
                                        row.Channel.ID, row.Channel, node1,
×
1159
                                        node2,
×
1160
                                )
×
1161
                                if err != nil {
×
1162
                                        return fmt.Errorf("unable to build "+
×
1163
                                                "channel info: %w", err)
×
1164
                                }
×
1165

1166
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1167
                                        row,
×
1168
                                )
×
1169
                                if err != nil {
×
1170
                                        return fmt.Errorf("unable to "+
×
1171
                                                "extract channel "+
×
1172
                                                "policies: %w", err)
×
1173
                                }
×
1174

1175
                                p1, p2, err := getAndBuildChanPolicies(
×
1176
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1177
                                        node1, node2,
×
1178
                                )
×
1179
                                if err != nil {
×
1180
                                        return fmt.Errorf("unable to "+
×
1181
                                                "build channel policies: %w",
×
1182
                                                err)
×
1183
                                }
×
1184

1185
                                // Determine the outgoing and incoming policy
1186
                                // for this channel and node combo.
1187
                                outPolicy, inPolicy := p1, p2
×
1188
                                if p1 != nil && p1.ToNode == nodePub {
×
1189
                                        outPolicy, inPolicy = p2, p1
×
1190
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1191
                                        outPolicy, inPolicy = p2, p1
×
1192
                                }
×
1193

1194
                                var cachedInPolicy *models.CachedEdgePolicy
×
1195
                                if inPolicy != nil {
×
1196
                                        cachedInPolicy = models.NewCachedPolicy(
×
1197
                                                p2,
×
1198
                                        )
×
1199
                                        cachedInPolicy.ToNodePubKey =
×
1200
                                                toNodeCallback
×
1201
                                        cachedInPolicy.ToNodeFeatures =
×
1202
                                                features
×
1203
                                }
×
1204

1205
                                var inboundFee lnwire.Fee
×
1206
                                outPolicy.InboundFee.WhenSome(
×
1207
                                        func(fee lnwire.Fee) {
×
1208
                                                inboundFee = fee
×
1209
                                        },
×
1210
                                )
1211

1212
                                directedChannel := &DirectedChannel{
×
1213
                                        ChannelID: e.ChannelID,
×
1214
                                        IsNode1: nodePub ==
×
1215
                                                e.NodeKey1Bytes,
×
1216
                                        OtherNode:    e.NodeKey2Bytes,
×
1217
                                        Capacity:     e.Capacity,
×
1218
                                        OutPolicySet: p1 != nil,
×
1219
                                        InPolicy:     cachedInPolicy,
×
1220
                                        InboundFee:   inboundFee,
×
1221
                                }
×
1222

×
1223
                                if nodePub == e.NodeKey2Bytes {
×
1224
                                        directedChannel.OtherNode =
×
1225
                                                e.NodeKey1Bytes
×
1226
                                }
×
1227

1228
                                channels[e.ChannelID] = directedChannel
×
1229
                        }
1230

1231
                        return cb(nodePub, channels)
×
1232
                })
1233
        }, sqldb.NoOpReset)
1234
}
1235

1236
// ForEachChannelCacheable iterates through all the channel edges stored
1237
// within the graph and invokes the passed callback for each edge. The
1238
// callback takes two edges as since this is a directed graph, both the
1239
// in/out edges are visited. If the callback returns an error, then the
1240
// transaction is aborted and the iteration stops early.
1241
//
1242
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1243
// pointer for that particular channel edge routing policy will be
1244
// passed into the callback.
1245
//
1246
// NOTE: this method is like ForEachChannel but fetches only the data
1247
// required for the graph cache.
1248
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1249
        *models.CachedEdgePolicy,
1250
        *models.CachedEdgePolicy) error) error {
×
1251

×
1252
        ctx := context.TODO()
×
1253

×
1254
        handleChannel := func(db SQLQueries,
×
1255
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1256

×
1257
                node1, node2, err := buildNodeVertices(
×
1258
                        row.Node1Pubkey, row.Node2Pubkey,
×
1259
                )
×
1260
                if err != nil {
×
1261
                        return err
×
1262
                }
×
1263

1264
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
1265

×
1266
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1267
                if err != nil {
×
1268
                        return err
×
1269
                }
×
1270

1271
                var pol1, pol2 *models.CachedEdgePolicy
×
1272
                if dbPol1 != nil {
×
1273
                        policy1, err := buildChanPolicy(
×
1274
                                *dbPol1, edge.ChannelID, nil, node2, true,
×
1275
                        )
×
1276
                        if err != nil {
×
1277
                                return err
×
1278
                        }
×
1279

1280
                        pol1 = models.NewCachedPolicy(policy1)
×
1281
                }
1282
                if dbPol2 != nil {
×
1283
                        policy2, err := buildChanPolicy(
×
1284
                                *dbPol2, edge.ChannelID, nil, node1, false,
×
1285
                        )
×
1286
                        if err != nil {
×
1287
                                return err
×
1288
                        }
×
1289

1290
                        pol2 = models.NewCachedPolicy(policy2)
×
1291
                }
1292

1293
                if err := cb(edge, pol1, pol2); err != nil {
×
1294
                        return err
×
1295
                }
×
1296

1297
                return nil
×
1298
        }
1299

1300
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1301
                lastID := int64(-1)
×
1302
                for {
×
1303
                        //nolint:ll
×
1304
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1305
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1306
                                        Version: int16(ProtocolV1),
×
1307
                                        ID:      lastID,
×
1308
                                        Limit:   pageSize,
×
1309
                                },
×
1310
                        )
×
1311
                        if err != nil {
×
1312
                                return err
×
1313
                        }
×
1314

1315
                        if len(rows) == 0 {
×
1316
                                break
×
1317
                        }
1318

1319
                        for _, row := range rows {
×
1320
                                err := handleChannel(db, row)
×
1321
                                if err != nil {
×
1322
                                        return err
×
1323
                                }
×
1324

1325
                                lastID = row.Channel.ID
×
1326
                        }
1327
                }
1328

1329
                return nil
×
1330
        }, sqldb.NoOpReset)
1331
}
1332

1333
// ForEachChannel iterates through all the channel edges stored within the
1334
// graph and invokes the passed callback for each edge. The callback takes two
1335
// edges as since this is a directed graph, both the in/out edges are visited.
1336
// If the callback returns an error, then the transaction is aborted and the
1337
// iteration stops early.
1338
//
1339
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1340
// for that particular channel edge routing policy will be passed into the
1341
// callback.
1342
//
1343
// NOTE: part of the V1Store interface.
1344
func (s *SQLStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
1345
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
1346

×
1347
        ctx := context.TODO()
×
1348

×
1349
        handleChannel := func(db SQLQueries,
×
1350
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1351

×
1352
                node1, node2, err := buildNodeVertices(
×
1353
                        row.Node1Pubkey, row.Node2Pubkey,
×
1354
                )
×
1355
                if err != nil {
×
1356
                        return fmt.Errorf("unable to build node vertices: %w",
×
1357
                                err)
×
1358
                }
×
1359

1360
                edge, err := getAndBuildEdgeInfo(
×
1361
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1362
                        node1, node2,
×
1363
                )
×
1364
                if err != nil {
×
1365
                        return fmt.Errorf("unable to build channel info: %w",
×
1366
                                err)
×
1367
                }
×
1368

1369
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1370
                if err != nil {
×
1371
                        return fmt.Errorf("unable to extract channel "+
×
1372
                                "policies: %w", err)
×
1373
                }
×
1374

1375
                p1, p2, err := getAndBuildChanPolicies(
×
1376
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1377
                )
×
1378
                if err != nil {
×
1379
                        return fmt.Errorf("unable to build channel "+
×
1380
                                "policies: %w", err)
×
1381
                }
×
1382

1383
                err = cb(edge, p1, p2)
×
1384
                if err != nil {
×
1385
                        return fmt.Errorf("callback failed for channel "+
×
1386
                                "id=%d: %w", edge.ChannelID, err)
×
1387
                }
×
1388

1389
                return nil
×
1390
        }
1391

1392
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1393
                lastID := int64(-1)
×
1394
                for {
×
1395
                        //nolint:ll
×
1396
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1397
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1398
                                        Version: int16(ProtocolV1),
×
1399
                                        ID:      lastID,
×
1400
                                        Limit:   pageSize,
×
1401
                                },
×
1402
                        )
×
1403
                        if err != nil {
×
1404
                                return err
×
1405
                        }
×
1406

1407
                        if len(rows) == 0 {
×
1408
                                break
×
1409
                        }
1410

1411
                        for _, row := range rows {
×
1412
                                err := handleChannel(db, row)
×
1413
                                if err != nil {
×
1414
                                        return err
×
1415
                                }
×
1416

1417
                                lastID = row.Channel.ID
×
1418
                        }
1419
                }
1420

1421
                return nil
×
1422
        }, sqldb.NoOpReset)
1423
}
1424

1425
// FilterChannelRange returns the channel ID's of all known channels which were
1426
// mined in a block height within the passed range. The channel IDs are grouped
1427
// by their common block height. This method can be used to quickly share with a
1428
// peer the set of channels we know of within a particular range to catch them
1429
// up after a period of time offline. If withTimestamps is true then the
1430
// timestamp info of the latest received channel update messages of the channel
1431
// will be included in the response.
1432
//
1433
// NOTE: This is part of the V1Store interface.
1434
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1435
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1436

×
1437
        var (
×
1438
                ctx       = context.TODO()
×
1439
                startSCID = &lnwire.ShortChannelID{
×
1440
                        BlockHeight: startHeight,
×
1441
                }
×
1442
                endSCID = lnwire.ShortChannelID{
×
1443
                        BlockHeight: endHeight,
×
1444
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1445
                        TxPosition:  math.MaxUint16,
×
1446
                }
×
1447
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1448
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1449
        )
×
1450

×
1451
        // 1) get all channels where channelID is between start and end chan ID.
×
1452
        // 2) skip if not public (ie, no channel_proof)
×
1453
        // 3) collect that channel.
×
1454
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1455
        //    and add those timestamps to the collected channel.
×
1456
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1457
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1458
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1459
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
NEW
1460
                                StartScid: chanIDStart,
×
NEW
1461
                                EndScid:   chanIDEnd,
×
1462
                        },
×
1463
                )
×
1464
                if err != nil {
×
1465
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1466
                                err)
×
1467
                }
×
1468

1469
                for _, dbChan := range dbChans {
×
1470
                        cid := lnwire.NewShortChanIDFromInt(
×
1471
                                byteOrder.Uint64(dbChan.Scid),
×
1472
                        )
×
1473
                        chanInfo := NewChannelUpdateInfo(
×
1474
                                cid, time.Time{}, time.Time{},
×
1475
                        )
×
1476

×
1477
                        if !withTimestamps {
×
1478
                                channelsPerBlock[cid.BlockHeight] = append(
×
1479
                                        channelsPerBlock[cid.BlockHeight],
×
1480
                                        chanInfo,
×
1481
                                )
×
1482

×
1483
                                continue
×
1484
                        }
1485

1486
                        //nolint:ll
1487
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1488
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1489
                                        Version:   int16(ProtocolV1),
×
1490
                                        ChannelID: dbChan.ID,
×
1491
                                        NodeID:    dbChan.NodeID1,
×
1492
                                },
×
1493
                        )
×
1494
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1495
                                return fmt.Errorf("unable to fetch node1 "+
×
1496
                                        "policy: %w", err)
×
1497
                        } else if err == nil {
×
1498
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1499
                                        node1Policy.LastUpdate.Int64, 0,
×
1500
                                )
×
1501
                        }
×
1502

1503
                        //nolint:ll
1504
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1505
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1506
                                        Version:   int16(ProtocolV1),
×
1507
                                        ChannelID: dbChan.ID,
×
1508
                                        NodeID:    dbChan.NodeID2,
×
1509
                                },
×
1510
                        )
×
1511
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1512
                                return fmt.Errorf("unable to fetch node2 "+
×
1513
                                        "policy: %w", err)
×
1514
                        } else if err == nil {
×
1515
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1516
                                        node2Policy.LastUpdate.Int64, 0,
×
1517
                                )
×
1518
                        }
×
1519

1520
                        channelsPerBlock[cid.BlockHeight] = append(
×
1521
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1522
                        )
×
1523
                }
1524

1525
                return nil
×
1526
        }, func() {
×
1527
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1528
        })
×
1529
        if err != nil {
×
1530
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1531
        }
×
1532

1533
        if len(channelsPerBlock) == 0 {
×
1534
                return nil, nil
×
1535
        }
×
1536

1537
        // Return the channel ranges in ascending block height order.
1538
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1539
        slices.Sort(blocks)
×
1540

×
1541
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1542
                return BlockChannelRange{
×
1543
                        Height:   block,
×
1544
                        Channels: channelsPerBlock[block],
×
1545
                }
×
1546
        }), nil
×
1547
}
1548

1549
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1550
// zombie. This method is used on an ad-hoc basis, when channels need to be
1551
// marked as zombies outside the normal pruning cycle.
1552
//
1553
// NOTE: part of the V1Store interface.
1554
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1555
        pubKey1, pubKey2 [33]byte) error {
×
1556

×
1557
        ctx := context.TODO()
×
1558

×
1559
        s.cacheMu.Lock()
×
1560
        defer s.cacheMu.Unlock()
×
1561

×
1562
        chanIDB := channelIDToBytes(chanID)
×
1563

×
1564
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1565
                return db.UpsertZombieChannel(
×
1566
                        ctx, sqlc.UpsertZombieChannelParams{
×
1567
                                Version:  int16(ProtocolV1),
×
NEW
1568
                                Scid:     chanIDB,
×
1569
                                NodeKey1: pubKey1[:],
×
1570
                                NodeKey2: pubKey2[:],
×
1571
                        },
×
1572
                )
×
1573
        }, sqldb.NoOpReset)
×
1574
        if err != nil {
×
1575
                return fmt.Errorf("unable to upsert zombie channel "+
×
1576
                        "(channel_id=%d): %w", chanID, err)
×
1577
        }
×
1578

1579
        s.rejectCache.remove(chanID)
×
1580
        s.chanCache.remove(chanID)
×
1581

×
1582
        return nil
×
1583
}
1584

1585
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1586
//
1587
// NOTE: part of the V1Store interface.
1588
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1589
        s.cacheMu.Lock()
×
1590
        defer s.cacheMu.Unlock()
×
1591

×
1592
        var (
×
1593
                ctx     = context.TODO()
×
1594
                chanIDB = channelIDToBytes(chanID)
×
1595
        )
×
1596

×
1597
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1598
                res, err := db.DeleteZombieChannel(
×
1599
                        ctx, sqlc.DeleteZombieChannelParams{
×
NEW
1600
                                Scid:    chanIDB,
×
1601
                                Version: int16(ProtocolV1),
×
1602
                        },
×
1603
                )
×
1604
                if err != nil {
×
1605
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1606
                                err)
×
1607
                }
×
1608

1609
                rows, err := res.RowsAffected()
×
1610
                if err != nil {
×
1611
                        return err
×
1612
                }
×
1613

1614
                if rows == 0 {
×
1615
                        return ErrZombieEdgeNotFound
×
1616
                } else if rows > 1 {
×
1617
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1618
                                "expected 1", rows)
×
1619
                }
×
1620

1621
                return nil
×
1622
        }, sqldb.NoOpReset)
1623
        if err != nil {
×
1624
                return fmt.Errorf("unable to mark edge live "+
×
1625
                        "(channel_id=%d): %w", chanID, err)
×
1626
        }
×
1627

1628
        s.rejectCache.remove(chanID)
×
1629
        s.chanCache.remove(chanID)
×
1630

×
1631
        return err
×
1632
}
1633

1634
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1635
// zombie, then the two node public keys corresponding to this edge are also
1636
// returned.
1637
//
1638
// NOTE: part of the V1Store interface.
1639
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1640
        error) {
×
1641

×
1642
        var (
×
1643
                ctx              = context.TODO()
×
1644
                isZombie         bool
×
1645
                pubKey1, pubKey2 route.Vertex
×
1646
                chanIDB          = channelIDToBytes(chanID)
×
1647
        )
×
1648

×
1649
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1650
                zombie, err := db.GetZombieChannel(
×
1651
                        ctx, sqlc.GetZombieChannelParams{
×
NEW
1652
                                Scid:    chanIDB,
×
1653
                                Version: int16(ProtocolV1),
×
1654
                        },
×
1655
                )
×
1656
                if errors.Is(err, sql.ErrNoRows) {
×
1657
                        return nil
×
1658
                }
×
1659
                if err != nil {
×
1660
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1661
                                err)
×
1662
                }
×
1663

1664
                copy(pubKey1[:], zombie.NodeKey1)
×
1665
                copy(pubKey2[:], zombie.NodeKey2)
×
1666
                isZombie = true
×
1667

×
1668
                return nil
×
1669
        }, sqldb.NoOpReset)
1670
        if err != nil {
×
1671
                return false, route.Vertex{}, route.Vertex{},
×
1672
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1673
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1674
        }
×
1675

1676
        return isZombie, pubKey1, pubKey2, nil
×
1677
}
1678

1679
// NumZombies returns the current number of zombie channels in the graph.
1680
//
1681
// NOTE: part of the V1Store interface.
1682
func (s *SQLStore) NumZombies() (uint64, error) {
×
1683
        var (
×
1684
                ctx        = context.TODO()
×
1685
                numZombies uint64
×
1686
        )
×
1687
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1688
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1689
                if err != nil {
×
1690
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1691
                                err)
×
1692
                }
×
1693

1694
                numZombies = uint64(count)
×
1695

×
1696
                return nil
×
1697
        }, sqldb.NoOpReset)
1698
        if err != nil {
×
1699
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1700
        }
×
1701

1702
        return numZombies, nil
×
1703
}
1704

1705
// DeleteChannelEdges removes edges with the given channel IDs from the
1706
// database and marks them as zombies. This ensures that we're unable to re-add
1707
// it to our database once again. If an edge does not exist within the
1708
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1709
// true, then when we mark these edges as zombies, we'll set up the keys such
1710
// that we require the node that failed to send the fresh update to be the one
1711
// that resurrects the channel from its zombie state. The markZombie bool
1712
// denotes whether to mark the channel as a zombie.
1713
//
1714
// NOTE: part of the V1Store interface.
1715
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1716
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1717

×
1718
        s.cacheMu.Lock()
×
1719
        defer s.cacheMu.Unlock()
×
1720

×
1721
        var (
×
1722
                ctx     = context.TODO()
×
1723
                deleted []*models.ChannelEdgeInfo
×
1724
        )
×
1725
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1726
                for _, chanID := range chanIDs {
×
1727
                        chanIDB := channelIDToBytes(chanID)
×
1728

×
1729
                        row, err := db.GetChannelBySCIDWithPolicies(
×
1730
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
NEW
1731
                                        Scid:    chanIDB,
×
1732
                                        Version: int16(ProtocolV1),
×
1733
                                },
×
1734
                        )
×
1735
                        if errors.Is(err, sql.ErrNoRows) {
×
1736
                                return ErrEdgeNotFound
×
1737
                        } else if err != nil {
×
1738
                                return fmt.Errorf("unable to fetch channel: %w",
×
1739
                                        err)
×
1740
                        }
×
1741

1742
                        node1, node2, err := buildNodeVertices(
×
1743
                                row.Node.PubKey, row.Node_2.PubKey,
×
1744
                        )
×
1745
                        if err != nil {
×
1746
                                return err
×
1747
                        }
×
1748

1749
                        info, err := getAndBuildEdgeInfo(
×
1750
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1751
                                row.Channel, node1, node2,
×
1752
                        )
×
1753
                        if err != nil {
×
1754
                                return err
×
1755
                        }
×
1756

1757
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
1758
                        if err != nil {
×
1759
                                return fmt.Errorf("unable to delete "+
×
1760
                                        "channel: %w", err)
×
1761
                        }
×
1762

1763
                        deleted = append(deleted, info)
×
1764

×
1765
                        if !markZombie {
×
1766
                                continue
×
1767
                        }
1768

1769
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1770
                                info.NodeKey2Bytes
×
1771
                        if strictZombiePruning {
×
1772
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1773
                                if row.Policy1LastUpdate.Valid {
×
1774
                                        e1Time := time.Unix(
×
1775
                                                row.Policy1LastUpdate.Int64, 0,
×
1776
                                        )
×
1777
                                        e1UpdateTime = &e1Time
×
1778
                                }
×
1779
                                if row.Policy2LastUpdate.Valid {
×
1780
                                        e2Time := time.Unix(
×
1781
                                                row.Policy2LastUpdate.Int64, 0,
×
1782
                                        )
×
1783
                                        e2UpdateTime = &e2Time
×
1784
                                }
×
1785

1786
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1787
                                        info, e1UpdateTime, e2UpdateTime,
×
1788
                                )
×
1789
                        }
1790

1791
                        err = db.UpsertZombieChannel(
×
1792
                                ctx, sqlc.UpsertZombieChannelParams{
×
1793
                                        Version:  int16(ProtocolV1),
×
NEW
1794
                                        Scid:     chanIDB,
×
1795
                                        NodeKey1: nodeKey1[:],
×
1796
                                        NodeKey2: nodeKey2[:],
×
1797
                                },
×
1798
                        )
×
1799
                        if err != nil {
×
1800
                                return fmt.Errorf("unable to mark channel as "+
×
1801
                                        "zombie: %w", err)
×
1802
                        }
×
1803
                }
1804

1805
                return nil
×
1806
        }, func() {
×
1807
                deleted = nil
×
1808
        })
×
1809
        if err != nil {
×
1810
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1811
                        err)
×
1812
        }
×
1813

1814
        for _, chanID := range chanIDs {
×
1815
                s.rejectCache.remove(chanID)
×
1816
                s.chanCache.remove(chanID)
×
1817
        }
×
1818

1819
        return deleted, nil
×
1820
}
1821

1822
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1823
// channel identified by the channel ID. If the channel can't be found, then
1824
// ErrEdgeNotFound is returned. A struct which houses the general information
1825
// for the channel itself is returned as well as two structs that contain the
1826
// routing policies for the channel in either direction.
1827
//
1828
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1829
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1830
// the ChannelEdgeInfo will only include the public keys of each node.
1831
//
1832
// NOTE: part of the V1Store interface.
1833
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1834
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1835
        *models.ChannelEdgePolicy, error) {
×
1836

×
1837
        var (
×
1838
                ctx              = context.TODO()
×
1839
                edge             *models.ChannelEdgeInfo
×
1840
                policy1, policy2 *models.ChannelEdgePolicy
×
NEW
1841
                chanIDB          = channelIDToBytes(chanID)
×
1842
        )
×
1843
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1844
                row, err := db.GetChannelBySCIDWithPolicies(
×
1845
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
NEW
1846
                                Scid:    chanIDB,
×
1847
                                Version: int16(ProtocolV1),
×
1848
                        },
×
1849
                )
×
1850
                if errors.Is(err, sql.ErrNoRows) {
×
1851
                        // First check if this edge is perhaps in the zombie
×
1852
                        // index.
×
1853
                        isZombie, err := db.IsZombieChannel(
×
1854
                                ctx, sqlc.IsZombieChannelParams{
×
NEW
1855
                                        Scid:    chanIDB,
×
1856
                                        Version: int16(ProtocolV1),
×
1857
                                },
×
1858
                        )
×
1859
                        if err != nil {
×
1860
                                return fmt.Errorf("unable to check if "+
×
1861
                                        "channel is zombie: %w", err)
×
1862
                        } else if isZombie {
×
1863
                                return ErrZombieEdge
×
1864
                        }
×
1865

1866
                        return ErrEdgeNotFound
×
1867
                } else if err != nil {
×
1868
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1869
                }
×
1870

1871
                node1, node2, err := buildNodeVertices(
×
1872
                        row.Node.PubKey, row.Node_2.PubKey,
×
1873
                )
×
1874
                if err != nil {
×
1875
                        return err
×
1876
                }
×
1877

1878
                edge, err = getAndBuildEdgeInfo(
×
1879
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1880
                        node1, node2,
×
1881
                )
×
1882
                if err != nil {
×
1883
                        return fmt.Errorf("unable to build channel info: %w",
×
1884
                                err)
×
1885
                }
×
1886

1887
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1888
                if err != nil {
×
1889
                        return fmt.Errorf("unable to extract channel "+
×
1890
                                "policies: %w", err)
×
1891
                }
×
1892

1893
                policy1, policy2, err = getAndBuildChanPolicies(
×
1894
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1895
                )
×
1896
                if err != nil {
×
1897
                        return fmt.Errorf("unable to build channel "+
×
1898
                                "policies: %w", err)
×
1899
                }
×
1900

1901
                return nil
×
1902
        }, sqldb.NoOpReset)
1903
        if err != nil {
×
1904
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1905
                        err)
×
1906
        }
×
1907

1908
        return edge, policy1, policy2, nil
×
1909
}
1910

1911
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1912
// the channel identified by the funding outpoint. If the channel can't be
1913
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1914
// information for the channel itself is returned as well as two structs that
1915
// contain the routing policies for the channel in either direction.
1916
//
1917
// NOTE: part of the V1Store interface.
1918
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1919
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1920
        *models.ChannelEdgePolicy, error) {
×
1921

×
1922
        var (
×
1923
                ctx              = context.TODO()
×
1924
                edge             *models.ChannelEdgeInfo
×
1925
                policy1, policy2 *models.ChannelEdgePolicy
×
1926
        )
×
1927
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1928
                row, err := db.GetChannelByOutpointWithPolicies(
×
1929
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1930
                                Outpoint: op.String(),
×
1931
                                Version:  int16(ProtocolV1),
×
1932
                        },
×
1933
                )
×
1934
                if errors.Is(err, sql.ErrNoRows) {
×
1935
                        return ErrEdgeNotFound
×
1936
                } else if err != nil {
×
1937
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1938
                }
×
1939

1940
                node1, node2, err := buildNodeVertices(
×
1941
                        row.Node1Pubkey, row.Node2Pubkey,
×
1942
                )
×
1943
                if err != nil {
×
1944
                        return err
×
1945
                }
×
1946

1947
                edge, err = getAndBuildEdgeInfo(
×
1948
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1949
                        node1, node2,
×
1950
                )
×
1951
                if err != nil {
×
1952
                        return fmt.Errorf("unable to build channel info: %w",
×
1953
                                err)
×
1954
                }
×
1955

1956
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1957
                if err != nil {
×
1958
                        return fmt.Errorf("unable to extract channel "+
×
1959
                                "policies: %w", err)
×
1960
                }
×
1961

1962
                policy1, policy2, err = getAndBuildChanPolicies(
×
1963
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1964
                )
×
1965
                if err != nil {
×
1966
                        return fmt.Errorf("unable to build channel "+
×
1967
                                "policies: %w", err)
×
1968
                }
×
1969

1970
                return nil
×
1971
        }, sqldb.NoOpReset)
1972
        if err != nil {
×
1973
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1974
                        err)
×
1975
        }
×
1976

1977
        return edge, policy1, policy2, nil
×
1978
}
1979

1980
// HasChannelEdge returns true if the database knows of a channel edge with the
1981
// passed channel ID, and false otherwise. If an edge with that ID is found
1982
// within the graph, then two time stamps representing the last time the edge
1983
// was updated for both directed edges are returned along with the boolean. If
1984
// it is not found, then the zombie index is checked and its result is returned
1985
// as the second boolean.
1986
//
1987
// NOTE: part of the V1Store interface.
1988
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1989
        bool, error) {
×
1990

×
1991
        ctx := context.TODO()
×
1992

×
1993
        var (
×
1994
                exists          bool
×
1995
                isZombie        bool
×
1996
                node1LastUpdate time.Time
×
1997
                node2LastUpdate time.Time
×
1998
        )
×
1999

×
2000
        // We'll query the cache with the shared lock held to allow multiple
×
2001
        // readers to access values in the cache concurrently if they exist.
×
2002
        s.cacheMu.RLock()
×
2003
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2004
                s.cacheMu.RUnlock()
×
2005
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2006
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2007
                exists, isZombie = entry.flags.unpack()
×
2008

×
2009
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2010
        }
×
2011
        s.cacheMu.RUnlock()
×
2012

×
2013
        s.cacheMu.Lock()
×
2014
        defer s.cacheMu.Unlock()
×
2015

×
2016
        // The item was not found with the shared lock, so we'll acquire the
×
2017
        // exclusive lock and check the cache again in case another method added
×
2018
        // the entry to the cache while no lock was held.
×
2019
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2020
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2021
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2022
                exists, isZombie = entry.flags.unpack()
×
2023

×
2024
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2025
        }
×
2026

NEW
2027
        chanIDB := channelIDToBytes(chanID)
×
2028
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2029
                channel, err := db.GetChannelBySCID(
×
2030
                        ctx, sqlc.GetChannelBySCIDParams{
×
NEW
2031
                                Scid:    chanIDB,
×
2032
                                Version: int16(ProtocolV1),
×
2033
                        },
×
2034
                )
×
2035
                if errors.Is(err, sql.ErrNoRows) {
×
2036
                        // Check if it is a zombie channel.
×
2037
                        isZombie, err = db.IsZombieChannel(
×
2038
                                ctx, sqlc.IsZombieChannelParams{
×
NEW
2039
                                        Scid:    chanIDB,
×
2040
                                        Version: int16(ProtocolV1),
×
2041
                                },
×
2042
                        )
×
2043
                        if err != nil {
×
2044
                                return fmt.Errorf("could not check if channel "+
×
2045
                                        "is zombie: %w", err)
×
2046
                        }
×
2047

2048
                        return nil
×
2049
                } else if err != nil {
×
2050
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2051
                }
×
2052

2053
                exists = true
×
2054

×
2055
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2056
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2057
                                Version:   int16(ProtocolV1),
×
2058
                                ChannelID: channel.ID,
×
2059
                                NodeID:    channel.NodeID1,
×
2060
                        },
×
2061
                )
×
2062
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2063
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2064
                                err)
×
2065
                } else if err == nil {
×
2066
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2067
                }
×
2068

2069
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2070
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2071
                                Version:   int16(ProtocolV1),
×
2072
                                ChannelID: channel.ID,
×
2073
                                NodeID:    channel.NodeID2,
×
2074
                        },
×
2075
                )
×
2076
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2077
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2078
                                err)
×
2079
                } else if err == nil {
×
2080
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2081
                }
×
2082

2083
                return nil
×
2084
        }, sqldb.NoOpReset)
2085
        if err != nil {
×
2086
                return time.Time{}, time.Time{}, false, false,
×
2087
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2088
        }
×
2089

2090
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2091
                upd1Time: node1LastUpdate.Unix(),
×
2092
                upd2Time: node2LastUpdate.Unix(),
×
2093
                flags:    packRejectFlags(exists, isZombie),
×
2094
        })
×
2095

×
2096
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2097
}
2098

2099
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2100
// passed channel point (outpoint). If the passed channel doesn't exist within
2101
// the database, then ErrEdgeNotFound is returned.
2102
//
2103
// NOTE: part of the V1Store interface.
2104
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2105
        var (
×
2106
                ctx       = context.TODO()
×
2107
                channelID uint64
×
2108
        )
×
2109
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2110
                chanID, err := db.GetSCIDByOutpoint(
×
2111
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2112
                                Outpoint: chanPoint.String(),
×
2113
                                Version:  int16(ProtocolV1),
×
2114
                        },
×
2115
                )
×
2116
                if errors.Is(err, sql.ErrNoRows) {
×
2117
                        return ErrEdgeNotFound
×
2118
                } else if err != nil {
×
2119
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2120
                                err)
×
2121
                }
×
2122

2123
                channelID = byteOrder.Uint64(chanID)
×
2124

×
2125
                return nil
×
2126
        }, sqldb.NoOpReset)
2127
        if err != nil {
×
2128
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2129
        }
×
2130

2131
        return channelID, nil
×
2132
}
2133

2134
// IsPublicNode is a helper method that determines whether the node with the
2135
// given public key is seen as a public node in the graph from the graph's
2136
// source node's point of view.
2137
//
2138
// NOTE: part of the V1Store interface.
2139
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2140
        ctx := context.TODO()
×
2141

×
2142
        var isPublic bool
×
2143
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2144
                var err error
×
2145
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2146

×
2147
                return err
×
2148
        }, sqldb.NoOpReset)
×
2149
        if err != nil {
×
2150
                return false, fmt.Errorf("unable to check if node is "+
×
2151
                        "public: %w", err)
×
2152
        }
×
2153

2154
        return isPublic, nil
×
2155
}
2156

2157
// FetchChanInfos returns the set of channel edges that correspond to the passed
2158
// channel ID's. If an edge is the query is unknown to the database, it will
2159
// skipped and the result will contain only those edges that exist at the time
2160
// of the query. This can be used to respond to peer queries that are seeking to
2161
// fill in gaps in their view of the channel graph.
2162
//
2163
// NOTE: part of the V1Store interface.
2164
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2165
        var (
×
2166
                ctx   = context.TODO()
×
2167
                edges []ChannelEdge
×
2168
        )
×
2169
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2170
                for _, chanID := range chanIDs {
×
NEW
2171
                        chanIDB := channelIDToBytes(chanID)
×
2172

×
2173
                        // TODO(elle): potentially optimize this by using
×
2174
                        //  sqlc.slice() once that works for both SQLite and
×
2175
                        //  Postgres.
×
2176
                        row, err := db.GetChannelBySCIDWithPolicies(
×
2177
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
NEW
2178
                                        Scid:    chanIDB,
×
2179
                                        Version: int16(ProtocolV1),
×
2180
                                },
×
2181
                        )
×
2182
                        if errors.Is(err, sql.ErrNoRows) {
×
2183
                                continue
×
2184
                        } else if err != nil {
×
2185
                                return fmt.Errorf("unable to fetch channel: %w",
×
2186
                                        err)
×
2187
                        }
×
2188

2189
                        node1, node2, err := buildNodes(
×
2190
                                ctx, db, row.Node, row.Node_2,
×
2191
                        )
×
2192
                        if err != nil {
×
2193
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2194
                                        err)
×
2195
                        }
×
2196

2197
                        edge, err := getAndBuildEdgeInfo(
×
2198
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2199
                                row.Channel, node1.PubKeyBytes,
×
2200
                                node2.PubKeyBytes,
×
2201
                        )
×
2202
                        if err != nil {
×
2203
                                return fmt.Errorf("unable to build "+
×
2204
                                        "channel info: %w", err)
×
2205
                        }
×
2206

2207
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2208
                        if err != nil {
×
2209
                                return fmt.Errorf("unable to extract channel "+
×
2210
                                        "policies: %w", err)
×
2211
                        }
×
2212

2213
                        p1, p2, err := getAndBuildChanPolicies(
×
2214
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2215
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2216
                        )
×
2217
                        if err != nil {
×
2218
                                return fmt.Errorf("unable to build channel "+
×
2219
                                        "policies: %w", err)
×
2220
                        }
×
2221

2222
                        edges = append(edges, ChannelEdge{
×
2223
                                Info:    edge,
×
2224
                                Policy1: p1,
×
2225
                                Policy2: p2,
×
2226
                                Node1:   node1,
×
2227
                                Node2:   node2,
×
2228
                        })
×
2229
                }
2230

2231
                return nil
×
2232
        }, func() {
×
2233
                edges = nil
×
2234
        })
×
2235
        if err != nil {
×
2236
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2237
        }
×
2238

2239
        return edges, nil
×
2240
}
2241

2242
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2243
// ID's that we don't know and are not known zombies of the passed set. In other
2244
// words, we perform a set difference of our set of chan ID's and the ones
2245
// passed in. This method can be used by callers to determine the set of
2246
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2247
// known zombies is also returned.
2248
//
2249
// NOTE: part of the V1Store interface.
2250
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2251
        []ChannelUpdateInfo, error) {
×
2252

×
2253
        var (
×
2254
                ctx          = context.TODO()
×
2255
                newChanIDs   []uint64
×
2256
                knownZombies []ChannelUpdateInfo
×
2257
        )
×
2258
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2259
                for _, chanInfo := range chansInfo {
×
2260
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
NEW
2261
                        chanIDB := channelIDToBytes(channelID)
×
2262

×
2263
                        // TODO(elle): potentially optimize this by using
×
2264
                        //  sqlc.slice() once that works for both SQLite and
×
2265
                        //  Postgres.
×
2266
                        _, err := db.GetChannelBySCID(
×
2267
                                ctx, sqlc.GetChannelBySCIDParams{
×
2268
                                        Version: int16(ProtocolV1),
×
NEW
2269
                                        Scid:    chanIDB,
×
2270
                                },
×
2271
                        )
×
2272
                        if err == nil {
×
2273
                                continue
×
2274
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
2275
                                return fmt.Errorf("unable to fetch channel: %w",
×
2276
                                        err)
×
2277
                        }
×
2278

2279
                        isZombie, err := db.IsZombieChannel(
×
2280
                                ctx, sqlc.IsZombieChannelParams{
×
NEW
2281
                                        Scid:    chanIDB,
×
2282
                                        Version: int16(ProtocolV1),
×
2283
                                },
×
2284
                        )
×
2285
                        if err != nil {
×
2286
                                return fmt.Errorf("unable to fetch zombie "+
×
2287
                                        "channel: %w", err)
×
2288
                        }
×
2289

2290
                        if isZombie {
×
2291
                                knownZombies = append(knownZombies, chanInfo)
×
2292

×
2293
                                continue
×
2294
                        }
2295

2296
                        newChanIDs = append(newChanIDs, channelID)
×
2297
                }
2298

2299
                return nil
×
2300
        }, func() {
×
2301
                newChanIDs = nil
×
2302
                knownZombies = nil
×
2303
        })
×
2304
        if err != nil {
×
2305
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2306
        }
×
2307

2308
        return newChanIDs, knownZombies, nil
×
2309
}
2310

2311
// PruneGraphNodes is a garbage collection method which attempts to prune out
2312
// any nodes from the channel graph that are currently unconnected. This ensure
2313
// that we only maintain a graph of reachable nodes. In the event that a pruned
2314
// node gains more channels, it will be re-added back to the graph.
2315
//
2316
// NOTE: this prunes nodes across protocol versions. It will never prune the
2317
// source nodes.
2318
//
2319
// NOTE: part of the V1Store interface.
2320
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2321
        var ctx = context.TODO()
×
2322

×
2323
        var prunedNodes []route.Vertex
×
2324
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2325
                var err error
×
2326
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2327

×
2328
                return err
×
2329
        }, func() {
×
2330
                prunedNodes = nil
×
2331
        })
×
2332
        if err != nil {
×
2333
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2334
        }
×
2335

2336
        return prunedNodes, nil
×
2337
}
2338

2339
// PruneGraph prunes newly closed channels from the channel graph in response
2340
// to a new block being solved on the network. Any transactions which spend the
2341
// funding output of any known channels within he graph will be deleted.
2342
// Additionally, the "prune tip", or the last block which has been used to
2343
// prune the graph is stored so callers can ensure the graph is fully in sync
2344
// with the current UTXO state. A slice of channels that have been closed by
2345
// the target block along with any pruned nodes are returned if the function
2346
// succeeds without error.
2347
//
2348
// NOTE: part of the V1Store interface.
2349
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2350
        blockHash *chainhash.Hash, blockHeight uint32) (
2351
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2352

×
2353
        ctx := context.TODO()
×
2354

×
2355
        s.cacheMu.Lock()
×
2356
        defer s.cacheMu.Unlock()
×
2357

×
2358
        var (
×
2359
                closedChans []*models.ChannelEdgeInfo
×
2360
                prunedNodes []route.Vertex
×
2361
        )
×
2362
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2363
                for _, outpoint := range spentOutputs {
×
2364
                        // TODO(elle): potentially optimize this by using
×
2365
                        //  sqlc.slice() once that works for both SQLite and
×
2366
                        //  Postgres.
×
2367
                        //
×
2368
                        // NOTE: this fetches channels for all protocol
×
2369
                        // versions.
×
2370
                        row, err := db.GetChannelByOutpoint(
×
2371
                                ctx, outpoint.String(),
×
2372
                        )
×
2373
                        if errors.Is(err, sql.ErrNoRows) {
×
2374
                                continue
×
2375
                        } else if err != nil {
×
2376
                                return fmt.Errorf("unable to fetch channel: %w",
×
2377
                                        err)
×
2378
                        }
×
2379

2380
                        node1, node2, err := buildNodeVertices(
×
2381
                                row.Node1Pubkey, row.Node2Pubkey,
×
2382
                        )
×
2383
                        if err != nil {
×
2384
                                return err
×
2385
                        }
×
2386

2387
                        info, err := getAndBuildEdgeInfo(
×
2388
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2389
                                row.Channel, node1, node2,
×
2390
                        )
×
2391
                        if err != nil {
×
2392
                                return err
×
2393
                        }
×
2394

2395
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2396
                        if err != nil {
×
2397
                                return fmt.Errorf("unable to delete "+
×
2398
                                        "channel: %w", err)
×
2399
                        }
×
2400

2401
                        closedChans = append(closedChans, info)
×
2402
                }
2403

2404
                err := db.UpsertPruneLogEntry(
×
2405
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2406
                                BlockHash:   blockHash[:],
×
2407
                                BlockHeight: int64(blockHeight),
×
2408
                        },
×
2409
                )
×
2410
                if err != nil {
×
2411
                        return fmt.Errorf("unable to insert prune log "+
×
2412
                                "entry: %w", err)
×
2413
                }
×
2414

2415
                // Now that we've pruned some channels, we'll also prune any
2416
                // nodes that no longer have any channels.
2417
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2418
                if err != nil {
×
2419
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2420
                                err)
×
2421
                }
×
2422

2423
                return nil
×
2424
        }, func() {
×
2425
                prunedNodes = nil
×
2426
                closedChans = nil
×
2427
        })
×
2428
        if err != nil {
×
2429
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2430
        }
×
2431

2432
        for _, channel := range closedChans {
×
2433
                s.rejectCache.remove(channel.ChannelID)
×
2434
                s.chanCache.remove(channel.ChannelID)
×
2435
        }
×
2436

2437
        return closedChans, prunedNodes, nil
×
2438
}
2439

2440
// ChannelView returns the verifiable edge information for each active channel
2441
// within the known channel graph. The set of UTXOs (along with their scripts)
2442
// returned are the ones that need to be watched on chain to detect channel
2443
// closes on the resident blockchain.
2444
//
2445
// NOTE: part of the V1Store interface.
2446
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2447
        var (
×
2448
                ctx        = context.TODO()
×
2449
                edgePoints []EdgePoint
×
2450
        )
×
2451

×
2452
        handleChannel := func(db SQLQueries,
×
2453
                channel sqlc.ListChannelsPaginatedRow) error {
×
2454

×
2455
                pkScript, err := genMultiSigP2WSH(
×
2456
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
2457
                )
×
2458
                if err != nil {
×
2459
                        return err
×
2460
                }
×
2461

2462
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2463
                if err != nil {
×
2464
                        return err
×
2465
                }
×
2466

2467
                edgePoints = append(edgePoints, EdgePoint{
×
2468
                        FundingPkScript: pkScript,
×
2469
                        OutPoint:        *op,
×
2470
                })
×
2471

×
2472
                return nil
×
2473
        }
2474

2475
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2476
                lastID := int64(-1)
×
2477
                for {
×
2478
                        rows, err := db.ListChannelsPaginated(
×
2479
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2480
                                        Version: int16(ProtocolV1),
×
2481
                                        ID:      lastID,
×
2482
                                        Limit:   pageSize,
×
2483
                                },
×
2484
                        )
×
2485
                        if err != nil {
×
2486
                                return err
×
2487
                        }
×
2488

2489
                        if len(rows) == 0 {
×
2490
                                break
×
2491
                        }
2492

2493
                        for _, row := range rows {
×
2494
                                err := handleChannel(db, row)
×
2495
                                if err != nil {
×
2496
                                        return err
×
2497
                                }
×
2498

2499
                                lastID = row.ID
×
2500
                        }
2501
                }
2502

2503
                return nil
×
2504
        }, func() {
×
2505
                edgePoints = nil
×
2506
        })
×
2507
        if err != nil {
×
2508
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2509
        }
×
2510

2511
        return edgePoints, nil
×
2512
}
2513

2514
// PruneTip returns the block height and hash of the latest block that has been
2515
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2516
// to tell if the graph is currently in sync with the current best known UTXO
2517
// state.
2518
//
2519
// NOTE: part of the V1Store interface.
2520
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2521
        var (
×
2522
                ctx       = context.TODO()
×
2523
                tipHash   chainhash.Hash
×
2524
                tipHeight uint32
×
2525
        )
×
2526
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2527
                pruneTip, err := db.GetPruneTip(ctx)
×
2528
                if errors.Is(err, sql.ErrNoRows) {
×
2529
                        return ErrGraphNeverPruned
×
2530
                } else if err != nil {
×
2531
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2532
                }
×
2533

2534
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2535
                tipHeight = uint32(pruneTip.BlockHeight)
×
2536

×
2537
                return nil
×
2538
        }, sqldb.NoOpReset)
2539
        if err != nil {
×
2540
                return nil, 0, err
×
2541
        }
×
2542

2543
        return &tipHash, tipHeight, nil
×
2544
}
2545

2546
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2547
//
2548
// NOTE: this prunes nodes across protocol versions. It will never prune the
2549
// source nodes.
2550
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2551
        db SQLQueries) ([]route.Vertex, error) {
×
2552

×
2553
        // Fetch all un-connected nodes from the database.
×
2554
        // NOTE: this will not include any nodes that are listed in the
×
2555
        // source table.
×
2556
        nodes, err := db.GetUnconnectedNodes(ctx)
×
2557
        if err != nil {
×
2558
                return nil, fmt.Errorf("unable to fetch unconnected nodes: %w",
×
2559
                        err)
×
2560
        }
×
2561

2562
        prunedNodes := make([]route.Vertex, 0, len(nodes))
×
2563
        for _, node := range nodes {
×
2564
                // TODO(elle): update to use sqlc.slice() once that works.
×
2565
                if err = db.DeleteNode(ctx, node.ID); err != nil {
×
2566
                        return nil, fmt.Errorf("unable to delete "+
×
2567
                                "node(id=%d): %w", node.ID, err)
×
2568
                }
×
2569

2570
                pubKey, err := route.NewVertexFromBytes(node.PubKey)
×
2571
                if err != nil {
×
2572
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2573
                                "for node(id=%d): %w", node.ID, err)
×
2574
                }
×
2575

2576
                prunedNodes = append(prunedNodes, pubKey)
×
2577
        }
2578

2579
        return prunedNodes, nil
×
2580
}
2581

2582
// DisconnectBlockAtHeight is used to indicate that the block specified
2583
// by the passed height has been disconnected from the main chain. This
2584
// will "rewind" the graph back to the height below, deleting channels
2585
// that are no longer confirmed from the graph. The prune log will be
2586
// set to the last prune height valid for the remaining chain.
2587
// Channels that were removed from the graph resulting from the
2588
// disconnected block are returned.
2589
//
2590
// NOTE: part of the V1Store interface.
2591
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2592
        []*models.ChannelEdgeInfo, error) {
×
2593

×
2594
        ctx := context.TODO()
×
2595

×
2596
        var (
×
2597
                // Every channel having a ShortChannelID starting at 'height'
×
2598
                // will no longer be confirmed.
×
2599
                startShortChanID = lnwire.ShortChannelID{
×
2600
                        BlockHeight: height,
×
2601
                }
×
2602

×
2603
                // Delete everything after this height from the db up until the
×
2604
                // SCID alias range.
×
2605
                endShortChanID = aliasmgr.StartingAlias
×
2606

×
2607
                removedChans []*models.ChannelEdgeInfo
×
2608

×
NEW
2609
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
NEW
2610
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
NEW
2611
        )
×
2612

×
2613
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2614
                rows, err := db.GetChannelsBySCIDRange(
×
2615
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
NEW
2616
                                StartScid: chanIDStart,
×
NEW
2617
                                EndScid:   chanIDEnd,
×
2618
                        },
×
2619
                )
×
2620
                if err != nil {
×
2621
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2622
                }
×
2623

2624
                for _, row := range rows {
×
2625
                        node1, node2, err := buildNodeVertices(
×
2626
                                row.Node1PubKey, row.Node2PubKey,
×
2627
                        )
×
2628
                        if err != nil {
×
2629
                                return err
×
2630
                        }
×
2631

2632
                        channel, err := getAndBuildEdgeInfo(
×
2633
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2634
                                row.Channel, node1, node2,
×
2635
                        )
×
2636
                        if err != nil {
×
2637
                                return err
×
2638
                        }
×
2639

2640
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2641
                        if err != nil {
×
2642
                                return fmt.Errorf("unable to delete "+
×
2643
                                        "channel: %w", err)
×
2644
                        }
×
2645

2646
                        removedChans = append(removedChans, channel)
×
2647
                }
2648

2649
                return db.DeletePruneLogEntriesInRange(
×
2650
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2651
                                StartHeight: int64(height),
×
2652
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2653
                        },
×
2654
                )
×
2655
        }, func() {
×
2656
                removedChans = nil
×
2657
        })
×
2658
        if err != nil {
×
2659
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2660
                        "height: %w", err)
×
2661
        }
×
2662

2663
        for _, channel := range removedChans {
×
2664
                s.rejectCache.remove(channel.ChannelID)
×
2665
                s.chanCache.remove(channel.ChannelID)
×
2666
        }
×
2667

2668
        return removedChans, nil
×
2669
}
2670

2671
// AddEdgeProof sets the proof of an existing edge in the graph database.
2672
//
2673
// NOTE: part of the V1Store interface.
2674
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2675
        proof *models.ChannelAuthProof) error {
×
2676

×
2677
        var (
×
2678
                ctx       = context.TODO()
×
2679
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2680
        )
×
2681

×
2682
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2683
                res, err := db.AddV1ChannelProof(
×
2684
                        ctx, sqlc.AddV1ChannelProofParams{
×
NEW
2685
                                Scid:              scidBytes,
×
2686
                                Node1Signature:    proof.NodeSig1Bytes,
×
2687
                                Node2Signature:    proof.NodeSig2Bytes,
×
2688
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2689
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2690
                        },
×
2691
                )
×
2692
                if err != nil {
×
2693
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2694
                }
×
2695

2696
                n, err := res.RowsAffected()
×
2697
                if err != nil {
×
2698
                        return err
×
2699
                }
×
2700

2701
                if n == 0 {
×
2702
                        return fmt.Errorf("no rows affected when adding edge "+
×
2703
                                "proof for SCID %v", scid)
×
2704
                } else if n > 1 {
×
2705
                        return fmt.Errorf("multiple rows affected when adding "+
×
2706
                                "edge proof for SCID %v: %d rows affected",
×
2707
                                scid, n)
×
2708
                }
×
2709

2710
                return nil
×
2711
        }, sqldb.NoOpReset)
2712
        if err != nil {
×
2713
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2714
        }
×
2715

2716
        return nil
×
2717
}
2718

2719
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2720
// that we can ignore channel announcements that we know to be closed without
2721
// having to validate them and fetch a block.
2722
//
2723
// NOTE: part of the V1Store interface.
2724
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2725
        var (
×
2726
                ctx     = context.TODO()
×
2727
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2728
        )
×
2729

×
2730
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2731
                return db.InsertClosedChannel(ctx, chanIDB)
×
2732
        }, sqldb.NoOpReset)
×
2733
}
2734

2735
// IsClosedScid checks whether a channel identified by the passed in scid is
2736
// closed. This helps avoid having to perform expensive validation checks.
2737
//
2738
// NOTE: part of the V1Store interface.
2739
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2740
        var (
×
2741
                ctx      = context.TODO()
×
2742
                isClosed bool
×
2743
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2744
        )
×
2745
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2746
                var err error
×
NEW
2747
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2748
                if err != nil {
×
2749
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2750
                                err)
×
2751
                }
×
2752

2753
                return nil
×
2754
        }, sqldb.NoOpReset)
2755
        if err != nil {
×
2756
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2757
                        err)
×
2758
        }
×
2759

2760
        return isClosed, nil
×
2761
}
2762

2763
// GraphSession will provide the call-back with access to a NodeTraverser
2764
// instance which can be used to perform queries against the channel graph.
2765
//
2766
// NOTE: part of the V1Store interface.
2767
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error) error {
×
2768
        var ctx = context.TODO()
×
2769

×
2770
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2771
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2772
        }, sqldb.NoOpReset)
×
2773
}
2774

2775
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2776
// read only transaction for a consistent view of the graph.
2777
type sqlNodeTraverser struct {
2778
        db    SQLQueries
2779
        chain chainhash.Hash
2780
}
2781

2782
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2783
// NodeTraverser interface.
2784
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2785

2786
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2787
func newSQLNodeTraverser(db SQLQueries,
2788
        chain chainhash.Hash) *sqlNodeTraverser {
×
2789

×
2790
        return &sqlNodeTraverser{
×
2791
                db:    db,
×
2792
                chain: chain,
×
2793
        }
×
2794
}
×
2795

2796
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2797
// node.
2798
//
2799
// NOTE: Part of the NodeTraverser interface.
2800
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2801
        cb func(channel *DirectedChannel) error) error {
×
2802

×
2803
        ctx := context.TODO()
×
2804

×
2805
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2806
}
×
2807

2808
// FetchNodeFeatures returns the features of the given node. If the node is
2809
// unknown, assume no additional features are supported.
2810
//
2811
// NOTE: Part of the NodeTraverser interface.
2812
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2813
        *lnwire.FeatureVector, error) {
×
2814

×
2815
        ctx := context.TODO()
×
2816

×
2817
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2818
}
×
2819

2820
// forEachNodeDirectedChannel iterates through all channels of a given
2821
// node, executing the passed callback on the directed edge representing the
2822
// channel and its incoming policy. If the node is not found, no error is
2823
// returned.
2824
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2825
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2826

×
2827
        toNodeCallback := func() route.Vertex {
×
2828
                return nodePub
×
2829
        }
×
2830

2831
        dbID, err := db.GetNodeIDByPubKey(
×
2832
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2833
                        Version: int16(ProtocolV1),
×
2834
                        PubKey:  nodePub[:],
×
2835
                },
×
2836
        )
×
2837
        if errors.Is(err, sql.ErrNoRows) {
×
2838
                return nil
×
2839
        } else if err != nil {
×
2840
                return fmt.Errorf("unable to fetch node: %w", err)
×
2841
        }
×
2842

2843
        rows, err := db.ListChannelsByNodeID(
×
2844
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2845
                        Version: int16(ProtocolV1),
×
2846
                        NodeID1: dbID,
×
2847
                },
×
2848
        )
×
2849
        if err != nil {
×
2850
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2851
        }
×
2852

2853
        // Exit early if there are no channels for this node so we don't
2854
        // do the unnecessary feature fetching.
2855
        if len(rows) == 0 {
×
2856
                return nil
×
2857
        }
×
2858

2859
        features, err := getNodeFeatures(ctx, db, dbID)
×
2860
        if err != nil {
×
2861
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2862
        }
×
2863

2864
        for _, row := range rows {
×
2865
                node1, node2, err := buildNodeVertices(
×
2866
                        row.Node1Pubkey, row.Node2Pubkey,
×
2867
                )
×
2868
                if err != nil {
×
2869
                        return fmt.Errorf("unable to build node vertices: %w",
×
2870
                                err)
×
2871
                }
×
2872

2873
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
2874

×
2875
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2876
                if err != nil {
×
2877
                        return err
×
2878
                }
×
2879

2880
                var p1, p2 *models.CachedEdgePolicy
×
2881
                if dbPol1 != nil {
×
2882
                        policy1, err := buildChanPolicy(
×
2883
                                *dbPol1, edge.ChannelID, nil, node2, true,
×
2884
                        )
×
2885
                        if err != nil {
×
2886
                                return err
×
2887
                        }
×
2888

2889
                        p1 = models.NewCachedPolicy(policy1)
×
2890
                }
2891
                if dbPol2 != nil {
×
2892
                        policy2, err := buildChanPolicy(
×
2893
                                *dbPol2, edge.ChannelID, nil, node1, false,
×
2894
                        )
×
2895
                        if err != nil {
×
2896
                                return err
×
2897
                        }
×
2898

2899
                        p2 = models.NewCachedPolicy(policy2)
×
2900
                }
2901

2902
                // Determine the outgoing and incoming policy for this
2903
                // channel and node combo.
2904
                outPolicy, inPolicy := p1, p2
×
2905
                if p1 != nil && node2 == nodePub {
×
2906
                        outPolicy, inPolicy = p2, p1
×
2907
                } else if p2 != nil && node1 != nodePub {
×
2908
                        outPolicy, inPolicy = p2, p1
×
2909
                }
×
2910

2911
                var cachedInPolicy *models.CachedEdgePolicy
×
2912
                if inPolicy != nil {
×
2913
                        cachedInPolicy = inPolicy
×
2914
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2915
                        cachedInPolicy.ToNodeFeatures = features
×
2916
                }
×
2917

2918
                directedChannel := &DirectedChannel{
×
2919
                        ChannelID:    edge.ChannelID,
×
2920
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2921
                        OtherNode:    edge.NodeKey2Bytes,
×
2922
                        Capacity:     edge.Capacity,
×
2923
                        OutPolicySet: outPolicy != nil,
×
2924
                        InPolicy:     cachedInPolicy,
×
2925
                }
×
2926
                if outPolicy != nil {
×
2927
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2928
                                directedChannel.InboundFee = fee
×
2929
                        })
×
2930
                }
2931

2932
                if nodePub == edge.NodeKey2Bytes {
×
2933
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2934
                }
×
2935

2936
                if err := cb(directedChannel); err != nil {
×
2937
                        return err
×
2938
                }
×
2939
        }
2940

2941
        return nil
×
2942
}
2943

2944
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2945
// and executes the provided callback for each node.
2946
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
2947
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
2948

×
2949
        lastID := int64(-1)
×
2950

×
2951
        for {
×
2952
                nodes, err := db.ListNodeIDsAndPubKeys(
×
2953
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
2954
                                Version: int16(ProtocolV1),
×
2955
                                ID:      lastID,
×
2956
                                Limit:   pageSize,
×
2957
                        },
×
2958
                )
×
2959
                if err != nil {
×
2960
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
2961
                }
×
2962

2963
                if len(nodes) == 0 {
×
2964
                        break
×
2965
                }
2966

2967
                for _, node := range nodes {
×
2968
                        var pub route.Vertex
×
2969
                        copy(pub[:], node.PubKey)
×
2970

×
2971
                        if err := cb(node.ID, pub); err != nil {
×
2972
                                return fmt.Errorf("forEachNodeCacheable "+
×
2973
                                        "callback failed for node(id=%d): %w",
×
2974
                                        node.ID, err)
×
2975
                        }
×
2976

2977
                        lastID = node.ID
×
2978
                }
2979
        }
2980

2981
        return nil
×
2982
}
2983

2984
// forEachNodeChannel iterates through all channels of a node, executing
2985
// the passed callback on each. The call-back is provided with the channel's
2986
// edge information, the outgoing policy and the incoming policy for the
2987
// channel and node combo.
2988
func forEachNodeChannel(ctx context.Context, db SQLQueries,
2989
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
2990
                *models.ChannelEdgePolicy,
2991
                *models.ChannelEdgePolicy) error) error {
×
2992

×
2993
        // Get all the V1 channels for this node.Add commentMore actions
×
2994
        rows, err := db.ListChannelsByNodeID(
×
2995
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2996
                        Version: int16(ProtocolV1),
×
2997
                        NodeID1: id,
×
2998
                },
×
2999
        )
×
3000
        if err != nil {
×
3001
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3002
        }
×
3003

3004
        // Call the call-back for each channel and its known policies.
3005
        for _, row := range rows {
×
3006
                node1, node2, err := buildNodeVertices(
×
3007
                        row.Node1Pubkey, row.Node2Pubkey,
×
3008
                )
×
3009
                if err != nil {
×
3010
                        return fmt.Errorf("unable to build node vertices: %w",
×
3011
                                err)
×
3012
                }
×
3013

3014
                edge, err := getAndBuildEdgeInfo(
×
3015
                        ctx, db, chain, row.Channel.ID, row.Channel, node1,
×
3016
                        node2,
×
3017
                )
×
3018
                if err != nil {
×
3019
                        return fmt.Errorf("unable to build channel info: %w",
×
3020
                                err)
×
3021
                }
×
3022

3023
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3024
                if err != nil {
×
3025
                        return fmt.Errorf("unable to extract channel "+
×
3026
                                "policies: %w", err)
×
3027
                }
×
3028

3029
                p1, p2, err := getAndBuildChanPolicies(
×
3030
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3031
                )
×
3032
                if err != nil {
×
3033
                        return fmt.Errorf("unable to build channel "+
×
3034
                                "policies: %w", err)
×
3035
                }
×
3036

3037
                // Determine the outgoing and incoming policy for this
3038
                // channel and node combo.
3039
                p1ToNode := row.Channel.NodeID2
×
3040
                p2ToNode := row.Channel.NodeID1
×
3041
                outPolicy, inPolicy := p1, p2
×
3042
                if (p1 != nil && p1ToNode == id) ||
×
3043
                        (p2 != nil && p2ToNode != id) {
×
3044

×
3045
                        outPolicy, inPolicy = p2, p1
×
3046
                }
×
3047

3048
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3049
                        return err
×
3050
                }
×
3051
        }
3052

3053
        return nil
×
3054
}
3055

3056
// updateChanEdgePolicy upserts the channel policy info we have stored for
3057
// a channel we already know of.
3058
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3059
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3060
        error) {
×
3061

×
3062
        var (
×
3063
                node1Pub, node2Pub route.Vertex
×
3064
                isNode1            bool
×
3065
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3066
        )
×
3067

×
3068
        // Check that this edge policy refers to a channel that we already
×
3069
        // know of. We do this explicitly so that we can return the appropriate
×
3070
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3071
        // abort the transaction which would abort the entire batch.
×
3072
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3073
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
NEW
3074
                        Scid:    chanIDB,
×
3075
                        Version: int16(ProtocolV1),
×
3076
                },
×
3077
        )
×
3078
        if errors.Is(err, sql.ErrNoRows) {
×
3079
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3080
        } else if err != nil {
×
3081
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3082
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3083
        }
×
3084

3085
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3086
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3087

×
3088
        // Figure out which node this edge is from.
×
3089
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3090
        nodeID := dbChan.NodeID1
×
3091
        if !isNode1 {
×
3092
                nodeID = dbChan.NodeID2
×
3093
        }
×
3094

3095
        var (
×
3096
                inboundBase sql.NullInt64
×
3097
                inboundRate sql.NullInt64
×
3098
        )
×
3099
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3100
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3101
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3102
        })
×
3103

3104
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3105
                Version:     int16(ProtocolV1),
×
3106
                ChannelID:   dbChan.ID,
×
3107
                NodeID:      nodeID,
×
3108
                Timelock:    int32(edge.TimeLockDelta),
×
3109
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3110
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3111
                MinHtlcMsat: int64(edge.MinHTLC),
×
3112
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3113
                Disabled: sql.NullBool{
×
3114
                        Valid: true,
×
3115
                        Bool:  edge.IsDisabled(),
×
3116
                },
×
3117
                MaxHtlcMsat: sql.NullInt64{
×
3118
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3119
                        Int64: int64(edge.MaxHTLC),
×
3120
                },
×
3121
                InboundBaseFeeMsat:      inboundBase,
×
3122
                InboundFeeRateMilliMsat: inboundRate,
×
3123
                Signature:               edge.SigBytes,
×
3124
        })
×
3125
        if err != nil {
×
3126
                return node1Pub, node2Pub, isNode1,
×
3127
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3128
        }
×
3129

3130
        // Convert the flat extra opaque data into a map of TLV types to
3131
        // values.
3132
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3133
        if err != nil {
×
3134
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3135
                        "marshal extra opaque data: %w", err)
×
3136
        }
×
3137

3138
        // Update the channel policy's extra signed fields.
3139
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3140
        if err != nil {
×
3141
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3142
                        "policy extra TLVs: %w", err)
×
3143
        }
×
3144

3145
        return node1Pub, node2Pub, isNode1, nil
×
3146
}
3147

3148
// getNodeByPubKey attempts to look up a target node by its public key.
3149
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3150
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3151

×
3152
        dbNode, err := db.GetNodeByPubKey(
×
3153
                ctx, sqlc.GetNodeByPubKeyParams{
×
3154
                        Version: int16(ProtocolV1),
×
3155
                        PubKey:  pubKey[:],
×
3156
                },
×
3157
        )
×
3158
        if errors.Is(err, sql.ErrNoRows) {
×
3159
                return 0, nil, ErrGraphNodeNotFound
×
3160
        } else if err != nil {
×
3161
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3162
        }
×
3163

3164
        node, err := buildNode(ctx, db, &dbNode)
×
3165
        if err != nil {
×
3166
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3167
        }
×
3168

3169
        return dbNode.ID, node, nil
×
3170
}
3171

3172
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3173
// provided database channel row and the public keys of the two nodes
3174
// involved in the channel.
3175
func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub,
3176
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3177

×
3178
        return &models.CachedEdgeInfo{
×
3179
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
3180
                NodeKey1Bytes: node1Pub,
×
3181
                NodeKey2Bytes: node2Pub,
×
3182
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
3183
        }
×
3184
}
×
3185

3186
// buildNode constructs a LightningNode instance from the given database node
3187
// record. The node's features, addresses and extra signed fields are also
3188
// fetched from the database and set on the node.
3189
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
3190
        *models.LightningNode, error) {
×
3191

×
3192
        if dbNode.Version != int16(ProtocolV1) {
×
3193
                return nil, fmt.Errorf("unsupported node version: %d",
×
3194
                        dbNode.Version)
×
3195
        }
×
3196

3197
        var pub [33]byte
×
3198
        copy(pub[:], dbNode.PubKey)
×
3199

×
3200
        node := &models.LightningNode{
×
3201
                PubKeyBytes: pub,
×
3202
                Features:    lnwire.EmptyFeatureVector(),
×
3203
                LastUpdate:  time.Unix(0, 0),
×
3204
        }
×
3205

×
3206
        if len(dbNode.Signature) == 0 {
×
3207
                return node, nil
×
3208
        }
×
3209

3210
        node.HaveNodeAnnouncement = true
×
3211
        node.AuthSigBytes = dbNode.Signature
×
3212
        node.Alias = dbNode.Alias.String
×
3213
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3214

×
3215
        var err error
×
3216
        if dbNode.Color.Valid {
×
3217
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3218
                if err != nil {
×
3219
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3220
                                err)
×
3221
                }
×
3222
        }
3223

3224
        // Fetch the node's features.
3225
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
3226
        if err != nil {
×
3227
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3228
                        "features: %w", dbNode.ID, err)
×
3229
        }
×
3230

3231
        // Fetch the node's addresses.
3232
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
3233
        if err != nil {
×
3234
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3235
                        "addresses: %w", dbNode.ID, err)
×
3236
        }
×
3237

3238
        // Fetch the node's extra signed fields.
3239
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
3240
        if err != nil {
×
3241
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3242
                        "extra signed fields: %w", dbNode.ID, err)
×
3243
        }
×
3244

3245
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
3246
        if err != nil {
×
3247
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3248
                        "fields: %w", err)
×
3249
        }
×
3250

3251
        if len(recs) != 0 {
×
3252
                node.ExtraOpaqueData = recs
×
3253
        }
×
3254

3255
        return node, nil
×
3256
}
3257

3258
// getNodeFeatures fetches the feature bits and constructs the feature vector
3259
// for a node with the given DB ID.
3260
func getNodeFeatures(ctx context.Context, db SQLQueries,
3261
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3262

×
3263
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3264
        if err != nil {
×
3265
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3266
                        nodeID, err)
×
3267
        }
×
3268

3269
        features := lnwire.EmptyFeatureVector()
×
3270
        for _, feature := range rows {
×
3271
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3272
        }
×
3273

3274
        return features, nil
×
3275
}
3276

3277
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3278
// given DB ID.
3279
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3280
        nodeID int64) (map[uint64][]byte, error) {
×
3281

×
3282
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3283
        if err != nil {
×
3284
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
3285
                        "signed fields: %w", nodeID, err)
×
3286
        }
×
3287

3288
        extraFields := make(map[uint64][]byte)
×
3289
        for _, field := range fields {
×
3290
                extraFields[uint64(field.Type)] = field.Value
×
3291
        }
×
3292

3293
        return extraFields, nil
×
3294
}
3295

3296
// upsertNode upserts the node record into the database. If the node already
3297
// exists, then the node's information is updated. If the node doesn't exist,
3298
// then a new node is created. The node's features, addresses and extra TLV
3299
// types are also updated. The node's DB ID is returned.
3300
func upsertNode(ctx context.Context, db SQLQueries,
3301
        node *models.LightningNode) (int64, error) {
×
3302

×
3303
        params := sqlc.UpsertNodeParams{
×
3304
                Version: int16(ProtocolV1),
×
3305
                PubKey:  node.PubKeyBytes[:],
×
3306
        }
×
3307

×
3308
        if node.HaveNodeAnnouncement {
×
3309
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3310
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3311
                params.Alias = sqldb.SQLStr(node.Alias)
×
3312
                params.Signature = node.AuthSigBytes
×
3313
        }
×
3314

3315
        nodeID, err := db.UpsertNode(ctx, params)
×
3316
        if err != nil {
×
3317
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3318
                        err)
×
3319
        }
×
3320

3321
        // We can exit here if we don't have the announcement yet.
3322
        if !node.HaveNodeAnnouncement {
×
3323
                return nodeID, nil
×
3324
        }
×
3325

3326
        // Update the node's features.
3327
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3328
        if err != nil {
×
3329
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3330
        }
×
3331

3332
        // Update the node's addresses.
3333
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3334
        if err != nil {
×
3335
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3336
        }
×
3337

3338
        // Convert the flat extra opaque data into a map of TLV types to
3339
        // values.
3340
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3341
        if err != nil {
×
3342
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3343
                        err)
×
3344
        }
×
3345

3346
        // Update the node's extra signed fields.
3347
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3348
        if err != nil {
×
3349
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3350
        }
×
3351

3352
        return nodeID, nil
×
3353
}
3354

3355
// upsertNodeFeatures updates the node's features node_features table. This
3356
// includes deleting any feature bits no longer present and inserting any new
3357
// feature bits. If the feature bit does not yet exist in the features table,
3358
// then an entry is created in that table first.
3359
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3360
        features *lnwire.FeatureVector) error {
×
3361

×
3362
        // Get any existing features for the node.
×
3363
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3364
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3365
                return err
×
3366
        }
×
3367

3368
        // Copy the nodes latest set of feature bits.
3369
        newFeatures := make(map[int32]struct{})
×
3370
        if features != nil {
×
3371
                for feature := range features.Features() {
×
3372
                        newFeatures[int32(feature)] = struct{}{}
×
3373
                }
×
3374
        }
3375

3376
        // For any current feature that already exists in the DB, remove it from
3377
        // the in-memory map. For any existing feature that does not exist in
3378
        // the in-memory map, delete it from the database.
3379
        for _, feature := range existingFeatures {
×
3380
                // The feature is still present, so there are no updates to be
×
3381
                // made.
×
3382
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3383
                        delete(newFeatures, feature.FeatureBit)
×
3384
                        continue
×
3385
                }
3386

3387
                // The feature is no longer present, so we remove it from the
3388
                // database.
3389
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3390
                        NodeID:     nodeID,
×
3391
                        FeatureBit: feature.FeatureBit,
×
3392
                })
×
3393
                if err != nil {
×
3394
                        return fmt.Errorf("unable to delete node(%d) "+
×
3395
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3396
                                err)
×
3397
                }
×
3398
        }
3399

3400
        // Any remaining entries in newFeatures are new features that need to be
3401
        // added to the database for the first time.
3402
        for feature := range newFeatures {
×
3403
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3404
                        NodeID:     nodeID,
×
3405
                        FeatureBit: feature,
×
3406
                })
×
3407
                if err != nil {
×
3408
                        return fmt.Errorf("unable to insert node(%d) "+
×
3409
                                "feature(%v): %w", nodeID, feature, err)
×
3410
                }
×
3411
        }
3412

3413
        return nil
×
3414
}
3415

3416
// fetchNodeFeatures fetches the features for a node with the given public key.
3417
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3418
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3419

×
3420
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3421
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3422
                        PubKey:  nodePub[:],
×
3423
                        Version: int16(ProtocolV1),
×
3424
                },
×
3425
        )
×
3426
        if err != nil {
×
3427
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3428
                        nodePub, err)
×
3429
        }
×
3430

3431
        features := lnwire.EmptyFeatureVector()
×
3432
        for _, bit := range rows {
×
3433
                features.Set(lnwire.FeatureBit(bit))
×
3434
        }
×
3435

3436
        return features, nil
×
3437
}
3438

3439
// dbAddressType is an enum type that represents the different address types
3440
// that we store in the node_addresses table. The address type determines how
3441
// the address is to be serialised/deserialize.
3442
type dbAddressType uint8
3443

3444
const (
3445
        addressTypeIPv4   dbAddressType = 1
3446
        addressTypeIPv6   dbAddressType = 2
3447
        addressTypeTorV2  dbAddressType = 3
3448
        addressTypeTorV3  dbAddressType = 4
3449
        addressTypeOpaque dbAddressType = math.MaxInt8
3450
)
3451

3452
// upsertNodeAddresses updates the node's addresses in the database. This
3453
// includes deleting any existing addresses and inserting the new set of
3454
// addresses. The deletion is necessary since the ordering of the addresses may
3455
// change, and we need to ensure that the database reflects the latest set of
3456
// addresses so that at the time of reconstructing the node announcement, the
3457
// order is preserved and the signature over the message remains valid.
3458
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3459
        addresses []net.Addr) error {
×
3460

×
3461
        // Delete any existing addresses for the node. This is required since
×
3462
        // even if the new set of addresses is the same, the ordering may have
×
3463
        // changed for a given address type.
×
3464
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3465
        if err != nil {
×
3466
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3467
                        nodeID, err)
×
3468
        }
×
3469

3470
        // Copy the nodes latest set of addresses.
3471
        newAddresses := map[dbAddressType][]string{
×
3472
                addressTypeIPv4:   {},
×
3473
                addressTypeIPv6:   {},
×
3474
                addressTypeTorV2:  {},
×
3475
                addressTypeTorV3:  {},
×
3476
                addressTypeOpaque: {},
×
3477
        }
×
3478
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3479
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3480
        }
×
3481

3482
        for _, address := range addresses {
×
3483
                switch addr := address.(type) {
×
3484
                case *net.TCPAddr:
×
3485
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3486
                                addAddr(addressTypeIPv4, addr)
×
3487
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3488
                                addAddr(addressTypeIPv6, addr)
×
3489
                        } else {
×
3490
                                return fmt.Errorf("unhandled IP address: %v",
×
3491
                                        addr)
×
3492
                        }
×
3493

3494
                case *tor.OnionAddr:
×
3495
                        switch len(addr.OnionService) {
×
3496
                        case tor.V2Len:
×
3497
                                addAddr(addressTypeTorV2, addr)
×
3498
                        case tor.V3Len:
×
3499
                                addAddr(addressTypeTorV3, addr)
×
3500
                        default:
×
3501
                                return fmt.Errorf("invalid length for a tor " +
×
3502
                                        "address")
×
3503
                        }
3504

3505
                case *lnwire.OpaqueAddrs:
×
3506
                        addAddr(addressTypeOpaque, addr)
×
3507

3508
                default:
×
3509
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3510
                }
3511
        }
3512

3513
        // Any remaining entries in newAddresses are new addresses that need to
3514
        // be added to the database for the first time.
3515
        for addrType, addrList := range newAddresses {
×
3516
                for position, addr := range addrList {
×
3517
                        err := db.InsertNodeAddress(
×
3518
                                ctx, sqlc.InsertNodeAddressParams{
×
3519
                                        NodeID:   nodeID,
×
3520
                                        Type:     int16(addrType),
×
3521
                                        Address:  addr,
×
3522
                                        Position: int32(position),
×
3523
                                },
×
3524
                        )
×
3525
                        if err != nil {
×
3526
                                return fmt.Errorf("unable to insert "+
×
3527
                                        "node(%d) address(%v): %w", nodeID,
×
3528
                                        addr, err)
×
3529
                        }
×
3530
                }
3531
        }
3532

3533
        return nil
×
3534
}
3535

3536
// getNodeAddresses fetches the addresses for a node with the given public key.
3537
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3538
        []net.Addr, error) {
×
3539

×
3540
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
3541
        // are returned in the same order as they were inserted.
×
3542
        rows, err := db.GetNodeAddressesByPubKey(
×
3543
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
3544
                        Version: int16(ProtocolV1),
×
3545
                        PubKey:  nodePub,
×
3546
                },
×
3547
        )
×
3548
        if err != nil {
×
3549
                return false, nil, err
×
3550
        }
×
3551

3552
        // GetNodeAddressesByPubKey uses a left join so there should always be
3553
        // at least one row returned if the node exists even if it has no
3554
        // addresses.
3555
        if len(rows) == 0 {
×
3556
                return false, nil, nil
×
3557
        }
×
3558

3559
        addresses := make([]net.Addr, 0, len(rows))
×
3560
        for _, addr := range rows {
×
3561
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3562
                        continue
×
3563
                }
3564

3565
                address := addr.Address.String
×
3566

×
3567
                switch dbAddressType(addr.Type.Int16) {
×
3568
                case addressTypeIPv4:
×
3569
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3570
                        if err != nil {
×
3571
                                return false, nil, nil
×
3572
                        }
×
3573
                        tcp.IP = tcp.IP.To4()
×
3574

×
3575
                        addresses = append(addresses, tcp)
×
3576

3577
                case addressTypeIPv6:
×
3578
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3579
                        if err != nil {
×
3580
                                return false, nil, nil
×
3581
                        }
×
3582
                        addresses = append(addresses, tcp)
×
3583

3584
                case addressTypeTorV3, addressTypeTorV2:
×
3585
                        service, portStr, err := net.SplitHostPort(address)
×
3586
                        if err != nil {
×
3587
                                return false, nil, fmt.Errorf("unable to "+
×
3588
                                        "split tor v3 address: %v",
×
3589
                                        addr.Address)
×
3590
                        }
×
3591

3592
                        port, err := strconv.Atoi(portStr)
×
3593
                        if err != nil {
×
3594
                                return false, nil, err
×
3595
                        }
×
3596

3597
                        addresses = append(addresses, &tor.OnionAddr{
×
3598
                                OnionService: service,
×
3599
                                Port:         port,
×
3600
                        })
×
3601

3602
                case addressTypeOpaque:
×
3603
                        opaque, err := hex.DecodeString(address)
×
3604
                        if err != nil {
×
3605
                                return false, nil, fmt.Errorf("unable to "+
×
3606
                                        "decode opaque address: %v", addr)
×
3607
                        }
×
3608

3609
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3610
                                Payload: opaque,
×
3611
                        })
×
3612

3613
                default:
×
3614
                        return false, nil, fmt.Errorf("unknown address "+
×
3615
                                "type: %v", addr.Type)
×
3616
                }
3617
        }
3618

3619
        return true, addresses, nil
×
3620
}
3621

3622
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3623
// database. This includes updating any existing types, inserting any new types,
3624
// and deleting any types that are no longer present.
3625
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3626
        nodeID int64, extraFields map[uint64][]byte) error {
×
3627

×
3628
        // Get any existing extra signed fields for the node.
×
3629
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3630
        if err != nil {
×
3631
                return err
×
3632
        }
×
3633

3634
        // Make a lookup map of the existing field types so that we can use it
3635
        // to keep track of any fields we should delete.
3636
        m := make(map[uint64]bool)
×
3637
        for _, field := range existingFields {
×
3638
                m[uint64(field.Type)] = true
×
3639
        }
×
3640

3641
        // For all the new fields, we'll upsert them and remove them from the
3642
        // map of existing fields.
3643
        for tlvType, value := range extraFields {
×
3644
                err = db.UpsertNodeExtraType(
×
3645
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3646
                                NodeID: nodeID,
×
3647
                                Type:   int64(tlvType),
×
3648
                                Value:  value,
×
3649
                        },
×
3650
                )
×
3651
                if err != nil {
×
3652
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3653
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3654
                }
×
3655

3656
                // Remove the field from the map of existing fields if it was
3657
                // present.
3658
                delete(m, tlvType)
×
3659
        }
3660

3661
        // For all the fields that are left in the map of existing fields, we'll
3662
        // delete them as they are no longer present in the new set of fields.
3663
        for tlvType := range m {
×
3664
                err = db.DeleteExtraNodeType(
×
3665
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3666
                                NodeID: nodeID,
×
3667
                                Type:   int64(tlvType),
×
3668
                        },
×
3669
                )
×
3670
                if err != nil {
×
3671
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3672
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3673
                }
×
3674
        }
3675

3676
        return nil
×
3677
}
3678

3679
// srcNodeInfo holds the information about the source node of the graph.
3680
type srcNodeInfo struct {
3681
        // id is the DB level ID of the source node entry in the "nodes" table.
3682
        id int64
3683

3684
        // pub is the public key of the source node.
3685
        pub route.Vertex
3686
}
3687

3688
// getSourceNode returns the DB node ID and pub key of the source node for the
3689
// specified protocol version.
3690
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3691
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3692

×
3693
        s.srcNodeMu.Lock()
×
3694
        defer s.srcNodeMu.Unlock()
×
3695

×
3696
        // If we already have the source node ID and pub key cached, then
×
3697
        // return them.
×
3698
        if info, ok := s.srcNodes[version]; ok {
×
3699
                return info.id, info.pub, nil
×
3700
        }
×
3701

3702
        var pubKey route.Vertex
×
3703

×
3704
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3705
        if err != nil {
×
3706
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3707
                        err)
×
3708
        }
×
3709

3710
        if len(nodes) == 0 {
×
3711
                return 0, pubKey, ErrSourceNodeNotSet
×
3712
        } else if len(nodes) > 1 {
×
3713
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3714
                        "protocol %s found", version)
×
3715
        }
×
3716

3717
        copy(pubKey[:], nodes[0].PubKey)
×
3718

×
3719
        s.srcNodes[version] = &srcNodeInfo{
×
3720
                id:  nodes[0].NodeID,
×
3721
                pub: pubKey,
×
3722
        }
×
3723

×
3724
        return nodes[0].NodeID, pubKey, nil
×
3725
}
3726

3727
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3728
// This then produces a map from TLV type to value. If the input is not a
3729
// valid TLV stream, then an error is returned.
3730
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3731
        r := bytes.NewReader(data)
×
3732

×
3733
        tlvStream, err := tlv.NewStream()
×
3734
        if err != nil {
×
3735
                return nil, err
×
3736
        }
×
3737

3738
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3739
        // pass it into the P2P decoding variant.
3740
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3741
        if err != nil {
×
3742
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3743
        }
×
3744
        if len(parsedTypes) == 0 {
×
3745
                return nil, nil
×
3746
        }
×
3747

3748
        records := make(map[uint64][]byte)
×
3749
        for k, v := range parsedTypes {
×
3750
                records[uint64(k)] = v
×
3751
        }
×
3752

3753
        return records, nil
×
3754
}
3755

3756
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3757
// channel.
3758
type dbChanInfo struct {
3759
        channelID int64
3760
        node1ID   int64
3761
        node2ID   int64
3762
}
3763

3764
// insertChannel inserts a new channel record into the database.
3765
func insertChannel(ctx context.Context, db SQLQueries,
NEW
3766
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3767

×
3768
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3769

×
3770
        // Make sure that the channel doesn't already exist. We do this
×
3771
        // explicitly instead of relying on catching a unique constraint error
×
3772
        // because relying on SQL to throw that error would abort the entire
×
3773
        // batch of transactions.
×
3774
        _, err := db.GetChannelBySCID(
×
3775
                ctx, sqlc.GetChannelBySCIDParams{
×
NEW
3776
                        Scid:    chanIDB,
×
3777
                        Version: int16(ProtocolV1),
×
3778
                },
×
3779
        )
×
3780
        if err == nil {
×
NEW
3781
                return nil, ErrEdgeAlreadyExist
×
3782
        } else if !errors.Is(err, sql.ErrNoRows) {
×
NEW
3783
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3784
        }
×
3785

3786
        // Make sure that at least a "shell" entry for each node is present in
3787
        // the nodes table.
3788
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3789
        if err != nil {
×
NEW
3790
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3791
        }
×
3792

3793
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3794
        if err != nil {
×
NEW
3795
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3796
        }
×
3797

3798
        var capacity sql.NullInt64
×
3799
        if edge.Capacity != 0 {
×
3800
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3801
        }
×
3802

3803
        createParams := sqlc.CreateChannelParams{
×
3804
                Version:     int16(ProtocolV1),
×
NEW
3805
                Scid:        chanIDB,
×
3806
                NodeID1:     node1DBID,
×
3807
                NodeID2:     node2DBID,
×
3808
                Outpoint:    edge.ChannelPoint.String(),
×
3809
                Capacity:    capacity,
×
3810
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3811
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3812
        }
×
3813

×
3814
        if edge.AuthProof != nil {
×
3815
                proof := edge.AuthProof
×
3816

×
3817
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3818
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3819
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3820
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3821
        }
×
3822

3823
        // Insert the new channel record.
3824
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3825
        if err != nil {
×
NEW
3826
                return nil, err
×
3827
        }
×
3828

3829
        // Insert any channel features.
3830
        if len(edge.Features) != 0 {
×
3831
                chanFeatures := lnwire.NewRawFeatureVector()
×
3832
                err := chanFeatures.Decode(bytes.NewReader(edge.Features))
×
3833
                if err != nil {
×
NEW
3834
                        return nil, err
×
3835
                }
×
3836

3837
                fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features)
×
3838
                for feature := range fv.Features() {
×
3839
                        err = db.InsertChannelFeature(
×
3840
                                ctx, sqlc.InsertChannelFeatureParams{
×
3841
                                        ChannelID:  dbChanID,
×
3842
                                        FeatureBit: int32(feature),
×
3843
                                },
×
3844
                        )
×
3845
                        if err != nil {
×
NEW
3846
                                return nil, fmt.Errorf("unable to insert "+
×
3847
                                        "channel(%d) feature(%v): %w", dbChanID,
×
3848
                                        feature, err)
×
3849
                        }
×
3850
                }
3851
        }
3852

3853
        // Finally, insert any extra TLV fields in the channel announcement.
3854
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3855
        if err != nil {
×
NEW
3856
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
NEW
3857
                        "data: %w", err)
×
UNCOV
3858
        }
×
3859

3860
        for tlvType, value := range extra {
×
3861
                err := db.CreateChannelExtraType(
×
3862
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3863
                                ChannelID: dbChanID,
×
3864
                                Type:      int64(tlvType),
×
3865
                                Value:     value,
×
3866
                        },
×
3867
                )
×
3868
                if err != nil {
×
NEW
3869
                        return nil, fmt.Errorf("unable to upsert "+
×
NEW
3870
                                "channel(%d) extra signed field(%v): %w",
×
NEW
3871
                                edge.ChannelID, tlvType, err)
×
UNCOV
3872
                }
×
3873
        }
3874

NEW
3875
        return &dbChanInfo{
×
NEW
3876
                channelID: dbChanID,
×
NEW
3877
                node1ID:   node1DBID,
×
NEW
3878
                node2ID:   node2DBID,
×
NEW
3879
        }, nil
×
3880
}
3881

3882
// maybeCreateShellNode checks if a shell node entry exists for the
3883
// given public key. If it does not exist, then a new shell node entry is
3884
// created. The ID of the node is returned. A shell node only has a protocol
3885
// version and public key persisted.
3886
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3887
        pubKey route.Vertex) (int64, error) {
×
3888

×
3889
        dbNode, err := db.GetNodeByPubKey(
×
3890
                ctx, sqlc.GetNodeByPubKeyParams{
×
3891
                        PubKey:  pubKey[:],
×
3892
                        Version: int16(ProtocolV1),
×
3893
                },
×
3894
        )
×
3895
        // The node exists. Return the ID.
×
3896
        if err == nil {
×
3897
                return dbNode.ID, nil
×
3898
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3899
                return 0, err
×
3900
        }
×
3901

3902
        // Otherwise, the node does not exist, so we create a shell entry for
3903
        // it.
3904
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3905
                Version: int16(ProtocolV1),
×
3906
                PubKey:  pubKey[:],
×
3907
        })
×
3908
        if err != nil {
×
3909
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3910
        }
×
3911

3912
        return id, nil
×
3913
}
3914

3915
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3916
// the database. This includes deleting any existing types and then inserting
3917
// the new types.
3918
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3919
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3920

×
3921
        // Delete all existing extra signed fields for the channel policy.
×
3922
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3923
        if err != nil {
×
3924
                return fmt.Errorf("unable to delete "+
×
3925
                        "existing policy extra signed fields for policy %d: %w",
×
3926
                        chanPolicyID, err)
×
3927
        }
×
3928

3929
        // Insert all new extra signed fields for the channel policy.
3930
        for tlvType, value := range extraFields {
×
3931
                err = db.InsertChanPolicyExtraType(
×
3932
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3933
                                ChannelPolicyID: chanPolicyID,
×
3934
                                Type:            int64(tlvType),
×
3935
                                Value:           value,
×
3936
                        },
×
3937
                )
×
3938
                if err != nil {
×
3939
                        return fmt.Errorf("unable to insert "+
×
3940
                                "channel_policy(%d) extra signed field(%v): %w",
×
3941
                                chanPolicyID, tlvType, err)
×
3942
                }
×
3943
        }
3944

3945
        return nil
×
3946
}
3947

3948
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3949
// provided dbChanRow and also fetches any other required information
3950
// to construct the edge info.
3951
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
3952
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1,
3953
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3954

×
3955
        if dbChan.Version != int16(ProtocolV1) {
×
3956
                return nil, fmt.Errorf("unsupported channel version: %d",
×
3957
                        dbChan.Version)
×
3958
        }
×
3959

3960
        fv, extras, err := getChanFeaturesAndExtras(
×
3961
                ctx, db, dbChanID,
×
3962
        )
×
3963
        if err != nil {
×
3964
                return nil, err
×
3965
        }
×
3966

3967
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
3968
        if err != nil {
×
3969
                return nil, err
×
3970
        }
×
3971

3972
        var featureBuf bytes.Buffer
×
3973
        if err := fv.Encode(&featureBuf); err != nil {
×
3974
                return nil, fmt.Errorf("unable to encode features: %w", err)
×
3975
        }
×
3976

3977
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3978
        if err != nil {
×
3979
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3980
                        "fields: %w", err)
×
3981
        }
×
3982
        if recs == nil {
×
3983
                recs = make([]byte, 0)
×
3984
        }
×
3985

3986
        var btcKey1, btcKey2 route.Vertex
×
3987
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
3988
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
3989

×
3990
        channel := &models.ChannelEdgeInfo{
×
3991
                ChainHash:        chain,
×
3992
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
3993
                NodeKey1Bytes:    node1,
×
3994
                NodeKey2Bytes:    node2,
×
3995
                BitcoinKey1Bytes: btcKey1,
×
3996
                BitcoinKey2Bytes: btcKey2,
×
3997
                ChannelPoint:     *op,
×
3998
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
3999
                Features:         featureBuf.Bytes(),
×
4000
                ExtraOpaqueData:  recs,
×
4001
        }
×
4002

×
4003
        // We always set all the signatures at the same time, so we can
×
4004
        // safely check if one signature is present to determine if we have the
×
4005
        // rest of the signatures for the auth proof.
×
4006
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4007
                channel.AuthProof = &models.ChannelAuthProof{
×
4008
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4009
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4010
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4011
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4012
                }
×
4013
        }
×
4014

4015
        return channel, nil
×
4016
}
4017

4018
// buildNodeVertices is a helper that converts raw node public keys
4019
// into route.Vertex instances.
4020
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4021
        route.Vertex, error) {
×
4022

×
4023
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4024
        if err != nil {
×
4025
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4026
                        "create vertex from node1 pubkey: %w", err)
×
4027
        }
×
4028

4029
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4030
        if err != nil {
×
4031
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4032
                        "create vertex from node2 pubkey: %w", err)
×
4033
        }
×
4034

4035
        return node1Vertex, node2Vertex, nil
×
4036
}
4037

4038
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
4039
// for a channel with the given ID.
4040
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
4041
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
4042

×
4043
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
4044
        if err != nil {
×
4045
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
4046
                        "features and extras: %w", err)
×
4047
        }
×
4048

4049
        var (
×
4050
                fv     = lnwire.EmptyFeatureVector()
×
4051
                extras = make(map[uint64][]byte)
×
4052
        )
×
4053
        for _, row := range rows {
×
4054
                if row.IsFeature {
×
4055
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
4056

×
4057
                        continue
×
4058
                }
4059

4060
                tlvType, ok := row.ExtraKey.(int64)
×
4061
                if !ok {
×
4062
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4063
                                "TLV type: %T", row.ExtraKey)
×
4064
                }
×
4065

4066
                valueBytes, ok := row.Value.([]byte)
×
4067
                if !ok {
×
4068
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4069
                                "Value: %T", row.Value)
×
4070
                }
×
4071

4072
                extras[uint64(tlvType)] = valueBytes
×
4073
        }
4074

4075
        return fv, extras, nil
×
4076
}
4077

4078
// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves
4079
// all the extra info required to build the complete models.ChannelEdgePolicy
4080
// types. It returns two policies, which may be nil if the provided
4081
// sqlc.ChannelPolicy records are nil.
4082
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4083
        dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1,
4084
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4085
        *models.ChannelEdgePolicy, error) {
×
4086

×
4087
        if dbPol1 == nil && dbPol2 == nil {
×
4088
                return nil, nil, nil
×
4089
        }
×
4090

4091
        var (
×
4092
                policy1ID int64
×
4093
                policy2ID int64
×
4094
        )
×
4095
        if dbPol1 != nil {
×
4096
                policy1ID = dbPol1.ID
×
4097
        }
×
4098
        if dbPol2 != nil {
×
4099
                policy2ID = dbPol2.ID
×
4100
        }
×
4101
        rows, err := db.GetChannelPolicyExtraTypes(
×
4102
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
4103
                        ID:   policy1ID,
×
4104
                        ID_2: policy2ID,
×
4105
                },
×
4106
        )
×
4107
        if err != nil {
×
4108
                return nil, nil, err
×
4109
        }
×
4110

4111
        var (
×
4112
                dbPol1Extras = make(map[uint64][]byte)
×
4113
                dbPol2Extras = make(map[uint64][]byte)
×
4114
        )
×
4115
        for _, row := range rows {
×
4116
                switch row.PolicyID {
×
4117
                case policy1ID:
×
4118
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
4119
                case policy2ID:
×
4120
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
4121
                default:
×
4122
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
4123
                                "in row: %v", row.PolicyID, row)
×
4124
                }
4125
        }
4126

4127
        var pol1, pol2 *models.ChannelEdgePolicy
×
4128
        if dbPol1 != nil {
×
4129
                pol1, err = buildChanPolicy(
×
4130
                        *dbPol1, channelID, dbPol1Extras, node2, true,
×
4131
                )
×
4132
                if err != nil {
×
4133
                        return nil, nil, err
×
4134
                }
×
4135
        }
4136
        if dbPol2 != nil {
×
4137
                pol2, err = buildChanPolicy(
×
4138
                        *dbPol2, channelID, dbPol2Extras, node1, false,
×
4139
                )
×
4140
                if err != nil {
×
4141
                        return nil, nil, err
×
4142
                }
×
4143
        }
4144

4145
        return pol1, pol2, nil
×
4146
}
4147

4148
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4149
// provided sqlc.ChannelPolicy and other required information.
4150
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
4151
        extras map[uint64][]byte, toNode route.Vertex,
4152
        isNode1 bool) (*models.ChannelEdgePolicy, error) {
×
4153

×
4154
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4155
        if err != nil {
×
4156
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4157
                        "fields: %w", err)
×
4158
        }
×
4159

4160
        var msgFlags lnwire.ChanUpdateMsgFlags
×
4161
        if dbPolicy.MaxHtlcMsat.Valid {
×
4162
                msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc
×
4163
        }
×
4164

4165
        var chanFlags lnwire.ChanUpdateChanFlags
×
4166
        if !isNode1 {
×
4167
                chanFlags |= lnwire.ChanUpdateDirection
×
4168
        }
×
4169
        if dbPolicy.Disabled.Bool {
×
4170
                chanFlags |= lnwire.ChanUpdateDisabled
×
4171
        }
×
4172

4173
        var inboundFee fn.Option[lnwire.Fee]
×
4174
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4175
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4176

×
4177
                inboundFee = fn.Some(lnwire.Fee{
×
4178
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4179
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4180
                })
×
4181
        }
×
4182

4183
        return &models.ChannelEdgePolicy{
×
4184
                SigBytes:  dbPolicy.Signature,
×
4185
                ChannelID: channelID,
×
4186
                LastUpdate: time.Unix(
×
4187
                        dbPolicy.LastUpdate.Int64, 0,
×
4188
                ),
×
4189
                MessageFlags:  msgFlags,
×
4190
                ChannelFlags:  chanFlags,
×
4191
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4192
                MinHTLC: lnwire.MilliSatoshi(
×
4193
                        dbPolicy.MinHtlcMsat,
×
4194
                ),
×
4195
                MaxHTLC: lnwire.MilliSatoshi(
×
4196
                        dbPolicy.MaxHtlcMsat.Int64,
×
4197
                ),
×
4198
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4199
                        dbPolicy.BaseFeeMsat,
×
4200
                ),
×
4201
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4202
                ToNode:                    toNode,
×
4203
                InboundFee:                inboundFee,
×
4204
                ExtraOpaqueData:           recs,
×
4205
        }, nil
×
4206
}
4207

4208
// buildNodes builds the models.LightningNode instances for the
4209
// given row which is expected to be a sqlc type that contains node information.
4210
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4211
        dbNode2 sqlc.Node) (*models.LightningNode, *models.LightningNode,
4212
        error) {
×
4213

×
4214
        node1, err := buildNode(ctx, db, &dbNode1)
×
4215
        if err != nil {
×
4216
                return nil, nil, err
×
4217
        }
×
4218

4219
        node2, err := buildNode(ctx, db, &dbNode2)
×
4220
        if err != nil {
×
4221
                return nil, nil, err
×
4222
        }
×
4223

4224
        return node1, node2, nil
×
4225
}
4226

4227
// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give
4228
// row which is expected to be a sqlc type that contains channel policy
4229
// information. It returns two policies, which may be nil if the policy
4230
// information is not present in the row.
4231
//
4232
//nolint:ll,dupl,funlen
4233
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
4234
        error) {
×
4235

×
4236
        var policy1, policy2 *sqlc.ChannelPolicy
×
4237
        switch r := row.(type) {
×
4238
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4239
                if r.Policy1ID.Valid {
×
4240
                        policy1 = &sqlc.ChannelPolicy{
×
4241
                                ID:                      r.Policy1ID.Int64,
×
4242
                                Version:                 r.Policy1Version.Int16,
×
4243
                                ChannelID:               r.Channel.ID,
×
4244
                                NodeID:                  r.Policy1NodeID.Int64,
×
4245
                                Timelock:                r.Policy1Timelock.Int32,
×
4246
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4247
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4248
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4249
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4250
                                LastUpdate:              r.Policy1LastUpdate,
×
4251
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4252
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4253
                                Disabled:                r.Policy1Disabled,
×
4254
                                Signature:               r.Policy1Signature,
×
4255
                        }
×
4256
                }
×
4257
                if r.Policy2ID.Valid {
×
4258
                        policy2 = &sqlc.ChannelPolicy{
×
4259
                                ID:                      r.Policy2ID.Int64,
×
4260
                                Version:                 r.Policy2Version.Int16,
×
4261
                                ChannelID:               r.Channel.ID,
×
4262
                                NodeID:                  r.Policy2NodeID.Int64,
×
4263
                                Timelock:                r.Policy2Timelock.Int32,
×
4264
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4265
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4266
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4267
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4268
                                LastUpdate:              r.Policy2LastUpdate,
×
4269
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4270
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4271
                                Disabled:                r.Policy2Disabled,
×
4272
                                Signature:               r.Policy2Signature,
×
4273
                        }
×
4274
                }
×
4275

4276
                return policy1, policy2, nil
×
4277

4278
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4279
                if r.Policy1ID.Valid {
×
4280
                        policy1 = &sqlc.ChannelPolicy{
×
4281
                                ID:                      r.Policy1ID.Int64,
×
4282
                                Version:                 r.Policy1Version.Int16,
×
4283
                                ChannelID:               r.Channel.ID,
×
4284
                                NodeID:                  r.Policy1NodeID.Int64,
×
4285
                                Timelock:                r.Policy1Timelock.Int32,
×
4286
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4287
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4288
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4289
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4290
                                LastUpdate:              r.Policy1LastUpdate,
×
4291
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4292
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4293
                                Disabled:                r.Policy1Disabled,
×
4294
                                Signature:               r.Policy1Signature,
×
4295
                        }
×
4296
                }
×
4297
                if r.Policy2ID.Valid {
×
4298
                        policy2 = &sqlc.ChannelPolicy{
×
4299
                                ID:                      r.Policy2ID.Int64,
×
4300
                                Version:                 r.Policy2Version.Int16,
×
4301
                                ChannelID:               r.Channel.ID,
×
4302
                                NodeID:                  r.Policy2NodeID.Int64,
×
4303
                                Timelock:                r.Policy2Timelock.Int32,
×
4304
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4305
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4306
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4307
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4308
                                LastUpdate:              r.Policy2LastUpdate,
×
4309
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4310
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4311
                                Disabled:                r.Policy2Disabled,
×
4312
                                Signature:               r.Policy2Signature,
×
4313
                        }
×
4314
                }
×
4315

4316
                return policy1, policy2, nil
×
4317

4318
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4319
                if r.Policy1ID.Valid {
×
4320
                        policy1 = &sqlc.ChannelPolicy{
×
4321
                                ID:                      r.Policy1ID.Int64,
×
4322
                                Version:                 r.Policy1Version.Int16,
×
4323
                                ChannelID:               r.Channel.ID,
×
4324
                                NodeID:                  r.Policy1NodeID.Int64,
×
4325
                                Timelock:                r.Policy1Timelock.Int32,
×
4326
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4327
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4328
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4329
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4330
                                LastUpdate:              r.Policy1LastUpdate,
×
4331
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4332
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4333
                                Disabled:                r.Policy1Disabled,
×
4334
                                Signature:               r.Policy1Signature,
×
4335
                        }
×
4336
                }
×
4337
                if r.Policy2ID.Valid {
×
4338
                        policy2 = &sqlc.ChannelPolicy{
×
4339
                                ID:                      r.Policy2ID.Int64,
×
4340
                                Version:                 r.Policy2Version.Int16,
×
4341
                                ChannelID:               r.Channel.ID,
×
4342
                                NodeID:                  r.Policy2NodeID.Int64,
×
4343
                                Timelock:                r.Policy2Timelock.Int32,
×
4344
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4345
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4346
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4347
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4348
                                LastUpdate:              r.Policy2LastUpdate,
×
4349
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4350
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4351
                                Disabled:                r.Policy2Disabled,
×
4352
                                Signature:               r.Policy2Signature,
×
4353
                        }
×
4354
                }
×
4355

4356
                return policy1, policy2, nil
×
4357

4358
        case sqlc.ListChannelsByNodeIDRow:
×
4359
                if r.Policy1ID.Valid {
×
4360
                        policy1 = &sqlc.ChannelPolicy{
×
4361
                                ID:                      r.Policy1ID.Int64,
×
4362
                                Version:                 r.Policy1Version.Int16,
×
4363
                                ChannelID:               r.Channel.ID,
×
4364
                                NodeID:                  r.Policy1NodeID.Int64,
×
4365
                                Timelock:                r.Policy1Timelock.Int32,
×
4366
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4367
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4368
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4369
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4370
                                LastUpdate:              r.Policy1LastUpdate,
×
4371
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4372
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4373
                                Disabled:                r.Policy1Disabled,
×
4374
                                Signature:               r.Policy1Signature,
×
4375
                        }
×
4376
                }
×
4377
                if r.Policy2ID.Valid {
×
4378
                        policy2 = &sqlc.ChannelPolicy{
×
4379
                                ID:                      r.Policy2ID.Int64,
×
4380
                                Version:                 r.Policy2Version.Int16,
×
4381
                                ChannelID:               r.Channel.ID,
×
4382
                                NodeID:                  r.Policy2NodeID.Int64,
×
4383
                                Timelock:                r.Policy2Timelock.Int32,
×
4384
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4385
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4386
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4387
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4388
                                LastUpdate:              r.Policy2LastUpdate,
×
4389
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4390
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4391
                                Disabled:                r.Policy2Disabled,
×
4392
                                Signature:               r.Policy2Signature,
×
4393
                        }
×
4394
                }
×
4395

4396
                return policy1, policy2, nil
×
4397

4398
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4399
                if r.Policy1ID.Valid {
×
4400
                        policy1 = &sqlc.ChannelPolicy{
×
4401
                                ID:                      r.Policy1ID.Int64,
×
4402
                                Version:                 r.Policy1Version.Int16,
×
4403
                                ChannelID:               r.Channel.ID,
×
4404
                                NodeID:                  r.Policy1NodeID.Int64,
×
4405
                                Timelock:                r.Policy1Timelock.Int32,
×
4406
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4407
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4408
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4409
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4410
                                LastUpdate:              r.Policy1LastUpdate,
×
4411
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4412
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4413
                                Disabled:                r.Policy1Disabled,
×
4414
                                Signature:               r.Policy1Signature,
×
4415
                        }
×
4416
                }
×
4417
                if r.Policy2ID.Valid {
×
4418
                        policy2 = &sqlc.ChannelPolicy{
×
4419
                                ID:                      r.Policy2ID.Int64,
×
4420
                                Version:                 r.Policy2Version.Int16,
×
4421
                                ChannelID:               r.Channel.ID,
×
4422
                                NodeID:                  r.Policy2NodeID.Int64,
×
4423
                                Timelock:                r.Policy2Timelock.Int32,
×
4424
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4425
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4426
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4427
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4428
                                LastUpdate:              r.Policy2LastUpdate,
×
4429
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4430
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4431
                                Disabled:                r.Policy2Disabled,
×
4432
                                Signature:               r.Policy2Signature,
×
4433
                        }
×
4434
                }
×
4435

4436
                return policy1, policy2, nil
×
4437
        default:
×
4438
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4439
                        "extractChannelPolicies: %T", r)
×
4440
        }
4441
}
4442

4443
// channelIDToBytes converts a channel ID (SCID) to a byte array
4444
// representation.
NEW
4445
func channelIDToBytes(channelID uint64) []byte {
×
4446
        var chanIDB [8]byte
×
4447
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4448

×
NEW
4449
        return chanIDB[:]
×
4450
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc