• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15973895448

30 Jun 2025 01:12PM UTC coverage: 67.627% (+0.05%) from 67.577%
15973895448

Pull #10007

github

web-flow
Merge a10fe7711 into 01dfee6f8
Pull Request #10007: graph/db: explicitly store bitfields for channel_update message & channel flags

0 of 65 new or added lines in 3 files covered. (0.0%)

40 existing lines in 12 files now uncovered.

135274 of 200031 relevant lines covered (67.63%)

21833.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        GetUnconnectedNodes(ctx context.Context) ([]sqlc.GetUnconnectedNodesRow, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
72
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
73
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
74

75
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
96
        GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error)
97
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
98
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
99
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
100
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
101
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
102
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
103
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
104
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
105
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
106
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
107
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
108
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
109
        DeleteChannel(ctx context.Context, id int64) error
110

111
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
112
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
113

114
        /*
115
                Channel Policy table queries.
116
        */
117
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
118
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.ChannelPolicy, error)
119
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
120

121
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
122
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
123
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
124

125
        /*
126
                Zombie index queries.
127
        */
128
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
129
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error)
130
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
131
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
132
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
133

134
        /*
135
                Prune log table queries.
136
        */
137
        GetPruneTip(ctx context.Context) (sqlc.PruneLog, error)
138
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
139
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
140

141
        /*
142
                Closed SCID table queries.
143
        */
144
        InsertClosedChannel(ctx context.Context, scid []byte) error
145
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
146
}
147

148
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
149
// database operations.
150
type BatchedSQLQueries interface {
151
        SQLQueries
152
        sqldb.BatchedTx[SQLQueries]
153
}
154

155
// SQLStore is an implementation of the V1Store interface that uses a SQL
156
// database as the backend.
157
//
158
// NOTE: currently, this temporarily embeds the KVStore struct so that we can
159
// implement the V1Store interface incrementally. For any method not
160
// implemented,  things will fall back to the KVStore. This is ONLY the case
161
// for the time being while this struct is purely used in unit tests only.
162
type SQLStore struct {
163
        cfg *SQLStoreConfig
164
        db  BatchedSQLQueries
165

166
        // cacheMu guards all caches (rejectCache and chanCache). If
167
        // this mutex will be acquired at the same time as the DB mutex then
168
        // the cacheMu MUST be acquired first to prevent deadlock.
169
        cacheMu     sync.RWMutex
170
        rejectCache *rejectCache
171
        chanCache   *channelCache
172

173
        chanScheduler batch.Scheduler[SQLQueries]
174
        nodeScheduler batch.Scheduler[SQLQueries]
175

176
        srcNodes  map[ProtocolVersion]*srcNodeInfo
177
        srcNodeMu sync.Mutex
178
}
179

180
// A compile-time assertion to ensure that SQLStore implements the V1Store
181
// interface.
182
var _ V1Store = (*SQLStore)(nil)
183

184
// SQLStoreConfig holds the configuration for the SQLStore.
185
type SQLStoreConfig struct {
186
        // ChainHash is the genesis hash for the chain that all the gossip
187
        // messages in this store are aimed at.
188
        ChainHash chainhash.Hash
189
}
190

191
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
192
// storage backend.
193
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
194
        options ...StoreOptionModifier) (*SQLStore, error) {
×
195

×
196
        opts := DefaultOptions()
×
197
        for _, o := range options {
×
198
                o(opts)
×
199
        }
×
200

201
        if opts.NoMigration {
×
202
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
203
                        "supported for SQL stores")
×
204
        }
×
205

206
        s := &SQLStore{
×
207
                cfg:         cfg,
×
208
                db:          db,
×
209
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
210
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
211
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
212
        }
×
213

×
214
        s.chanScheduler = batch.NewTimeScheduler(
×
215
                db, &s.cacheMu, opts.BatchCommitInterval,
×
216
        )
×
217
        s.nodeScheduler = batch.NewTimeScheduler(
×
218
                db, nil, opts.BatchCommitInterval,
×
219
        )
×
220

×
221
        return s, nil
×
222
}
223

224
// AddLightningNode adds a vertex/node to the graph database. If the node is not
225
// in the database from before, this will add a new, unconnected one to the
226
// graph. If it is present from before, this will update that node's
227
// information.
228
//
229
// NOTE: part of the V1Store interface.
230
func (s *SQLStore) AddLightningNode(ctx context.Context,
231
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
232

×
233
        r := &batch.Request[SQLQueries]{
×
234
                Opts: batch.NewSchedulerOptions(opts...),
×
235
                Do: func(queries SQLQueries) error {
×
236
                        _, err := upsertNode(ctx, queries, node)
×
237
                        return err
×
238
                },
×
239
        }
240

241
        return s.nodeScheduler.Execute(ctx, r)
×
242
}
243

244
// FetchLightningNode attempts to look up a target node by its identity public
245
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
246
// returned.
247
//
248
// NOTE: part of the V1Store interface.
249
func (s *SQLStore) FetchLightningNode(ctx context.Context,
250
        pubKey route.Vertex) (*models.LightningNode, error) {
×
251

×
252
        var node *models.LightningNode
×
253
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
254
                var err error
×
255
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
256

×
257
                return err
×
258
        }, sqldb.NoOpReset)
×
259
        if err != nil {
×
260
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
261
        }
×
262

263
        return node, nil
×
264
}
265

266
// HasLightningNode determines if the graph has a vertex identified by the
267
// target node identity public key. If the node exists in the database, a
268
// timestamp of when the data for the node was lasted updated is returned along
269
// with a true boolean. Otherwise, an empty time.Time is returned with a false
270
// boolean.
271
//
272
// NOTE: part of the V1Store interface.
273
func (s *SQLStore) HasLightningNode(ctx context.Context,
274
        pubKey [33]byte) (time.Time, bool, error) {
×
275

×
276
        var (
×
277
                exists     bool
×
278
                lastUpdate time.Time
×
279
        )
×
280
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
281
                dbNode, err := db.GetNodeByPubKey(
×
282
                        ctx, sqlc.GetNodeByPubKeyParams{
×
283
                                Version: int16(ProtocolV1),
×
284
                                PubKey:  pubKey[:],
×
285
                        },
×
286
                )
×
287
                if errors.Is(err, sql.ErrNoRows) {
×
288
                        return nil
×
289
                } else if err != nil {
×
290
                        return fmt.Errorf("unable to fetch node: %w", err)
×
291
                }
×
292

293
                exists = true
×
294

×
295
                if dbNode.LastUpdate.Valid {
×
296
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
297
                }
×
298

299
                return nil
×
300
        }, sqldb.NoOpReset)
301
        if err != nil {
×
302
                return time.Time{}, false,
×
303
                        fmt.Errorf("unable to fetch node: %w", err)
×
304
        }
×
305

306
        return lastUpdate, exists, nil
×
307
}
308

309
// AddrsForNode returns all known addresses for the target node public key
310
// that the graph DB is aware of. The returned boolean indicates if the
311
// given node is unknown to the graph DB or not.
312
//
313
// NOTE: part of the V1Store interface.
314
func (s *SQLStore) AddrsForNode(ctx context.Context,
315
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
316

×
317
        var (
×
318
                addresses []net.Addr
×
319
                known     bool
×
320
        )
×
321
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
322
                var err error
×
323
                known, addresses, err = getNodeAddresses(
×
324
                        ctx, db, nodePub.SerializeCompressed(),
×
325
                )
×
326
                if err != nil {
×
327
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
328
                                err)
×
329
                }
×
330

331
                return nil
×
332
        }, sqldb.NoOpReset)
333
        if err != nil {
×
334
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
335
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
336
        }
×
337

338
        return known, addresses, nil
×
339
}
340

341
// DeleteLightningNode starts a new database transaction to remove a vertex/node
342
// from the database according to the node's public key.
343
//
344
// NOTE: part of the V1Store interface.
345
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
346
        pubKey route.Vertex) error {
×
347

×
348
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
349
                res, err := db.DeleteNodeByPubKey(
×
350
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
351
                                Version: int16(ProtocolV1),
×
352
                                PubKey:  pubKey[:],
×
353
                        },
×
354
                )
×
355
                if err != nil {
×
356
                        return err
×
357
                }
×
358

359
                rows, err := res.RowsAffected()
×
360
                if err != nil {
×
361
                        return err
×
362
                }
×
363

364
                if rows == 0 {
×
365
                        return ErrGraphNodeNotFound
×
366
                } else if rows > 1 {
×
367
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
368
                }
×
369

370
                return err
×
371
        }, sqldb.NoOpReset)
372
        if err != nil {
×
373
                return fmt.Errorf("unable to delete node: %w", err)
×
374
        }
×
375

376
        return nil
×
377
}
378

379
// FetchNodeFeatures returns the features of the given node. If no features are
380
// known for the node, an empty feature vector is returned.
381
//
382
// NOTE: this is part of the graphdb.NodeTraverser interface.
383
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
384
        *lnwire.FeatureVector, error) {
×
385

×
386
        ctx := context.TODO()
×
387

×
388
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
389
}
×
390

391
// DisabledChannelIDs returns the channel ids of disabled channels.
392
// A channel is disabled when two of the associated ChanelEdgePolicies
393
// have their disabled bit on.
394
//
395
// NOTE: part of the V1Store interface.
396
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
397
        var (
×
398
                ctx     = context.TODO()
×
399
                chanIDs []uint64
×
400
        )
×
401
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
402
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
403
                if err != nil {
×
404
                        return fmt.Errorf("unable to fetch disabled "+
×
405
                                "channels: %w", err)
×
406
                }
×
407

408
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
409

×
410
                return nil
×
411
        }, sqldb.NoOpReset)
412
        if err != nil {
×
413
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
414
                        err)
×
415
        }
×
416

417
        return chanIDs, nil
×
418
}
419

420
// LookupAlias attempts to return the alias as advertised by the target node.
421
//
422
// NOTE: part of the V1Store interface.
423
func (s *SQLStore) LookupAlias(ctx context.Context,
424
        pub *btcec.PublicKey) (string, error) {
×
425

×
426
        var alias string
×
427
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
428
                dbNode, err := db.GetNodeByPubKey(
×
429
                        ctx, sqlc.GetNodeByPubKeyParams{
×
430
                                Version: int16(ProtocolV1),
×
431
                                PubKey:  pub.SerializeCompressed(),
×
432
                        },
×
433
                )
×
434
                if errors.Is(err, sql.ErrNoRows) {
×
435
                        return ErrNodeAliasNotFound
×
436
                } else if err != nil {
×
437
                        return fmt.Errorf("unable to fetch node: %w", err)
×
438
                }
×
439

440
                if !dbNode.Alias.Valid {
×
441
                        return ErrNodeAliasNotFound
×
442
                }
×
443

444
                alias = dbNode.Alias.String
×
445

×
446
                return nil
×
447
        }, sqldb.NoOpReset)
448
        if err != nil {
×
449
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
450
        }
×
451

452
        return alias, nil
×
453
}
454

455
// SourceNode returns the source node of the graph. The source node is treated
456
// as the center node within a star-graph. This method may be used to kick off
457
// a path finding algorithm in order to explore the reachability of another
458
// node based off the source node.
459
//
460
// NOTE: part of the V1Store interface.
461
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
462
        error) {
×
463

×
464
        var node *models.LightningNode
×
465
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
466
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
467
                if err != nil {
×
468
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
469
                                err)
×
470
                }
×
471

472
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
473

×
474
                return err
×
475
        }, sqldb.NoOpReset)
476
        if err != nil {
×
477
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
478
        }
×
479

480
        return node, nil
×
481
}
482

483
// SetSourceNode sets the source node within the graph database. The source
484
// node is to be used as the center of a star-graph within path finding
485
// algorithms.
486
//
487
// NOTE: part of the V1Store interface.
488
func (s *SQLStore) SetSourceNode(ctx context.Context,
489
        node *models.LightningNode) error {
×
490

×
491
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
492
                id, err := upsertNode(ctx, db, node)
×
493
                if err != nil {
×
494
                        return fmt.Errorf("unable to upsert source node: %w",
×
495
                                err)
×
496
                }
×
497

498
                // Make sure that if a source node for this version is already
499
                // set, then the ID is the same as the one we are about to set.
500
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
501
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
502
                        return fmt.Errorf("unable to fetch source node: %w",
×
503
                                err)
×
504
                } else if err == nil {
×
505
                        if dbSourceNodeID != id {
×
506
                                return fmt.Errorf("v1 source node already "+
×
507
                                        "set to a different node: %d vs %d",
×
508
                                        dbSourceNodeID, id)
×
509
                        }
×
510

511
                        return nil
×
512
                }
513

514
                return db.AddSourceNode(ctx, id)
×
515
        }, sqldb.NoOpReset)
516
}
517

518
// NodeUpdatesInHorizon returns all the known lightning node which have an
519
// update timestamp within the passed range. This method can be used by two
520
// nodes to quickly determine if they have the same set of up to date node
521
// announcements.
522
//
523
// NOTE: This is part of the V1Store interface.
524
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
525
        endTime time.Time) ([]models.LightningNode, error) {
×
526

×
527
        ctx := context.TODO()
×
528

×
529
        var nodes []models.LightningNode
×
530
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
531
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
532
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
533
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
534
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
535
                        },
×
536
                )
×
537
                if err != nil {
×
538
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
539
                }
×
540

541
                for _, dbNode := range dbNodes {
×
542
                        node, err := buildNode(ctx, db, &dbNode)
×
543
                        if err != nil {
×
544
                                return fmt.Errorf("unable to build node: %w",
×
545
                                        err)
×
546
                        }
×
547

548
                        nodes = append(nodes, *node)
×
549
                }
550

551
                return nil
×
552
        }, sqldb.NoOpReset)
553
        if err != nil {
×
554
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
555
        }
×
556

557
        return nodes, nil
×
558
}
559

560
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
561
// undirected edge from the two target nodes are created. The information stored
562
// denotes the static attributes of the channel, such as the channelID, the keys
563
// involved in creation of the channel, and the set of features that the channel
564
// supports. The chanPoint and chanID are used to uniquely identify the edge
565
// globally within the database.
566
//
567
// NOTE: part of the V1Store interface.
568
func (s *SQLStore) AddChannelEdge(ctx context.Context,
569
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
570

×
571
        var alreadyExists bool
×
572
        r := &batch.Request[SQLQueries]{
×
573
                Opts: batch.NewSchedulerOptions(opts...),
×
574
                Reset: func() {
×
575
                        alreadyExists = false
×
576
                },
×
577
                Do: func(tx SQLQueries) error {
×
578
                        err := insertChannel(ctx, tx, edge)
×
579

×
580
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
581
                        // succeed, but propagate the error via local state.
×
582
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
583
                                alreadyExists = true
×
584
                                return nil
×
585
                        }
×
586

587
                        return err
×
588
                },
589
                OnCommit: func(err error) error {
×
590
                        switch {
×
591
                        case err != nil:
×
592
                                return err
×
593
                        case alreadyExists:
×
594
                                return ErrEdgeAlreadyExist
×
595
                        default:
×
596
                                s.rejectCache.remove(edge.ChannelID)
×
597
                                s.chanCache.remove(edge.ChannelID)
×
598
                                return nil
×
599
                        }
600
                },
601
        }
602

603
        return s.chanScheduler.Execute(ctx, r)
×
604
}
605

606
// HighestChanID returns the "highest" known channel ID in the channel graph.
607
// This represents the "newest" channel from the PoV of the chain. This method
608
// can be used by peers to quickly determine if their graphs are in sync.
609
//
610
// NOTE: This is part of the V1Store interface.
611
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
612
        var highestChanID uint64
×
613
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
614
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
615
                if errors.Is(err, sql.ErrNoRows) {
×
616
                        return nil
×
617
                } else if err != nil {
×
618
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
619
                                err)
×
620
                }
×
621

622
                highestChanID = byteOrder.Uint64(chanID)
×
623

×
624
                return nil
×
625
        }, sqldb.NoOpReset)
626
        if err != nil {
×
627
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
628
        }
×
629

630
        return highestChanID, nil
×
631
}
632

633
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
634
// within the database for the referenced channel. The `flags` attribute within
635
// the ChannelEdgePolicy determines which of the directed edges are being
636
// updated. If the flag is 1, then the first node's information is being
637
// updated, otherwise it's the second node's information. The node ordering is
638
// determined by the lexicographical ordering of the identity public keys of the
639
// nodes on either side of the channel.
640
//
641
// NOTE: part of the V1Store interface.
642
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
643
        edge *models.ChannelEdgePolicy,
644
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
645

×
646
        var (
×
647
                isUpdate1    bool
×
648
                edgeNotFound bool
×
649
                from, to     route.Vertex
×
650
        )
×
651

×
652
        r := &batch.Request[SQLQueries]{
×
653
                Opts: batch.NewSchedulerOptions(opts...),
×
654
                Reset: func() {
×
655
                        isUpdate1 = false
×
656
                        edgeNotFound = false
×
657
                },
×
658
                Do: func(tx SQLQueries) error {
×
659
                        var err error
×
660
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
661
                                ctx, tx, edge,
×
662
                        )
×
663
                        if err != nil {
×
664
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
665
                        }
×
666

667
                        // Silence ErrEdgeNotFound so that the batch can
668
                        // succeed, but propagate the error via local state.
669
                        if errors.Is(err, ErrEdgeNotFound) {
×
670
                                edgeNotFound = true
×
671
                                return nil
×
672
                        }
×
673

674
                        return err
×
675
                },
676
                OnCommit: func(err error) error {
×
677
                        switch {
×
678
                        case err != nil:
×
679
                                return err
×
680
                        case edgeNotFound:
×
681
                                return ErrEdgeNotFound
×
682
                        default:
×
683
                                s.updateEdgeCache(edge, isUpdate1)
×
684
                                return nil
×
685
                        }
686
                },
687
        }
688

689
        err := s.chanScheduler.Execute(ctx, r)
×
690

×
691
        return from, to, err
×
692
}
693

694
// updateEdgeCache updates our reject and channel caches with the new
695
// edge policy information.
696
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
697
        isUpdate1 bool) {
×
698

×
699
        // If an entry for this channel is found in reject cache, we'll modify
×
700
        // the entry with the updated timestamp for the direction that was just
×
701
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
702
        // during the next query for this edge.
×
703
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
704
                if isUpdate1 {
×
705
                        entry.upd1Time = e.LastUpdate.Unix()
×
706
                } else {
×
707
                        entry.upd2Time = e.LastUpdate.Unix()
×
708
                }
×
709
                s.rejectCache.insert(e.ChannelID, entry)
×
710
        }
711

712
        // If an entry for this channel is found in channel cache, we'll modify
713
        // the entry with the updated policy for the direction that was just
714
        // written. If the edge doesn't exist, we'll defer loading the info and
715
        // policies and lazily read from disk during the next query.
716
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
717
                if isUpdate1 {
×
718
                        channel.Policy1 = e
×
719
                } else {
×
720
                        channel.Policy2 = e
×
721
                }
×
722
                s.chanCache.insert(e.ChannelID, channel)
×
723
        }
724
}
725

726
// ForEachSourceNodeChannel iterates through all channels of the source node,
727
// executing the passed callback on each. The call-back is provided with the
728
// channel's outpoint, whether we have a policy for the channel and the channel
729
// peer's node information.
730
//
731
// NOTE: part of the V1Store interface.
732
func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint,
733
        havePolicy bool, otherNode *models.LightningNode) error) error {
×
734

×
735
        var ctx = context.TODO()
×
736

×
737
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
738
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
739
                if err != nil {
×
740
                        return fmt.Errorf("unable to fetch source node: %w",
×
741
                                err)
×
742
                }
×
743

744
                return forEachNodeChannel(
×
745
                        ctx, db, s.cfg.ChainHash, nodeID,
×
746
                        func(info *models.ChannelEdgeInfo,
×
747
                                outPolicy *models.ChannelEdgePolicy,
×
748
                                _ *models.ChannelEdgePolicy) error {
×
749

×
750
                                // Fetch the other node.
×
751
                                var (
×
752
                                        otherNodePub [33]byte
×
753
                                        node1        = info.NodeKey1Bytes
×
754
                                        node2        = info.NodeKey2Bytes
×
755
                                )
×
756
                                switch {
×
757
                                case bytes.Equal(node1[:], nodePub[:]):
×
758
                                        otherNodePub = node2
×
759
                                case bytes.Equal(node2[:], nodePub[:]):
×
760
                                        otherNodePub = node1
×
761
                                default:
×
762
                                        return fmt.Errorf("node not " +
×
763
                                                "participating in this channel")
×
764
                                }
765

766
                                _, otherNode, err := getNodeByPubKey(
×
767
                                        ctx, db, otherNodePub,
×
768
                                )
×
769
                                if err != nil {
×
770
                                        return fmt.Errorf("unable to fetch "+
×
771
                                                "other node(%x): %w",
×
772
                                                otherNodePub, err)
×
773
                                }
×
774

775
                                return cb(
×
776
                                        info.ChannelPoint, outPolicy != nil,
×
777
                                        otherNode,
×
778
                                )
×
779
                        },
780
                )
781
        }, sqldb.NoOpReset)
782
}
783

784
// ForEachNode iterates through all the stored vertices/nodes in the graph,
785
// executing the passed callback with each node encountered. If the callback
786
// returns an error, then the transaction is aborted and the iteration stops
787
// early. Any operations performed on the NodeTx passed to the call-back are
788
// executed under the same read transaction and so, methods on the NodeTx object
789
// _MUST_ only be called from within the call-back.
790
//
791
// NOTE: part of the V1Store interface.
792
func (s *SQLStore) ForEachNode(cb func(tx NodeRTx) error) error {
×
793
        var (
×
794
                ctx          = context.TODO()
×
795
                lastID int64 = 0
×
796
        )
×
797

×
798
        handleNode := func(db SQLQueries, dbNode sqlc.Node) error {
×
799
                node, err := buildNode(ctx, db, &dbNode)
×
800
                if err != nil {
×
801
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
802
                                dbNode.ID, err)
×
803
                }
×
804

805
                err = cb(
×
806
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
807
                )
×
808
                if err != nil {
×
809
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
810
                                dbNode.ID, err)
×
811
                }
×
812

813
                return nil
×
814
        }
815

816
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
817
                for {
×
818
                        nodes, err := db.ListNodesPaginated(
×
819
                                ctx, sqlc.ListNodesPaginatedParams{
×
820
                                        Version: int16(ProtocolV1),
×
821
                                        ID:      lastID,
×
822
                                        Limit:   pageSize,
×
823
                                },
×
824
                        )
×
825
                        if err != nil {
×
826
                                return fmt.Errorf("unable to fetch nodes: %w",
×
827
                                        err)
×
828
                        }
×
829

830
                        if len(nodes) == 0 {
×
831
                                break
×
832
                        }
833

834
                        for _, dbNode := range nodes {
×
835
                                err = handleNode(db, dbNode)
×
836
                                if err != nil {
×
837
                                        return err
×
838
                                }
×
839

840
                                lastID = dbNode.ID
×
841
                        }
842
                }
843

844
                return nil
×
845
        }, sqldb.NoOpReset)
846
}
847

848
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
849
// SQLStore and a SQL transaction.
850
type sqlGraphNodeTx struct {
851
        db    SQLQueries
852
        id    int64
853
        node  *models.LightningNode
854
        chain chainhash.Hash
855
}
856

857
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
858
// interface.
859
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
860

861
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
862
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
863

×
864
        return &sqlGraphNodeTx{
×
865
                db:    db,
×
866
                chain: chain,
×
867
                id:    id,
×
868
                node:  node,
×
869
        }
×
870
}
×
871

872
// Node returns the raw information of the node.
873
//
874
// NOTE: This is a part of the NodeRTx interface.
875
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
876
        return s.node
×
877
}
×
878

879
// ForEachChannel can be used to iterate over the node's channels under the same
880
// transaction used to fetch the node.
881
//
882
// NOTE: This is a part of the NodeRTx interface.
883
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
884
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
885

×
886
        ctx := context.TODO()
×
887

×
888
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
889
}
×
890

891
// FetchNode fetches the node with the given pub key under the same transaction
892
// used to fetch the current node. The returned node is also a NodeRTx and any
893
// operations on that NodeRTx will also be done under the same transaction.
894
//
895
// NOTE: This is a part of the NodeRTx interface.
896
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
897
        ctx := context.TODO()
×
898

×
899
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
900
        if err != nil {
×
901
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
902
                        nodePub, err)
×
903
        }
×
904

905
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
906
}
907

908
// ForEachNodeDirectedChannel iterates through all channels of a given node,
909
// executing the passed callback on the directed edge representing the channel
910
// and its incoming policy. If the callback returns an error, then the iteration
911
// is halted with the error propagated back up to the caller.
912
//
913
// Unknown policies are passed into the callback as nil values.
914
//
915
// NOTE: this is part of the graphdb.NodeTraverser interface.
916
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
917
        cb func(channel *DirectedChannel) error) error {
×
918

×
919
        var ctx = context.TODO()
×
920

×
921
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
922
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
923
        }, sqldb.NoOpReset)
×
924
}
925

926
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
927
// graph, executing the passed callback with each node encountered. If the
928
// callback returns an error, then the transaction is aborted and the iteration
929
// stops early.
930
//
931
// NOTE: This is a part of the V1Store interface.
932
func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex,
933
        *lnwire.FeatureVector) error) error {
×
934

×
935
        ctx := context.TODO()
×
936

×
937
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
938
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
939
                        nodePub route.Vertex) error {
×
940

×
941
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
942
                        if err != nil {
×
943
                                return fmt.Errorf("unable to fetch node "+
×
944
                                        "features: %w", err)
×
945
                        }
×
946

947
                        return cb(nodePub, features)
×
948
                })
949
        }, sqldb.NoOpReset)
950
        if err != nil {
×
951
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
952
        }
×
953

954
        return nil
×
955
}
956

957
// ForEachNodeChannel iterates through all channels of the given node,
958
// executing the passed callback with an edge info structure and the policies
959
// of each end of the channel. The first edge policy is the outgoing edge *to*
960
// the connecting node, while the second is the incoming edge *from* the
961
// connecting node. If the callback returns an error, then the iteration is
962
// halted with the error propagated back up to the caller.
963
//
964
// Unknown policies are passed into the callback as nil values.
965
//
966
// NOTE: part of the V1Store interface.
967
func (s *SQLStore) ForEachNodeChannel(nodePub route.Vertex,
968
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
969
                *models.ChannelEdgePolicy) error) error {
×
970

×
971
        var ctx = context.TODO()
×
972

×
973
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
974
                dbNode, err := db.GetNodeByPubKey(
×
975
                        ctx, sqlc.GetNodeByPubKeyParams{
×
976
                                Version: int16(ProtocolV1),
×
977
                                PubKey:  nodePub[:],
×
978
                        },
×
979
                )
×
980
                if errors.Is(err, sql.ErrNoRows) {
×
981
                        return nil
×
982
                } else if err != nil {
×
983
                        return fmt.Errorf("unable to fetch node: %w", err)
×
984
                }
×
985

986
                return forEachNodeChannel(
×
987
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
988
                )
×
989
        }, sqldb.NoOpReset)
990
}
991

992
// ChanUpdatesInHorizon returns all the known channel edges which have at least
993
// one edge that has an update timestamp within the specified horizon.
994
//
995
// NOTE: This is part of the V1Store interface.
996
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
997
        endTime time.Time) ([]ChannelEdge, error) {
×
998

×
999
        s.cacheMu.Lock()
×
1000
        defer s.cacheMu.Unlock()
×
1001

×
1002
        var (
×
1003
                ctx = context.TODO()
×
1004
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
1005
                // an additional map to keep track of the edges already seen to
×
1006
                // prevent re-adding it.
×
1007
                edgesSeen    = make(map[uint64]struct{})
×
1008
                edgesToCache = make(map[uint64]ChannelEdge)
×
1009
                edges        []ChannelEdge
×
1010
                hits         int
×
1011
        )
×
1012
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1013
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1014
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1015
                                Version:   int16(ProtocolV1),
×
1016
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1017
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1018
                        },
×
1019
                )
×
1020
                if err != nil {
×
1021
                        return err
×
1022
                }
×
1023

1024
                for _, row := range rows {
×
1025
                        // If we've already retrieved the info and policies for
×
1026
                        // this edge, then we can skip it as we don't need to do
×
1027
                        // so again.
×
1028
                        chanIDInt := byteOrder.Uint64(row.Channel.Scid)
×
1029
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1030
                                continue
×
1031
                        }
1032

1033
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1034
                                hits++
×
1035
                                edgesSeen[chanIDInt] = struct{}{}
×
1036
                                edges = append(edges, channel)
×
1037

×
1038
                                continue
×
1039
                        }
1040

1041
                        node1, node2, err := buildNodes(
×
1042
                                ctx, db, row.Node, row.Node_2,
×
1043
                        )
×
1044
                        if err != nil {
×
1045
                                return err
×
1046
                        }
×
1047

1048
                        channel, err := getAndBuildEdgeInfo(
×
1049
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1050
                                row.Channel, node1.PubKeyBytes,
×
1051
                                node2.PubKeyBytes,
×
1052
                        )
×
1053
                        if err != nil {
×
1054
                                return fmt.Errorf("unable to build channel "+
×
1055
                                        "info: %w", err)
×
1056
                        }
×
1057

1058
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1059
                        if err != nil {
×
1060
                                return fmt.Errorf("unable to extract channel "+
×
1061
                                        "policies: %w", err)
×
1062
                        }
×
1063

1064
                        p1, p2, err := getAndBuildChanPolicies(
×
1065
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1066
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1067
                        )
×
1068
                        if err != nil {
×
1069
                                return fmt.Errorf("unable to build channel "+
×
1070
                                        "policies: %w", err)
×
1071
                        }
×
1072

1073
                        edgesSeen[chanIDInt] = struct{}{}
×
1074
                        chanEdge := ChannelEdge{
×
1075
                                Info:    channel,
×
1076
                                Policy1: p1,
×
1077
                                Policy2: p2,
×
1078
                                Node1:   node1,
×
1079
                                Node2:   node2,
×
1080
                        }
×
1081
                        edges = append(edges, chanEdge)
×
1082
                        edgesToCache[chanIDInt] = chanEdge
×
1083
                }
1084

1085
                return nil
×
1086
        }, func() {
×
1087
                edgesSeen = make(map[uint64]struct{})
×
1088
                edgesToCache = make(map[uint64]ChannelEdge)
×
1089
                edges = nil
×
1090
        })
×
1091
        if err != nil {
×
1092
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1093
        }
×
1094

1095
        // Insert any edges loaded from disk into the cache.
1096
        for chanid, channel := range edgesToCache {
×
1097
                s.chanCache.insert(chanid, channel)
×
1098
        }
×
1099

1100
        if len(edges) > 0 {
×
1101
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1102
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1103
        } else {
×
1104
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1105
                        "horizon (%s, %s)", startTime, endTime)
×
1106
        }
×
1107

1108
        return edges, nil
×
1109
}
1110

1111
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1112
// data to the call-back.
1113
//
1114
// NOTE: The callback contents MUST not be modified.
1115
//
1116
// NOTE: part of the V1Store interface.
1117
func (s *SQLStore) ForEachNodeCached(cb func(node route.Vertex,
1118
        chans map[uint64]*DirectedChannel) error) error {
×
1119

×
1120
        var ctx = context.TODO()
×
1121

×
1122
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1123
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1124
                        nodePub route.Vertex) error {
×
1125

×
1126
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1127
                        if err != nil {
×
1128
                                return fmt.Errorf("unable to fetch "+
×
1129
                                        "node(id=%d) features: %w", nodeID, err)
×
1130
                        }
×
1131

1132
                        toNodeCallback := func() route.Vertex {
×
1133
                                return nodePub
×
1134
                        }
×
1135

1136
                        rows, err := db.ListChannelsByNodeID(
×
1137
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1138
                                        Version: int16(ProtocolV1),
×
1139
                                        NodeID1: nodeID,
×
1140
                                },
×
1141
                        )
×
1142
                        if err != nil {
×
1143
                                return fmt.Errorf("unable to fetch channels "+
×
1144
                                        "of node(id=%d): %w", nodeID, err)
×
1145
                        }
×
1146

1147
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1148
                        for _, row := range rows {
×
1149
                                node1, node2, err := buildNodeVertices(
×
1150
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1151
                                )
×
1152
                                if err != nil {
×
1153
                                        return err
×
1154
                                }
×
1155

1156
                                e, err := getAndBuildEdgeInfo(
×
1157
                                        ctx, db, s.cfg.ChainHash,
×
1158
                                        row.Channel.ID, row.Channel, node1,
×
1159
                                        node2,
×
1160
                                )
×
1161
                                if err != nil {
×
1162
                                        return fmt.Errorf("unable to build "+
×
1163
                                                "channel info: %w", err)
×
1164
                                }
×
1165

1166
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1167
                                        row,
×
1168
                                )
×
1169
                                if err != nil {
×
1170
                                        return fmt.Errorf("unable to "+
×
1171
                                                "extract channel "+
×
1172
                                                "policies: %w", err)
×
1173
                                }
×
1174

1175
                                p1, p2, err := getAndBuildChanPolicies(
×
1176
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1177
                                        node1, node2,
×
1178
                                )
×
1179
                                if err != nil {
×
1180
                                        return fmt.Errorf("unable to "+
×
1181
                                                "build channel policies: %w",
×
1182
                                                err)
×
1183
                                }
×
1184

1185
                                // Determine the outgoing and incoming policy
1186
                                // for this channel and node combo.
1187
                                outPolicy, inPolicy := p1, p2
×
1188
                                if p1 != nil && p1.ToNode == nodePub {
×
1189
                                        outPolicy, inPolicy = p2, p1
×
1190
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1191
                                        outPolicy, inPolicy = p2, p1
×
1192
                                }
×
1193

1194
                                var cachedInPolicy *models.CachedEdgePolicy
×
1195
                                if inPolicy != nil {
×
1196
                                        cachedInPolicy = models.NewCachedPolicy(
×
1197
                                                p2,
×
1198
                                        )
×
1199
                                        cachedInPolicy.ToNodePubKey =
×
1200
                                                toNodeCallback
×
1201
                                        cachedInPolicy.ToNodeFeatures =
×
1202
                                                features
×
1203
                                }
×
1204

1205
                                var inboundFee lnwire.Fee
×
1206
                                outPolicy.InboundFee.WhenSome(
×
1207
                                        func(fee lnwire.Fee) {
×
1208
                                                inboundFee = fee
×
1209
                                        },
×
1210
                                )
1211

1212
                                directedChannel := &DirectedChannel{
×
1213
                                        ChannelID: e.ChannelID,
×
1214
                                        IsNode1: nodePub ==
×
1215
                                                e.NodeKey1Bytes,
×
1216
                                        OtherNode:    e.NodeKey2Bytes,
×
1217
                                        Capacity:     e.Capacity,
×
1218
                                        OutPolicySet: p1 != nil,
×
1219
                                        InPolicy:     cachedInPolicy,
×
1220
                                        InboundFee:   inboundFee,
×
1221
                                }
×
1222

×
1223
                                if nodePub == e.NodeKey2Bytes {
×
1224
                                        directedChannel.OtherNode =
×
1225
                                                e.NodeKey1Bytes
×
1226
                                }
×
1227

1228
                                channels[e.ChannelID] = directedChannel
×
1229
                        }
1230

1231
                        return cb(nodePub, channels)
×
1232
                })
1233
        }, sqldb.NoOpReset)
1234
}
1235

1236
// ForEachChannelCacheable iterates through all the channel edges stored
1237
// within the graph and invokes the passed callback for each edge. The
1238
// callback takes two edges as since this is a directed graph, both the
1239
// in/out edges are visited. If the callback returns an error, then the
1240
// transaction is aborted and the iteration stops early.
1241
//
1242
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1243
// pointer for that particular channel edge routing policy will be
1244
// passed into the callback.
1245
//
1246
// NOTE: this method is like ForEachChannel but fetches only the data
1247
// required for the graph cache.
1248
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1249
        *models.CachedEdgePolicy,
1250
        *models.CachedEdgePolicy) error) error {
×
1251

×
1252
        ctx := context.TODO()
×
1253

×
1254
        handleChannel := func(db SQLQueries,
×
1255
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1256

×
1257
                node1, node2, err := buildNodeVertices(
×
1258
                        row.Node1Pubkey, row.Node2Pubkey,
×
1259
                )
×
1260
                if err != nil {
×
1261
                        return err
×
1262
                }
×
1263

1264
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
1265

×
1266
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1267
                if err != nil {
×
1268
                        return err
×
1269
                }
×
1270

1271
                var pol1, pol2 *models.CachedEdgePolicy
×
1272
                if dbPol1 != nil {
×
1273
                        policy1, err := buildChanPolicy(
×
NEW
1274
                                *dbPol1, edge.ChannelID, nil, node2,
×
1275
                        )
×
1276
                        if err != nil {
×
1277
                                return err
×
1278
                        }
×
1279

1280
                        pol1 = models.NewCachedPolicy(policy1)
×
1281
                }
1282
                if dbPol2 != nil {
×
1283
                        policy2, err := buildChanPolicy(
×
NEW
1284
                                *dbPol2, edge.ChannelID, nil, node1,
×
1285
                        )
×
1286
                        if err != nil {
×
1287
                                return err
×
1288
                        }
×
1289

1290
                        pol2 = models.NewCachedPolicy(policy2)
×
1291
                }
1292

1293
                if err := cb(edge, pol1, pol2); err != nil {
×
1294
                        return err
×
1295
                }
×
1296

1297
                return nil
×
1298
        }
1299

1300
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1301
                lastID := int64(-1)
×
1302
                for {
×
1303
                        //nolint:ll
×
1304
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1305
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1306
                                        Version: int16(ProtocolV1),
×
1307
                                        ID:      lastID,
×
1308
                                        Limit:   pageSize,
×
1309
                                },
×
1310
                        )
×
1311
                        if err != nil {
×
1312
                                return err
×
1313
                        }
×
1314

1315
                        if len(rows) == 0 {
×
1316
                                break
×
1317
                        }
1318

1319
                        for _, row := range rows {
×
1320
                                err := handleChannel(db, row)
×
1321
                                if err != nil {
×
1322
                                        return err
×
1323
                                }
×
1324

1325
                                lastID = row.Channel.ID
×
1326
                        }
1327
                }
1328

1329
                return nil
×
1330
        }, sqldb.NoOpReset)
1331
}
1332

1333
// ForEachChannel iterates through all the channel edges stored within the
1334
// graph and invokes the passed callback for each edge. The callback takes two
1335
// edges as since this is a directed graph, both the in/out edges are visited.
1336
// If the callback returns an error, then the transaction is aborted and the
1337
// iteration stops early.
1338
//
1339
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1340
// for that particular channel edge routing policy will be passed into the
1341
// callback.
1342
//
1343
// NOTE: part of the V1Store interface.
1344
func (s *SQLStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
1345
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
1346

×
1347
        ctx := context.TODO()
×
1348

×
1349
        handleChannel := func(db SQLQueries,
×
1350
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1351

×
1352
                node1, node2, err := buildNodeVertices(
×
1353
                        row.Node1Pubkey, row.Node2Pubkey,
×
1354
                )
×
1355
                if err != nil {
×
1356
                        return fmt.Errorf("unable to build node vertices: %w",
×
1357
                                err)
×
1358
                }
×
1359

1360
                edge, err := getAndBuildEdgeInfo(
×
1361
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1362
                        node1, node2,
×
1363
                )
×
1364
                if err != nil {
×
1365
                        return fmt.Errorf("unable to build channel info: %w",
×
1366
                                err)
×
1367
                }
×
1368

1369
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1370
                if err != nil {
×
1371
                        return fmt.Errorf("unable to extract channel "+
×
1372
                                "policies: %w", err)
×
1373
                }
×
1374

1375
                p1, p2, err := getAndBuildChanPolicies(
×
1376
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1377
                )
×
1378
                if err != nil {
×
1379
                        return fmt.Errorf("unable to build channel "+
×
1380
                                "policies: %w", err)
×
1381
                }
×
1382

1383
                err = cb(edge, p1, p2)
×
1384
                if err != nil {
×
1385
                        return fmt.Errorf("callback failed for channel "+
×
1386
                                "id=%d: %w", edge.ChannelID, err)
×
1387
                }
×
1388

1389
                return nil
×
1390
        }
1391

1392
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1393
                lastID := int64(-1)
×
1394
                for {
×
1395
                        //nolint:ll
×
1396
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1397
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1398
                                        Version: int16(ProtocolV1),
×
1399
                                        ID:      lastID,
×
1400
                                        Limit:   pageSize,
×
1401
                                },
×
1402
                        )
×
1403
                        if err != nil {
×
1404
                                return err
×
1405
                        }
×
1406

1407
                        if len(rows) == 0 {
×
1408
                                break
×
1409
                        }
1410

1411
                        for _, row := range rows {
×
1412
                                err := handleChannel(db, row)
×
1413
                                if err != nil {
×
1414
                                        return err
×
1415
                                }
×
1416

1417
                                lastID = row.Channel.ID
×
1418
                        }
1419
                }
1420

1421
                return nil
×
1422
        }, sqldb.NoOpReset)
1423
}
1424

1425
// FilterChannelRange returns the channel ID's of all known channels which were
1426
// mined in a block height within the passed range. The channel IDs are grouped
1427
// by their common block height. This method can be used to quickly share with a
1428
// peer the set of channels we know of within a particular range to catch them
1429
// up after a period of time offline. If withTimestamps is true then the
1430
// timestamp info of the latest received channel update messages of the channel
1431
// will be included in the response.
1432
//
1433
// NOTE: This is part of the V1Store interface.
1434
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1435
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1436

×
1437
        var (
×
1438
                ctx       = context.TODO()
×
1439
                startSCID = &lnwire.ShortChannelID{
×
1440
                        BlockHeight: startHeight,
×
1441
                }
×
1442
                endSCID = lnwire.ShortChannelID{
×
1443
                        BlockHeight: endHeight,
×
1444
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1445
                        TxPosition:  math.MaxUint16,
×
1446
                }
×
1447
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1448
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1449
        )
×
1450

×
1451
        // 1) get all channels where channelID is between start and end chan ID.
×
1452
        // 2) skip if not public (ie, no channel_proof)
×
1453
        // 3) collect that channel.
×
1454
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1455
        //    and add those timestamps to the collected channel.
×
1456
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1457
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1458
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1459
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1460
                                StartScid: chanIDStart[:],
×
1461
                                EndScid:   chanIDEnd[:],
×
1462
                        },
×
1463
                )
×
1464
                if err != nil {
×
1465
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1466
                                err)
×
1467
                }
×
1468

1469
                for _, dbChan := range dbChans {
×
1470
                        cid := lnwire.NewShortChanIDFromInt(
×
1471
                                byteOrder.Uint64(dbChan.Scid),
×
1472
                        )
×
1473
                        chanInfo := NewChannelUpdateInfo(
×
1474
                                cid, time.Time{}, time.Time{},
×
1475
                        )
×
1476

×
1477
                        if !withTimestamps {
×
1478
                                channelsPerBlock[cid.BlockHeight] = append(
×
1479
                                        channelsPerBlock[cid.BlockHeight],
×
1480
                                        chanInfo,
×
1481
                                )
×
1482

×
1483
                                continue
×
1484
                        }
1485

1486
                        //nolint:ll
1487
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1488
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1489
                                        Version:   int16(ProtocolV1),
×
1490
                                        ChannelID: dbChan.ID,
×
1491
                                        NodeID:    dbChan.NodeID1,
×
1492
                                },
×
1493
                        )
×
1494
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1495
                                return fmt.Errorf("unable to fetch node1 "+
×
1496
                                        "policy: %w", err)
×
1497
                        } else if err == nil {
×
1498
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1499
                                        node1Policy.LastUpdate.Int64, 0,
×
1500
                                )
×
1501
                        }
×
1502

1503
                        //nolint:ll
1504
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1505
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1506
                                        Version:   int16(ProtocolV1),
×
1507
                                        ChannelID: dbChan.ID,
×
1508
                                        NodeID:    dbChan.NodeID2,
×
1509
                                },
×
1510
                        )
×
1511
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1512
                                return fmt.Errorf("unable to fetch node2 "+
×
1513
                                        "policy: %w", err)
×
1514
                        } else if err == nil {
×
1515
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1516
                                        node2Policy.LastUpdate.Int64, 0,
×
1517
                                )
×
1518
                        }
×
1519

1520
                        channelsPerBlock[cid.BlockHeight] = append(
×
1521
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1522
                        )
×
1523
                }
1524

1525
                return nil
×
1526
        }, func() {
×
1527
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1528
        })
×
1529
        if err != nil {
×
1530
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1531
        }
×
1532

1533
        if len(channelsPerBlock) == 0 {
×
1534
                return nil, nil
×
1535
        }
×
1536

1537
        // Return the channel ranges in ascending block height order.
1538
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1539
        slices.Sort(blocks)
×
1540

×
1541
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1542
                return BlockChannelRange{
×
1543
                        Height:   block,
×
1544
                        Channels: channelsPerBlock[block],
×
1545
                }
×
1546
        }), nil
×
1547
}
1548

1549
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1550
// zombie. This method is used on an ad-hoc basis, when channels need to be
1551
// marked as zombies outside the normal pruning cycle.
1552
//
1553
// NOTE: part of the V1Store interface.
1554
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1555
        pubKey1, pubKey2 [33]byte) error {
×
1556

×
1557
        ctx := context.TODO()
×
1558

×
1559
        s.cacheMu.Lock()
×
1560
        defer s.cacheMu.Unlock()
×
1561

×
1562
        chanIDB := channelIDToBytes(chanID)
×
1563

×
1564
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1565
                return db.UpsertZombieChannel(
×
1566
                        ctx, sqlc.UpsertZombieChannelParams{
×
1567
                                Version:  int16(ProtocolV1),
×
1568
                                Scid:     chanIDB[:],
×
1569
                                NodeKey1: pubKey1[:],
×
1570
                                NodeKey2: pubKey2[:],
×
1571
                        },
×
1572
                )
×
1573
        }, sqldb.NoOpReset)
×
1574
        if err != nil {
×
1575
                return fmt.Errorf("unable to upsert zombie channel "+
×
1576
                        "(channel_id=%d): %w", chanID, err)
×
1577
        }
×
1578

1579
        s.rejectCache.remove(chanID)
×
1580
        s.chanCache.remove(chanID)
×
1581

×
1582
        return nil
×
1583
}
1584

1585
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1586
//
1587
// NOTE: part of the V1Store interface.
1588
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1589
        s.cacheMu.Lock()
×
1590
        defer s.cacheMu.Unlock()
×
1591

×
1592
        var (
×
1593
                ctx     = context.TODO()
×
1594
                chanIDB = channelIDToBytes(chanID)
×
1595
        )
×
1596

×
1597
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1598
                res, err := db.DeleteZombieChannel(
×
1599
                        ctx, sqlc.DeleteZombieChannelParams{
×
1600
                                Scid:    chanIDB[:],
×
1601
                                Version: int16(ProtocolV1),
×
1602
                        },
×
1603
                )
×
1604
                if err != nil {
×
1605
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1606
                                err)
×
1607
                }
×
1608

1609
                rows, err := res.RowsAffected()
×
1610
                if err != nil {
×
1611
                        return err
×
1612
                }
×
1613

1614
                if rows == 0 {
×
1615
                        return ErrZombieEdgeNotFound
×
1616
                } else if rows > 1 {
×
1617
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1618
                                "expected 1", rows)
×
1619
                }
×
1620

1621
                return nil
×
1622
        }, sqldb.NoOpReset)
1623
        if err != nil {
×
1624
                return fmt.Errorf("unable to mark edge live "+
×
1625
                        "(channel_id=%d): %w", chanID, err)
×
1626
        }
×
1627

1628
        s.rejectCache.remove(chanID)
×
1629
        s.chanCache.remove(chanID)
×
1630

×
1631
        return err
×
1632
}
1633

1634
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1635
// zombie, then the two node public keys corresponding to this edge are also
1636
// returned.
1637
//
1638
// NOTE: part of the V1Store interface.
1639
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1640
        error) {
×
1641

×
1642
        var (
×
1643
                ctx              = context.TODO()
×
1644
                isZombie         bool
×
1645
                pubKey1, pubKey2 route.Vertex
×
1646
                chanIDB          = channelIDToBytes(chanID)
×
1647
        )
×
1648

×
1649
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1650
                zombie, err := db.GetZombieChannel(
×
1651
                        ctx, sqlc.GetZombieChannelParams{
×
1652
                                Scid:    chanIDB[:],
×
1653
                                Version: int16(ProtocolV1),
×
1654
                        },
×
1655
                )
×
1656
                if errors.Is(err, sql.ErrNoRows) {
×
1657
                        return nil
×
1658
                }
×
1659
                if err != nil {
×
1660
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1661
                                err)
×
1662
                }
×
1663

1664
                copy(pubKey1[:], zombie.NodeKey1)
×
1665
                copy(pubKey2[:], zombie.NodeKey2)
×
1666
                isZombie = true
×
1667

×
1668
                return nil
×
1669
        }, sqldb.NoOpReset)
1670
        if err != nil {
×
1671
                return false, route.Vertex{}, route.Vertex{},
×
1672
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1673
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1674
        }
×
1675

1676
        return isZombie, pubKey1, pubKey2, nil
×
1677
}
1678

1679
// NumZombies returns the current number of zombie channels in the graph.
1680
//
1681
// NOTE: part of the V1Store interface.
1682
func (s *SQLStore) NumZombies() (uint64, error) {
×
1683
        var (
×
1684
                ctx        = context.TODO()
×
1685
                numZombies uint64
×
1686
        )
×
1687
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1688
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1689
                if err != nil {
×
1690
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1691
                                err)
×
1692
                }
×
1693

1694
                numZombies = uint64(count)
×
1695

×
1696
                return nil
×
1697
        }, sqldb.NoOpReset)
1698
        if err != nil {
×
1699
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1700
        }
×
1701

1702
        return numZombies, nil
×
1703
}
1704

1705
// DeleteChannelEdges removes edges with the given channel IDs from the
1706
// database and marks them as zombies. This ensures that we're unable to re-add
1707
// it to our database once again. If an edge does not exist within the
1708
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1709
// true, then when we mark these edges as zombies, we'll set up the keys such
1710
// that we require the node that failed to send the fresh update to be the one
1711
// that resurrects the channel from its zombie state. The markZombie bool
1712
// denotes whether to mark the channel as a zombie.
1713
//
1714
// NOTE: part of the V1Store interface.
1715
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1716
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1717

×
1718
        s.cacheMu.Lock()
×
1719
        defer s.cacheMu.Unlock()
×
1720

×
1721
        var (
×
1722
                ctx     = context.TODO()
×
1723
                deleted []*models.ChannelEdgeInfo
×
1724
        )
×
1725
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1726
                for _, chanID := range chanIDs {
×
1727
                        chanIDB := channelIDToBytes(chanID)
×
1728

×
1729
                        row, err := db.GetChannelBySCIDWithPolicies(
×
1730
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1731
                                        Scid:    chanIDB[:],
×
1732
                                        Version: int16(ProtocolV1),
×
1733
                                },
×
1734
                        )
×
1735
                        if errors.Is(err, sql.ErrNoRows) {
×
1736
                                return ErrEdgeNotFound
×
1737
                        } else if err != nil {
×
1738
                                return fmt.Errorf("unable to fetch channel: %w",
×
1739
                                        err)
×
1740
                        }
×
1741

1742
                        node1, node2, err := buildNodeVertices(
×
1743
                                row.Node.PubKey, row.Node_2.PubKey,
×
1744
                        )
×
1745
                        if err != nil {
×
1746
                                return err
×
1747
                        }
×
1748

1749
                        info, err := getAndBuildEdgeInfo(
×
1750
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1751
                                row.Channel, node1, node2,
×
1752
                        )
×
1753
                        if err != nil {
×
1754
                                return err
×
1755
                        }
×
1756

1757
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
1758
                        if err != nil {
×
1759
                                return fmt.Errorf("unable to delete "+
×
1760
                                        "channel: %w", err)
×
1761
                        }
×
1762

1763
                        deleted = append(deleted, info)
×
1764

×
1765
                        if !markZombie {
×
1766
                                continue
×
1767
                        }
1768

1769
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1770
                                info.NodeKey2Bytes
×
1771
                        if strictZombiePruning {
×
1772
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1773
                                if row.Policy1LastUpdate.Valid {
×
1774
                                        e1Time := time.Unix(
×
1775
                                                row.Policy1LastUpdate.Int64, 0,
×
1776
                                        )
×
1777
                                        e1UpdateTime = &e1Time
×
1778
                                }
×
1779
                                if row.Policy2LastUpdate.Valid {
×
1780
                                        e2Time := time.Unix(
×
1781
                                                row.Policy2LastUpdate.Int64, 0,
×
1782
                                        )
×
1783
                                        e2UpdateTime = &e2Time
×
1784
                                }
×
1785

1786
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1787
                                        info, e1UpdateTime, e2UpdateTime,
×
1788
                                )
×
1789
                        }
1790

1791
                        err = db.UpsertZombieChannel(
×
1792
                                ctx, sqlc.UpsertZombieChannelParams{
×
1793
                                        Version:  int16(ProtocolV1),
×
1794
                                        Scid:     chanIDB[:],
×
1795
                                        NodeKey1: nodeKey1[:],
×
1796
                                        NodeKey2: nodeKey2[:],
×
1797
                                },
×
1798
                        )
×
1799
                        if err != nil {
×
1800
                                return fmt.Errorf("unable to mark channel as "+
×
1801
                                        "zombie: %w", err)
×
1802
                        }
×
1803
                }
1804

1805
                return nil
×
1806
        }, func() {
×
1807
                deleted = nil
×
1808
        })
×
1809
        if err != nil {
×
1810
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1811
                        err)
×
1812
        }
×
1813

1814
        for _, chanID := range chanIDs {
×
1815
                s.rejectCache.remove(chanID)
×
1816
                s.chanCache.remove(chanID)
×
1817
        }
×
1818

1819
        return deleted, nil
×
1820
}
1821

1822
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1823
// channel identified by the channel ID. If the channel can't be found, then
1824
// ErrEdgeNotFound is returned. A struct which houses the general information
1825
// for the channel itself is returned as well as two structs that contain the
1826
// routing policies for the channel in either direction.
1827
//
1828
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1829
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1830
// the ChannelEdgeInfo will only include the public keys of each node.
1831
//
1832
// NOTE: part of the V1Store interface.
1833
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1834
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1835
        *models.ChannelEdgePolicy, error) {
×
1836

×
1837
        var (
×
1838
                ctx              = context.TODO()
×
1839
                edge             *models.ChannelEdgeInfo
×
1840
                policy1, policy2 *models.ChannelEdgePolicy
×
1841
        )
×
1842
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1843
                var chanIDB [8]byte
×
1844
                byteOrder.PutUint64(chanIDB[:], chanID)
×
1845

×
1846
                row, err := db.GetChannelBySCIDWithPolicies(
×
1847
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1848
                                Scid:    chanIDB[:],
×
1849
                                Version: int16(ProtocolV1),
×
1850
                        },
×
1851
                )
×
1852
                if errors.Is(err, sql.ErrNoRows) {
×
1853
                        // First check if this edge is perhaps in the zombie
×
1854
                        // index.
×
1855
                        isZombie, err := db.IsZombieChannel(
×
1856
                                ctx, sqlc.IsZombieChannelParams{
×
1857
                                        Scid:    chanIDB[:],
×
1858
                                        Version: int16(ProtocolV1),
×
1859
                                },
×
1860
                        )
×
1861
                        if err != nil {
×
1862
                                return fmt.Errorf("unable to check if "+
×
1863
                                        "channel is zombie: %w", err)
×
1864
                        } else if isZombie {
×
1865
                                return ErrZombieEdge
×
1866
                        }
×
1867

1868
                        return ErrEdgeNotFound
×
1869
                } else if err != nil {
×
1870
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1871
                }
×
1872

1873
                node1, node2, err := buildNodeVertices(
×
1874
                        row.Node.PubKey, row.Node_2.PubKey,
×
1875
                )
×
1876
                if err != nil {
×
1877
                        return err
×
1878
                }
×
1879

1880
                edge, err = getAndBuildEdgeInfo(
×
1881
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1882
                        node1, node2,
×
1883
                )
×
1884
                if err != nil {
×
1885
                        return fmt.Errorf("unable to build channel info: %w",
×
1886
                                err)
×
1887
                }
×
1888

1889
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1890
                if err != nil {
×
1891
                        return fmt.Errorf("unable to extract channel "+
×
1892
                                "policies: %w", err)
×
1893
                }
×
1894

1895
                policy1, policy2, err = getAndBuildChanPolicies(
×
1896
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1897
                )
×
1898
                if err != nil {
×
1899
                        return fmt.Errorf("unable to build channel "+
×
1900
                                "policies: %w", err)
×
1901
                }
×
1902

1903
                return nil
×
1904
        }, sqldb.NoOpReset)
1905
        if err != nil {
×
1906
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1907
                        err)
×
1908
        }
×
1909

1910
        return edge, policy1, policy2, nil
×
1911
}
1912

1913
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1914
// the channel identified by the funding outpoint. If the channel can't be
1915
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1916
// information for the channel itself is returned as well as two structs that
1917
// contain the routing policies for the channel in either direction.
1918
//
1919
// NOTE: part of the V1Store interface.
1920
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1921
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1922
        *models.ChannelEdgePolicy, error) {
×
1923

×
1924
        var (
×
1925
                ctx              = context.TODO()
×
1926
                edge             *models.ChannelEdgeInfo
×
1927
                policy1, policy2 *models.ChannelEdgePolicy
×
1928
        )
×
1929
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1930
                row, err := db.GetChannelByOutpointWithPolicies(
×
1931
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1932
                                Outpoint: op.String(),
×
1933
                                Version:  int16(ProtocolV1),
×
1934
                        },
×
1935
                )
×
1936
                if errors.Is(err, sql.ErrNoRows) {
×
1937
                        return ErrEdgeNotFound
×
1938
                } else if err != nil {
×
1939
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1940
                }
×
1941

1942
                node1, node2, err := buildNodeVertices(
×
1943
                        row.Node1Pubkey, row.Node2Pubkey,
×
1944
                )
×
1945
                if err != nil {
×
1946
                        return err
×
1947
                }
×
1948

1949
                edge, err = getAndBuildEdgeInfo(
×
1950
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1951
                        node1, node2,
×
1952
                )
×
1953
                if err != nil {
×
1954
                        return fmt.Errorf("unable to build channel info: %w",
×
1955
                                err)
×
1956
                }
×
1957

1958
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1959
                if err != nil {
×
1960
                        return fmt.Errorf("unable to extract channel "+
×
1961
                                "policies: %w", err)
×
1962
                }
×
1963

1964
                policy1, policy2, err = getAndBuildChanPolicies(
×
1965
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1966
                )
×
1967
                if err != nil {
×
1968
                        return fmt.Errorf("unable to build channel "+
×
1969
                                "policies: %w", err)
×
1970
                }
×
1971

1972
                return nil
×
1973
        }, sqldb.NoOpReset)
1974
        if err != nil {
×
1975
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1976
                        err)
×
1977
        }
×
1978

1979
        return edge, policy1, policy2, nil
×
1980
}
1981

1982
// HasChannelEdge returns true if the database knows of a channel edge with the
1983
// passed channel ID, and false otherwise. If an edge with that ID is found
1984
// within the graph, then two time stamps representing the last time the edge
1985
// was updated for both directed edges are returned along with the boolean. If
1986
// it is not found, then the zombie index is checked and its result is returned
1987
// as the second boolean.
1988
//
1989
// NOTE: part of the V1Store interface.
1990
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1991
        bool, error) {
×
1992

×
1993
        ctx := context.TODO()
×
1994

×
1995
        var (
×
1996
                exists          bool
×
1997
                isZombie        bool
×
1998
                node1LastUpdate time.Time
×
1999
                node2LastUpdate time.Time
×
2000
        )
×
2001

×
2002
        // We'll query the cache with the shared lock held to allow multiple
×
2003
        // readers to access values in the cache concurrently if they exist.
×
2004
        s.cacheMu.RLock()
×
2005
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2006
                s.cacheMu.RUnlock()
×
2007
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2008
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2009
                exists, isZombie = entry.flags.unpack()
×
2010

×
2011
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2012
        }
×
2013
        s.cacheMu.RUnlock()
×
2014

×
2015
        s.cacheMu.Lock()
×
2016
        defer s.cacheMu.Unlock()
×
2017

×
2018
        // The item was not found with the shared lock, so we'll acquire the
×
2019
        // exclusive lock and check the cache again in case another method added
×
2020
        // the entry to the cache while no lock was held.
×
2021
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2022
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2023
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2024
                exists, isZombie = entry.flags.unpack()
×
2025

×
2026
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2027
        }
×
2028

2029
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2030
                var chanIDB [8]byte
×
2031
                byteOrder.PutUint64(chanIDB[:], chanID)
×
2032

×
2033
                channel, err := db.GetChannelBySCID(
×
2034
                        ctx, sqlc.GetChannelBySCIDParams{
×
2035
                                Scid:    chanIDB[:],
×
2036
                                Version: int16(ProtocolV1),
×
2037
                        },
×
2038
                )
×
2039
                if errors.Is(err, sql.ErrNoRows) {
×
2040
                        // Check if it is a zombie channel.
×
2041
                        isZombie, err = db.IsZombieChannel(
×
2042
                                ctx, sqlc.IsZombieChannelParams{
×
2043
                                        Scid:    chanIDB[:],
×
2044
                                        Version: int16(ProtocolV1),
×
2045
                                },
×
2046
                        )
×
2047
                        if err != nil {
×
2048
                                return fmt.Errorf("could not check if channel "+
×
2049
                                        "is zombie: %w", err)
×
2050
                        }
×
2051

2052
                        return nil
×
2053
                } else if err != nil {
×
2054
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2055
                }
×
2056

2057
                exists = true
×
2058

×
2059
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2060
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2061
                                Version:   int16(ProtocolV1),
×
2062
                                ChannelID: channel.ID,
×
2063
                                NodeID:    channel.NodeID1,
×
2064
                        },
×
2065
                )
×
2066
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2067
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2068
                                err)
×
2069
                } else if err == nil {
×
2070
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2071
                }
×
2072

2073
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2074
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2075
                                Version:   int16(ProtocolV1),
×
2076
                                ChannelID: channel.ID,
×
2077
                                NodeID:    channel.NodeID2,
×
2078
                        },
×
2079
                )
×
2080
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2081
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2082
                                err)
×
2083
                } else if err == nil {
×
2084
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2085
                }
×
2086

2087
                return nil
×
2088
        }, sqldb.NoOpReset)
2089
        if err != nil {
×
2090
                return time.Time{}, time.Time{}, false, false,
×
2091
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2092
        }
×
2093

2094
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2095
                upd1Time: node1LastUpdate.Unix(),
×
2096
                upd2Time: node2LastUpdate.Unix(),
×
2097
                flags:    packRejectFlags(exists, isZombie),
×
2098
        })
×
2099

×
2100
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2101
}
2102

2103
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2104
// passed channel point (outpoint). If the passed channel doesn't exist within
2105
// the database, then ErrEdgeNotFound is returned.
2106
//
2107
// NOTE: part of the V1Store interface.
2108
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2109
        var (
×
2110
                ctx       = context.TODO()
×
2111
                channelID uint64
×
2112
        )
×
2113
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2114
                chanID, err := db.GetSCIDByOutpoint(
×
2115
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2116
                                Outpoint: chanPoint.String(),
×
2117
                                Version:  int16(ProtocolV1),
×
2118
                        },
×
2119
                )
×
2120
                if errors.Is(err, sql.ErrNoRows) {
×
2121
                        return ErrEdgeNotFound
×
2122
                } else if err != nil {
×
2123
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2124
                                err)
×
2125
                }
×
2126

2127
                channelID = byteOrder.Uint64(chanID)
×
2128

×
2129
                return nil
×
2130
        }, sqldb.NoOpReset)
2131
        if err != nil {
×
2132
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2133
        }
×
2134

2135
        return channelID, nil
×
2136
}
2137

2138
// IsPublicNode is a helper method that determines whether the node with the
2139
// given public key is seen as a public node in the graph from the graph's
2140
// source node's point of view.
2141
//
2142
// NOTE: part of the V1Store interface.
2143
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2144
        ctx := context.TODO()
×
2145

×
2146
        var isPublic bool
×
2147
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2148
                var err error
×
2149
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2150

×
2151
                return err
×
2152
        }, sqldb.NoOpReset)
×
2153
        if err != nil {
×
2154
                return false, fmt.Errorf("unable to check if node is "+
×
2155
                        "public: %w", err)
×
2156
        }
×
2157

2158
        return isPublic, nil
×
2159
}
2160

2161
// FetchChanInfos returns the set of channel edges that correspond to the passed
2162
// channel ID's. If an edge is the query is unknown to the database, it will
2163
// skipped and the result will contain only those edges that exist at the time
2164
// of the query. This can be used to respond to peer queries that are seeking to
2165
// fill in gaps in their view of the channel graph.
2166
//
2167
// NOTE: part of the V1Store interface.
2168
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2169
        var (
×
2170
                ctx   = context.TODO()
×
2171
                edges []ChannelEdge
×
2172
        )
×
2173
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2174
                for _, chanID := range chanIDs {
×
2175
                        var chanIDB [8]byte
×
2176
                        byteOrder.PutUint64(chanIDB[:], chanID)
×
2177

×
2178
                        // TODO(elle): potentially optimize this by using
×
2179
                        //  sqlc.slice() once that works for both SQLite and
×
2180
                        //  Postgres.
×
2181
                        row, err := db.GetChannelBySCIDWithPolicies(
×
2182
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
2183
                                        Scid:    chanIDB[:],
×
2184
                                        Version: int16(ProtocolV1),
×
2185
                                },
×
2186
                        )
×
2187
                        if errors.Is(err, sql.ErrNoRows) {
×
2188
                                continue
×
2189
                        } else if err != nil {
×
2190
                                return fmt.Errorf("unable to fetch channel: %w",
×
2191
                                        err)
×
2192
                        }
×
2193

2194
                        node1, node2, err := buildNodes(
×
2195
                                ctx, db, row.Node, row.Node_2,
×
2196
                        )
×
2197
                        if err != nil {
×
2198
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2199
                                        err)
×
2200
                        }
×
2201

2202
                        edge, err := getAndBuildEdgeInfo(
×
2203
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2204
                                row.Channel, node1.PubKeyBytes,
×
2205
                                node2.PubKeyBytes,
×
2206
                        )
×
2207
                        if err != nil {
×
2208
                                return fmt.Errorf("unable to build "+
×
2209
                                        "channel info: %w", err)
×
2210
                        }
×
2211

2212
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2213
                        if err != nil {
×
2214
                                return fmt.Errorf("unable to extract channel "+
×
2215
                                        "policies: %w", err)
×
2216
                        }
×
2217

2218
                        p1, p2, err := getAndBuildChanPolicies(
×
2219
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2220
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2221
                        )
×
2222
                        if err != nil {
×
2223
                                return fmt.Errorf("unable to build channel "+
×
2224
                                        "policies: %w", err)
×
2225
                        }
×
2226

2227
                        edges = append(edges, ChannelEdge{
×
2228
                                Info:    edge,
×
2229
                                Policy1: p1,
×
2230
                                Policy2: p2,
×
2231
                                Node1:   node1,
×
2232
                                Node2:   node2,
×
2233
                        })
×
2234
                }
2235

2236
                return nil
×
2237
        }, func() {
×
2238
                edges = nil
×
2239
        })
×
2240
        if err != nil {
×
2241
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2242
        }
×
2243

2244
        return edges, nil
×
2245
}
2246

2247
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2248
// ID's that we don't know and are not known zombies of the passed set. In other
2249
// words, we perform a set difference of our set of chan ID's and the ones
2250
// passed in. This method can be used by callers to determine the set of
2251
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2252
// known zombies is also returned.
2253
//
2254
// NOTE: part of the V1Store interface.
2255
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2256
        []ChannelUpdateInfo, error) {
×
2257

×
2258
        var (
×
2259
                ctx          = context.TODO()
×
2260
                newChanIDs   []uint64
×
2261
                knownZombies []ChannelUpdateInfo
×
2262
        )
×
2263
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2264
                for _, chanInfo := range chansInfo {
×
2265
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2266
                        var chanIDB [8]byte
×
2267
                        byteOrder.PutUint64(chanIDB[:], channelID)
×
2268

×
2269
                        // TODO(elle): potentially optimize this by using
×
2270
                        //  sqlc.slice() once that works for both SQLite and
×
2271
                        //  Postgres.
×
2272
                        _, err := db.GetChannelBySCID(
×
2273
                                ctx, sqlc.GetChannelBySCIDParams{
×
2274
                                        Version: int16(ProtocolV1),
×
2275
                                        Scid:    chanIDB[:],
×
2276
                                },
×
2277
                        )
×
2278
                        if err == nil {
×
2279
                                continue
×
2280
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
2281
                                return fmt.Errorf("unable to fetch channel: %w",
×
2282
                                        err)
×
2283
                        }
×
2284

2285
                        isZombie, err := db.IsZombieChannel(
×
2286
                                ctx, sqlc.IsZombieChannelParams{
×
2287
                                        Scid:    chanIDB[:],
×
2288
                                        Version: int16(ProtocolV1),
×
2289
                                },
×
2290
                        )
×
2291
                        if err != nil {
×
2292
                                return fmt.Errorf("unable to fetch zombie "+
×
2293
                                        "channel: %w", err)
×
2294
                        }
×
2295

2296
                        if isZombie {
×
2297
                                knownZombies = append(knownZombies, chanInfo)
×
2298

×
2299
                                continue
×
2300
                        }
2301

2302
                        newChanIDs = append(newChanIDs, channelID)
×
2303
                }
2304

2305
                return nil
×
2306
        }, func() {
×
2307
                newChanIDs = nil
×
2308
                knownZombies = nil
×
2309
        })
×
2310
        if err != nil {
×
2311
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2312
        }
×
2313

2314
        return newChanIDs, knownZombies, nil
×
2315
}
2316

2317
// PruneGraphNodes is a garbage collection method which attempts to prune out
2318
// any nodes from the channel graph that are currently unconnected. This ensure
2319
// that we only maintain a graph of reachable nodes. In the event that a pruned
2320
// node gains more channels, it will be re-added back to the graph.
2321
//
2322
// NOTE: this prunes nodes across protocol versions. It will never prune the
2323
// source nodes.
2324
//
2325
// NOTE: part of the V1Store interface.
2326
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2327
        var ctx = context.TODO()
×
2328

×
2329
        var prunedNodes []route.Vertex
×
2330
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2331
                var err error
×
2332
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2333

×
2334
                return err
×
2335
        }, func() {
×
2336
                prunedNodes = nil
×
2337
        })
×
2338
        if err != nil {
×
2339
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2340
        }
×
2341

2342
        return prunedNodes, nil
×
2343
}
2344

2345
// PruneGraph prunes newly closed channels from the channel graph in response
2346
// to a new block being solved on the network. Any transactions which spend the
2347
// funding output of any known channels within he graph will be deleted.
2348
// Additionally, the "prune tip", or the last block which has been used to
2349
// prune the graph is stored so callers can ensure the graph is fully in sync
2350
// with the current UTXO state. A slice of channels that have been closed by
2351
// the target block along with any pruned nodes are returned if the function
2352
// succeeds without error.
2353
//
2354
// NOTE: part of the V1Store interface.
2355
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2356
        blockHash *chainhash.Hash, blockHeight uint32) (
2357
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2358

×
2359
        ctx := context.TODO()
×
2360

×
2361
        s.cacheMu.Lock()
×
2362
        defer s.cacheMu.Unlock()
×
2363

×
2364
        var (
×
2365
                closedChans []*models.ChannelEdgeInfo
×
2366
                prunedNodes []route.Vertex
×
2367
        )
×
2368
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2369
                for _, outpoint := range spentOutputs {
×
2370
                        // TODO(elle): potentially optimize this by using
×
2371
                        //  sqlc.slice() once that works for both SQLite and
×
2372
                        //  Postgres.
×
2373
                        //
×
2374
                        // NOTE: this fetches channels for all protocol
×
2375
                        // versions.
×
2376
                        row, err := db.GetChannelByOutpoint(
×
2377
                                ctx, outpoint.String(),
×
2378
                        )
×
2379
                        if errors.Is(err, sql.ErrNoRows) {
×
2380
                                continue
×
2381
                        } else if err != nil {
×
2382
                                return fmt.Errorf("unable to fetch channel: %w",
×
2383
                                        err)
×
2384
                        }
×
2385

2386
                        node1, node2, err := buildNodeVertices(
×
2387
                                row.Node1Pubkey, row.Node2Pubkey,
×
2388
                        )
×
2389
                        if err != nil {
×
2390
                                return err
×
2391
                        }
×
2392

2393
                        info, err := getAndBuildEdgeInfo(
×
2394
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2395
                                row.Channel, node1, node2,
×
2396
                        )
×
2397
                        if err != nil {
×
2398
                                return err
×
2399
                        }
×
2400

2401
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2402
                        if err != nil {
×
2403
                                return fmt.Errorf("unable to delete "+
×
2404
                                        "channel: %w", err)
×
2405
                        }
×
2406

2407
                        closedChans = append(closedChans, info)
×
2408
                }
2409

2410
                err := db.UpsertPruneLogEntry(
×
2411
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2412
                                BlockHash:   blockHash[:],
×
2413
                                BlockHeight: int64(blockHeight),
×
2414
                        },
×
2415
                )
×
2416
                if err != nil {
×
2417
                        return fmt.Errorf("unable to insert prune log "+
×
2418
                                "entry: %w", err)
×
2419
                }
×
2420

2421
                // Now that we've pruned some channels, we'll also prune any
2422
                // nodes that no longer have any channels.
2423
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2424
                if err != nil {
×
2425
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2426
                                err)
×
2427
                }
×
2428

2429
                return nil
×
2430
        }, func() {
×
2431
                prunedNodes = nil
×
2432
                closedChans = nil
×
2433
        })
×
2434
        if err != nil {
×
2435
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2436
        }
×
2437

2438
        for _, channel := range closedChans {
×
2439
                s.rejectCache.remove(channel.ChannelID)
×
2440
                s.chanCache.remove(channel.ChannelID)
×
2441
        }
×
2442

2443
        return closedChans, prunedNodes, nil
×
2444
}
2445

2446
// ChannelView returns the verifiable edge information for each active channel
2447
// within the known channel graph. The set of UTXOs (along with their scripts)
2448
// returned are the ones that need to be watched on chain to detect channel
2449
// closes on the resident blockchain.
2450
//
2451
// NOTE: part of the V1Store interface.
2452
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2453
        var (
×
2454
                ctx        = context.TODO()
×
2455
                edgePoints []EdgePoint
×
2456
        )
×
2457

×
2458
        handleChannel := func(db SQLQueries,
×
2459
                channel sqlc.ListChannelsPaginatedRow) error {
×
2460

×
2461
                pkScript, err := genMultiSigP2WSH(
×
2462
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
2463
                )
×
2464
                if err != nil {
×
2465
                        return err
×
2466
                }
×
2467

2468
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2469
                if err != nil {
×
2470
                        return err
×
2471
                }
×
2472

2473
                edgePoints = append(edgePoints, EdgePoint{
×
2474
                        FundingPkScript: pkScript,
×
2475
                        OutPoint:        *op,
×
2476
                })
×
2477

×
2478
                return nil
×
2479
        }
2480

2481
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2482
                lastID := int64(-1)
×
2483
                for {
×
2484
                        rows, err := db.ListChannelsPaginated(
×
2485
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2486
                                        Version: int16(ProtocolV1),
×
2487
                                        ID:      lastID,
×
2488
                                        Limit:   pageSize,
×
2489
                                },
×
2490
                        )
×
2491
                        if err != nil {
×
2492
                                return err
×
2493
                        }
×
2494

2495
                        if len(rows) == 0 {
×
2496
                                break
×
2497
                        }
2498

2499
                        for _, row := range rows {
×
2500
                                err := handleChannel(db, row)
×
2501
                                if err != nil {
×
2502
                                        return err
×
2503
                                }
×
2504

2505
                                lastID = row.ID
×
2506
                        }
2507
                }
2508

2509
                return nil
×
2510
        }, func() {
×
2511
                edgePoints = nil
×
2512
        })
×
2513
        if err != nil {
×
2514
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2515
        }
×
2516

2517
        return edgePoints, nil
×
2518
}
2519

2520
// PruneTip returns the block height and hash of the latest block that has been
2521
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2522
// to tell if the graph is currently in sync with the current best known UTXO
2523
// state.
2524
//
2525
// NOTE: part of the V1Store interface.
2526
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2527
        var (
×
2528
                ctx       = context.TODO()
×
2529
                tipHash   chainhash.Hash
×
2530
                tipHeight uint32
×
2531
        )
×
2532
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2533
                pruneTip, err := db.GetPruneTip(ctx)
×
2534
                if errors.Is(err, sql.ErrNoRows) {
×
2535
                        return ErrGraphNeverPruned
×
2536
                } else if err != nil {
×
2537
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2538
                }
×
2539

2540
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2541
                tipHeight = uint32(pruneTip.BlockHeight)
×
2542

×
2543
                return nil
×
2544
        }, sqldb.NoOpReset)
2545
        if err != nil {
×
2546
                return nil, 0, err
×
2547
        }
×
2548

2549
        return &tipHash, tipHeight, nil
×
2550
}
2551

2552
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2553
//
2554
// NOTE: this prunes nodes across protocol versions. It will never prune the
2555
// source nodes.
2556
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2557
        db SQLQueries) ([]route.Vertex, error) {
×
2558

×
2559
        // Fetch all un-connected nodes from the database.
×
2560
        // NOTE: this will not include any nodes that are listed in the
×
2561
        // source table.
×
2562
        nodes, err := db.GetUnconnectedNodes(ctx)
×
2563
        if err != nil {
×
2564
                return nil, fmt.Errorf("unable to fetch unconnected nodes: %w",
×
2565
                        err)
×
2566
        }
×
2567

2568
        prunedNodes := make([]route.Vertex, 0, len(nodes))
×
2569
        for _, node := range nodes {
×
2570
                // TODO(elle): update to use sqlc.slice() once that works.
×
2571
                if err = db.DeleteNode(ctx, node.ID); err != nil {
×
2572
                        return nil, fmt.Errorf("unable to delete "+
×
2573
                                "node(id=%d): %w", node.ID, err)
×
2574
                }
×
2575

2576
                pubKey, err := route.NewVertexFromBytes(node.PubKey)
×
2577
                if err != nil {
×
2578
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2579
                                "for node(id=%d): %w", node.ID, err)
×
2580
                }
×
2581

2582
                prunedNodes = append(prunedNodes, pubKey)
×
2583
        }
2584

2585
        return prunedNodes, nil
×
2586
}
2587

2588
// DisconnectBlockAtHeight is used to indicate that the block specified
2589
// by the passed height has been disconnected from the main chain. This
2590
// will "rewind" the graph back to the height below, deleting channels
2591
// that are no longer confirmed from the graph. The prune log will be
2592
// set to the last prune height valid for the remaining chain.
2593
// Channels that were removed from the graph resulting from the
2594
// disconnected block are returned.
2595
//
2596
// NOTE: part of the V1Store interface.
2597
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2598
        []*models.ChannelEdgeInfo, error) {
×
2599

×
2600
        ctx := context.TODO()
×
2601

×
2602
        var (
×
2603
                // Every channel having a ShortChannelID starting at 'height'
×
2604
                // will no longer be confirmed.
×
2605
                startShortChanID = lnwire.ShortChannelID{
×
2606
                        BlockHeight: height,
×
2607
                }
×
2608

×
2609
                // Delete everything after this height from the db up until the
×
2610
                // SCID alias range.
×
2611
                endShortChanID = aliasmgr.StartingAlias
×
2612

×
2613
                removedChans []*models.ChannelEdgeInfo
×
2614
        )
×
2615

×
2616
        var chanIDStart [8]byte
×
2617
        byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64())
×
2618
        var chanIDEnd [8]byte
×
2619
        byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64())
×
2620

×
2621
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2622
                rows, err := db.GetChannelsBySCIDRange(
×
2623
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2624
                                StartScid: chanIDStart[:],
×
2625
                                EndScid:   chanIDEnd[:],
×
2626
                        },
×
2627
                )
×
2628
                if err != nil {
×
2629
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2630
                }
×
2631

2632
                for _, row := range rows {
×
2633
                        node1, node2, err := buildNodeVertices(
×
2634
                                row.Node1PubKey, row.Node2PubKey,
×
2635
                        )
×
2636
                        if err != nil {
×
2637
                                return err
×
2638
                        }
×
2639

2640
                        channel, err := getAndBuildEdgeInfo(
×
2641
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2642
                                row.Channel, node1, node2,
×
2643
                        )
×
2644
                        if err != nil {
×
2645
                                return err
×
2646
                        }
×
2647

2648
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2649
                        if err != nil {
×
2650
                                return fmt.Errorf("unable to delete "+
×
2651
                                        "channel: %w", err)
×
2652
                        }
×
2653

2654
                        removedChans = append(removedChans, channel)
×
2655
                }
2656

2657
                return db.DeletePruneLogEntriesInRange(
×
2658
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2659
                                StartHeight: int64(height),
×
2660
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2661
                        },
×
2662
                )
×
2663
        }, func() {
×
2664
                removedChans = nil
×
2665
        })
×
2666
        if err != nil {
×
2667
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2668
                        "height: %w", err)
×
2669
        }
×
2670

2671
        for _, channel := range removedChans {
×
2672
                s.rejectCache.remove(channel.ChannelID)
×
2673
                s.chanCache.remove(channel.ChannelID)
×
2674
        }
×
2675

2676
        return removedChans, nil
×
2677
}
2678

2679
// AddEdgeProof sets the proof of an existing edge in the graph database.
2680
//
2681
// NOTE: part of the V1Store interface.
2682
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2683
        proof *models.ChannelAuthProof) error {
×
2684

×
2685
        var (
×
2686
                ctx       = context.TODO()
×
2687
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2688
        )
×
2689

×
2690
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2691
                res, err := db.AddV1ChannelProof(
×
2692
                        ctx, sqlc.AddV1ChannelProofParams{
×
2693
                                Scid:              scidBytes[:],
×
2694
                                Node1Signature:    proof.NodeSig1Bytes,
×
2695
                                Node2Signature:    proof.NodeSig2Bytes,
×
2696
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2697
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2698
                        },
×
2699
                )
×
2700
                if err != nil {
×
2701
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2702
                }
×
2703

2704
                n, err := res.RowsAffected()
×
2705
                if err != nil {
×
2706
                        return err
×
2707
                }
×
2708

2709
                if n == 0 {
×
2710
                        return fmt.Errorf("no rows affected when adding edge "+
×
2711
                                "proof for SCID %v", scid)
×
2712
                } else if n > 1 {
×
2713
                        return fmt.Errorf("multiple rows affected when adding "+
×
2714
                                "edge proof for SCID %v: %d rows affected",
×
2715
                                scid, n)
×
2716
                }
×
2717

2718
                return nil
×
2719
        }, sqldb.NoOpReset)
2720
        if err != nil {
×
2721
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2722
        }
×
2723

2724
        return nil
×
2725
}
2726

2727
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2728
// that we can ignore channel announcements that we know to be closed without
2729
// having to validate them and fetch a block.
2730
//
2731
// NOTE: part of the V1Store interface.
2732
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2733
        var (
×
2734
                ctx     = context.TODO()
×
2735
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2736
        )
×
2737

×
2738
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2739
                return db.InsertClosedChannel(ctx, chanIDB[:])
×
2740
        }, sqldb.NoOpReset)
×
2741
}
2742

2743
// IsClosedScid checks whether a channel identified by the passed in scid is
2744
// closed. This helps avoid having to perform expensive validation checks.
2745
//
2746
// NOTE: part of the V1Store interface.
2747
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2748
        var (
×
2749
                ctx      = context.TODO()
×
2750
                isClosed bool
×
2751
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2752
        )
×
2753
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2754
                var err error
×
2755
                isClosed, err = db.IsClosedChannel(ctx, chanIDB[:])
×
2756
                if err != nil {
×
2757
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2758
                                err)
×
2759
                }
×
2760

2761
                return nil
×
2762
        }, sqldb.NoOpReset)
2763
        if err != nil {
×
2764
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2765
                        err)
×
2766
        }
×
2767

2768
        return isClosed, nil
×
2769
}
2770

2771
// GraphSession will provide the call-back with access to a NodeTraverser
2772
// instance which can be used to perform queries against the channel graph.
2773
//
2774
// NOTE: part of the V1Store interface.
2775
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error) error {
×
2776
        var ctx = context.TODO()
×
2777

×
2778
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2779
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2780
        }, sqldb.NoOpReset)
×
2781
}
2782

2783
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2784
// read only transaction for a consistent view of the graph.
2785
type sqlNodeTraverser struct {
2786
        db    SQLQueries
2787
        chain chainhash.Hash
2788
}
2789

2790
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2791
// NodeTraverser interface.
2792
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2793

2794
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2795
func newSQLNodeTraverser(db SQLQueries,
2796
        chain chainhash.Hash) *sqlNodeTraverser {
×
2797

×
2798
        return &sqlNodeTraverser{
×
2799
                db:    db,
×
2800
                chain: chain,
×
2801
        }
×
2802
}
×
2803

2804
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2805
// node.
2806
//
2807
// NOTE: Part of the NodeTraverser interface.
2808
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2809
        cb func(channel *DirectedChannel) error) error {
×
2810

×
2811
        ctx := context.TODO()
×
2812

×
2813
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2814
}
×
2815

2816
// FetchNodeFeatures returns the features of the given node. If the node is
2817
// unknown, assume no additional features are supported.
2818
//
2819
// NOTE: Part of the NodeTraverser interface.
2820
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2821
        *lnwire.FeatureVector, error) {
×
2822

×
2823
        ctx := context.TODO()
×
2824

×
2825
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2826
}
×
2827

2828
// forEachNodeDirectedChannel iterates through all channels of a given
2829
// node, executing the passed callback on the directed edge representing the
2830
// channel and its incoming policy. If the node is not found, no error is
2831
// returned.
2832
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2833
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2834

×
2835
        toNodeCallback := func() route.Vertex {
×
2836
                return nodePub
×
2837
        }
×
2838

2839
        dbID, err := db.GetNodeIDByPubKey(
×
2840
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2841
                        Version: int16(ProtocolV1),
×
2842
                        PubKey:  nodePub[:],
×
2843
                },
×
2844
        )
×
2845
        if errors.Is(err, sql.ErrNoRows) {
×
2846
                return nil
×
2847
        } else if err != nil {
×
2848
                return fmt.Errorf("unable to fetch node: %w", err)
×
2849
        }
×
2850

2851
        rows, err := db.ListChannelsByNodeID(
×
2852
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2853
                        Version: int16(ProtocolV1),
×
2854
                        NodeID1: dbID,
×
2855
                },
×
2856
        )
×
2857
        if err != nil {
×
2858
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2859
        }
×
2860

2861
        // Exit early if there are no channels for this node so we don't
2862
        // do the unnecessary feature fetching.
2863
        if len(rows) == 0 {
×
2864
                return nil
×
2865
        }
×
2866

2867
        features, err := getNodeFeatures(ctx, db, dbID)
×
2868
        if err != nil {
×
2869
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2870
        }
×
2871

2872
        for _, row := range rows {
×
2873
                node1, node2, err := buildNodeVertices(
×
2874
                        row.Node1Pubkey, row.Node2Pubkey,
×
2875
                )
×
2876
                if err != nil {
×
2877
                        return fmt.Errorf("unable to build node vertices: %w",
×
2878
                                err)
×
2879
                }
×
2880

2881
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
2882

×
2883
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2884
                if err != nil {
×
2885
                        return err
×
2886
                }
×
2887

2888
                var p1, p2 *models.CachedEdgePolicy
×
2889
                if dbPol1 != nil {
×
2890
                        policy1, err := buildChanPolicy(
×
NEW
2891
                                *dbPol1, edge.ChannelID, nil, node2,
×
2892
                        )
×
2893
                        if err != nil {
×
2894
                                return err
×
2895
                        }
×
2896

2897
                        p1 = models.NewCachedPolicy(policy1)
×
2898
                }
2899
                if dbPol2 != nil {
×
2900
                        policy2, err := buildChanPolicy(
×
NEW
2901
                                *dbPol2, edge.ChannelID, nil, node1,
×
2902
                        )
×
2903
                        if err != nil {
×
2904
                                return err
×
2905
                        }
×
2906

2907
                        p2 = models.NewCachedPolicy(policy2)
×
2908
                }
2909

2910
                // Determine the outgoing and incoming policy for this
2911
                // channel and node combo.
2912
                outPolicy, inPolicy := p1, p2
×
2913
                if p1 != nil && node2 == nodePub {
×
2914
                        outPolicy, inPolicy = p2, p1
×
2915
                } else if p2 != nil && node1 != nodePub {
×
2916
                        outPolicy, inPolicy = p2, p1
×
2917
                }
×
2918

2919
                var cachedInPolicy *models.CachedEdgePolicy
×
2920
                if inPolicy != nil {
×
2921
                        cachedInPolicy = inPolicy
×
2922
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2923
                        cachedInPolicy.ToNodeFeatures = features
×
2924
                }
×
2925

2926
                directedChannel := &DirectedChannel{
×
2927
                        ChannelID:    edge.ChannelID,
×
2928
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2929
                        OtherNode:    edge.NodeKey2Bytes,
×
2930
                        Capacity:     edge.Capacity,
×
2931
                        OutPolicySet: outPolicy != nil,
×
2932
                        InPolicy:     cachedInPolicy,
×
2933
                }
×
2934
                if outPolicy != nil {
×
2935
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2936
                                directedChannel.InboundFee = fee
×
2937
                        })
×
2938
                }
2939

2940
                if nodePub == edge.NodeKey2Bytes {
×
2941
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2942
                }
×
2943

2944
                if err := cb(directedChannel); err != nil {
×
2945
                        return err
×
2946
                }
×
2947
        }
2948

2949
        return nil
×
2950
}
2951

2952
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2953
// and executes the provided callback for each node.
2954
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
2955
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
2956

×
2957
        lastID := int64(-1)
×
2958

×
2959
        for {
×
2960
                nodes, err := db.ListNodeIDsAndPubKeys(
×
2961
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
2962
                                Version: int16(ProtocolV1),
×
2963
                                ID:      lastID,
×
2964
                                Limit:   pageSize,
×
2965
                        },
×
2966
                )
×
2967
                if err != nil {
×
2968
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
2969
                }
×
2970

2971
                if len(nodes) == 0 {
×
2972
                        break
×
2973
                }
2974

2975
                for _, node := range nodes {
×
2976
                        var pub route.Vertex
×
2977
                        copy(pub[:], node.PubKey)
×
2978

×
2979
                        if err := cb(node.ID, pub); err != nil {
×
2980
                                return fmt.Errorf("forEachNodeCacheable "+
×
2981
                                        "callback failed for node(id=%d): %w",
×
2982
                                        node.ID, err)
×
2983
                        }
×
2984

2985
                        lastID = node.ID
×
2986
                }
2987
        }
2988

2989
        return nil
×
2990
}
2991

2992
// forEachNodeChannel iterates through all channels of a node, executing
2993
// the passed callback on each. The call-back is provided with the channel's
2994
// edge information, the outgoing policy and the incoming policy for the
2995
// channel and node combo.
2996
func forEachNodeChannel(ctx context.Context, db SQLQueries,
2997
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
2998
                *models.ChannelEdgePolicy,
2999
                *models.ChannelEdgePolicy) error) error {
×
3000

×
3001
        // Get all the V1 channels for this node.Add commentMore actions
×
3002
        rows, err := db.ListChannelsByNodeID(
×
3003
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3004
                        Version: int16(ProtocolV1),
×
3005
                        NodeID1: id,
×
3006
                },
×
3007
        )
×
3008
        if err != nil {
×
3009
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3010
        }
×
3011

3012
        // Call the call-back for each channel and its known policies.
3013
        for _, row := range rows {
×
3014
                node1, node2, err := buildNodeVertices(
×
3015
                        row.Node1Pubkey, row.Node2Pubkey,
×
3016
                )
×
3017
                if err != nil {
×
3018
                        return fmt.Errorf("unable to build node vertices: %w",
×
3019
                                err)
×
3020
                }
×
3021

3022
                edge, err := getAndBuildEdgeInfo(
×
3023
                        ctx, db, chain, row.Channel.ID, row.Channel, node1,
×
3024
                        node2,
×
3025
                )
×
3026
                if err != nil {
×
3027
                        return fmt.Errorf("unable to build channel info: %w",
×
3028
                                err)
×
3029
                }
×
3030

3031
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3032
                if err != nil {
×
3033
                        return fmt.Errorf("unable to extract channel "+
×
3034
                                "policies: %w", err)
×
3035
                }
×
3036

3037
                p1, p2, err := getAndBuildChanPolicies(
×
3038
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3039
                )
×
3040
                if err != nil {
×
3041
                        return fmt.Errorf("unable to build channel "+
×
3042
                                "policies: %w", err)
×
3043
                }
×
3044

3045
                // Determine the outgoing and incoming policy for this
3046
                // channel and node combo.
3047
                p1ToNode := row.Channel.NodeID2
×
3048
                p2ToNode := row.Channel.NodeID1
×
3049
                outPolicy, inPolicy := p1, p2
×
3050
                if (p1 != nil && p1ToNode == id) ||
×
3051
                        (p2 != nil && p2ToNode != id) {
×
3052

×
3053
                        outPolicy, inPolicy = p2, p1
×
3054
                }
×
3055

3056
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3057
                        return err
×
3058
                }
×
3059
        }
3060

3061
        return nil
×
3062
}
3063

3064
// updateChanEdgePolicy upserts the channel policy info we have stored for
3065
// a channel we already know of.
3066
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3067
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3068
        error) {
×
3069

×
3070
        var (
×
3071
                node1Pub, node2Pub route.Vertex
×
3072
                isNode1            bool
×
3073
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3074
        )
×
3075

×
3076
        // Check that this edge policy refers to a channel that we already
×
3077
        // know of. We do this explicitly so that we can return the appropriate
×
3078
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3079
        // abort the transaction which would abort the entire batch.
×
3080
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3081
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3082
                        Scid:    chanIDB[:],
×
3083
                        Version: int16(ProtocolV1),
×
3084
                },
×
3085
        )
×
3086
        if errors.Is(err, sql.ErrNoRows) {
×
3087
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3088
        } else if err != nil {
×
3089
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3090
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3091
        }
×
3092

3093
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3094
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3095

×
3096
        // Figure out which node this edge is from.
×
3097
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3098
        nodeID := dbChan.NodeID1
×
3099
        if !isNode1 {
×
3100
                nodeID = dbChan.NodeID2
×
3101
        }
×
3102

3103
        var (
×
3104
                inboundBase sql.NullInt64
×
3105
                inboundRate sql.NullInt64
×
3106
        )
×
3107
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3108
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3109
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3110
        })
×
3111

3112
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3113
                Version:     int16(ProtocolV1),
×
3114
                ChannelID:   dbChan.ID,
×
3115
                NodeID:      nodeID,
×
3116
                Timelock:    int32(edge.TimeLockDelta),
×
3117
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3118
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3119
                MinHtlcMsat: int64(edge.MinHTLC),
×
3120
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3121
                Disabled: sql.NullBool{
×
3122
                        Valid: true,
×
3123
                        Bool:  edge.IsDisabled(),
×
3124
                },
×
3125
                MaxHtlcMsat: sql.NullInt64{
×
3126
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3127
                        Int64: int64(edge.MaxHTLC),
×
3128
                },
×
NEW
3129
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
NEW
3130
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3131
                InboundBaseFeeMsat:      inboundBase,
×
3132
                InboundFeeRateMilliMsat: inboundRate,
×
3133
                Signature:               edge.SigBytes,
×
3134
        })
×
3135
        if err != nil {
×
3136
                return node1Pub, node2Pub, isNode1,
×
3137
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3138
        }
×
3139

3140
        // Convert the flat extra opaque data into a map of TLV types to
3141
        // values.
3142
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3143
        if err != nil {
×
3144
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3145
                        "marshal extra opaque data: %w", err)
×
3146
        }
×
3147

3148
        // Update the channel policy's extra signed fields.
3149
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3150
        if err != nil {
×
3151
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3152
                        "policy extra TLVs: %w", err)
×
3153
        }
×
3154

3155
        return node1Pub, node2Pub, isNode1, nil
×
3156
}
3157

3158
// getNodeByPubKey attempts to look up a target node by its public key.
3159
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3160
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3161

×
3162
        dbNode, err := db.GetNodeByPubKey(
×
3163
                ctx, sqlc.GetNodeByPubKeyParams{
×
3164
                        Version: int16(ProtocolV1),
×
3165
                        PubKey:  pubKey[:],
×
3166
                },
×
3167
        )
×
3168
        if errors.Is(err, sql.ErrNoRows) {
×
3169
                return 0, nil, ErrGraphNodeNotFound
×
3170
        } else if err != nil {
×
3171
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3172
        }
×
3173

3174
        node, err := buildNode(ctx, db, &dbNode)
×
3175
        if err != nil {
×
3176
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3177
        }
×
3178

3179
        return dbNode.ID, node, nil
×
3180
}
3181

3182
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3183
// provided database channel row and the public keys of the two nodes
3184
// involved in the channel.
3185
func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub,
3186
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3187

×
3188
        return &models.CachedEdgeInfo{
×
3189
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
3190
                NodeKey1Bytes: node1Pub,
×
3191
                NodeKey2Bytes: node2Pub,
×
3192
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
3193
        }
×
3194
}
×
3195

3196
// buildNode constructs a LightningNode instance from the given database node
3197
// record. The node's features, addresses and extra signed fields are also
3198
// fetched from the database and set on the node.
3199
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
3200
        *models.LightningNode, error) {
×
3201

×
3202
        if dbNode.Version != int16(ProtocolV1) {
×
3203
                return nil, fmt.Errorf("unsupported node version: %d",
×
3204
                        dbNode.Version)
×
3205
        }
×
3206

3207
        var pub [33]byte
×
3208
        copy(pub[:], dbNode.PubKey)
×
3209

×
3210
        node := &models.LightningNode{
×
3211
                PubKeyBytes: pub,
×
3212
                Features:    lnwire.EmptyFeatureVector(),
×
3213
                LastUpdate:  time.Unix(0, 0),
×
3214
        }
×
3215

×
3216
        if len(dbNode.Signature) == 0 {
×
3217
                return node, nil
×
3218
        }
×
3219

3220
        node.HaveNodeAnnouncement = true
×
3221
        node.AuthSigBytes = dbNode.Signature
×
3222
        node.Alias = dbNode.Alias.String
×
3223
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3224

×
3225
        var err error
×
3226
        if dbNode.Color.Valid {
×
3227
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3228
                if err != nil {
×
3229
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3230
                                err)
×
3231
                }
×
3232
        }
3233

3234
        // Fetch the node's features.
3235
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
3236
        if err != nil {
×
3237
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3238
                        "features: %w", dbNode.ID, err)
×
3239
        }
×
3240

3241
        // Fetch the node's addresses.
3242
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
3243
        if err != nil {
×
3244
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3245
                        "addresses: %w", dbNode.ID, err)
×
3246
        }
×
3247

3248
        // Fetch the node's extra signed fields.
3249
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
3250
        if err != nil {
×
3251
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3252
                        "extra signed fields: %w", dbNode.ID, err)
×
3253
        }
×
3254

3255
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
3256
        if err != nil {
×
3257
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3258
                        "fields: %w", err)
×
3259
        }
×
3260

3261
        if len(recs) != 0 {
×
3262
                node.ExtraOpaqueData = recs
×
3263
        }
×
3264

3265
        return node, nil
×
3266
}
3267

3268
// getNodeFeatures fetches the feature bits and constructs the feature vector
3269
// for a node with the given DB ID.
3270
func getNodeFeatures(ctx context.Context, db SQLQueries,
3271
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3272

×
3273
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3274
        if err != nil {
×
3275
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3276
                        nodeID, err)
×
3277
        }
×
3278

3279
        features := lnwire.EmptyFeatureVector()
×
3280
        for _, feature := range rows {
×
3281
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3282
        }
×
3283

3284
        return features, nil
×
3285
}
3286

3287
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3288
// given DB ID.
3289
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3290
        nodeID int64) (map[uint64][]byte, error) {
×
3291

×
3292
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3293
        if err != nil {
×
3294
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
3295
                        "signed fields: %w", nodeID, err)
×
3296
        }
×
3297

3298
        extraFields := make(map[uint64][]byte)
×
3299
        for _, field := range fields {
×
3300
                extraFields[uint64(field.Type)] = field.Value
×
3301
        }
×
3302

3303
        return extraFields, nil
×
3304
}
3305

3306
// upsertNode upserts the node record into the database. If the node already
3307
// exists, then the node's information is updated. If the node doesn't exist,
3308
// then a new node is created. The node's features, addresses and extra TLV
3309
// types are also updated. The node's DB ID is returned.
3310
func upsertNode(ctx context.Context, db SQLQueries,
3311
        node *models.LightningNode) (int64, error) {
×
3312

×
3313
        params := sqlc.UpsertNodeParams{
×
3314
                Version: int16(ProtocolV1),
×
3315
                PubKey:  node.PubKeyBytes[:],
×
3316
        }
×
3317

×
3318
        if node.HaveNodeAnnouncement {
×
3319
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3320
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3321
                params.Alias = sqldb.SQLStr(node.Alias)
×
3322
                params.Signature = node.AuthSigBytes
×
3323
        }
×
3324

3325
        nodeID, err := db.UpsertNode(ctx, params)
×
3326
        if err != nil {
×
3327
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3328
                        err)
×
3329
        }
×
3330

3331
        // We can exit here if we don't have the announcement yet.
3332
        if !node.HaveNodeAnnouncement {
×
3333
                return nodeID, nil
×
3334
        }
×
3335

3336
        // Update the node's features.
3337
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3338
        if err != nil {
×
3339
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3340
        }
×
3341

3342
        // Update the node's addresses.
3343
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3344
        if err != nil {
×
3345
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3346
        }
×
3347

3348
        // Convert the flat extra opaque data into a map of TLV types to
3349
        // values.
3350
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3351
        if err != nil {
×
3352
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3353
                        err)
×
3354
        }
×
3355

3356
        // Update the node's extra signed fields.
3357
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3358
        if err != nil {
×
3359
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3360
        }
×
3361

3362
        return nodeID, nil
×
3363
}
3364

3365
// upsertNodeFeatures updates the node's features node_features table. This
3366
// includes deleting any feature bits no longer present and inserting any new
3367
// feature bits. If the feature bit does not yet exist in the features table,
3368
// then an entry is created in that table first.
3369
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3370
        features *lnwire.FeatureVector) error {
×
3371

×
3372
        // Get any existing features for the node.
×
3373
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3374
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3375
                return err
×
3376
        }
×
3377

3378
        // Copy the nodes latest set of feature bits.
3379
        newFeatures := make(map[int32]struct{})
×
3380
        if features != nil {
×
3381
                for feature := range features.Features() {
×
3382
                        newFeatures[int32(feature)] = struct{}{}
×
3383
                }
×
3384
        }
3385

3386
        // For any current feature that already exists in the DB, remove it from
3387
        // the in-memory map. For any existing feature that does not exist in
3388
        // the in-memory map, delete it from the database.
3389
        for _, feature := range existingFeatures {
×
3390
                // The feature is still present, so there are no updates to be
×
3391
                // made.
×
3392
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3393
                        delete(newFeatures, feature.FeatureBit)
×
3394
                        continue
×
3395
                }
3396

3397
                // The feature is no longer present, so we remove it from the
3398
                // database.
3399
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3400
                        NodeID:     nodeID,
×
3401
                        FeatureBit: feature.FeatureBit,
×
3402
                })
×
3403
                if err != nil {
×
3404
                        return fmt.Errorf("unable to delete node(%d) "+
×
3405
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3406
                                err)
×
3407
                }
×
3408
        }
3409

3410
        // Any remaining entries in newFeatures are new features that need to be
3411
        // added to the database for the first time.
3412
        for feature := range newFeatures {
×
3413
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3414
                        NodeID:     nodeID,
×
3415
                        FeatureBit: feature,
×
3416
                })
×
3417
                if err != nil {
×
3418
                        return fmt.Errorf("unable to insert node(%d) "+
×
3419
                                "feature(%v): %w", nodeID, feature, err)
×
3420
                }
×
3421
        }
3422

3423
        return nil
×
3424
}
3425

3426
// fetchNodeFeatures fetches the features for a node with the given public key.
3427
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3428
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3429

×
3430
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3431
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3432
                        PubKey:  nodePub[:],
×
3433
                        Version: int16(ProtocolV1),
×
3434
                },
×
3435
        )
×
3436
        if err != nil {
×
3437
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3438
                        nodePub, err)
×
3439
        }
×
3440

3441
        features := lnwire.EmptyFeatureVector()
×
3442
        for _, bit := range rows {
×
3443
                features.Set(lnwire.FeatureBit(bit))
×
3444
        }
×
3445

3446
        return features, nil
×
3447
}
3448

3449
// dbAddressType is an enum type that represents the different address types
3450
// that we store in the node_addresses table. The address type determines how
3451
// the address is to be serialised/deserialize.
3452
type dbAddressType uint8
3453

3454
const (
3455
        addressTypeIPv4   dbAddressType = 1
3456
        addressTypeIPv6   dbAddressType = 2
3457
        addressTypeTorV2  dbAddressType = 3
3458
        addressTypeTorV3  dbAddressType = 4
3459
        addressTypeOpaque dbAddressType = math.MaxInt8
3460
)
3461

3462
// upsertNodeAddresses updates the node's addresses in the database. This
3463
// includes deleting any existing addresses and inserting the new set of
3464
// addresses. The deletion is necessary since the ordering of the addresses may
3465
// change, and we need to ensure that the database reflects the latest set of
3466
// addresses so that at the time of reconstructing the node announcement, the
3467
// order is preserved and the signature over the message remains valid.
3468
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3469
        addresses []net.Addr) error {
×
3470

×
3471
        // Delete any existing addresses for the node. This is required since
×
3472
        // even if the new set of addresses is the same, the ordering may have
×
3473
        // changed for a given address type.
×
3474
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3475
        if err != nil {
×
3476
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3477
                        nodeID, err)
×
3478
        }
×
3479

3480
        // Copy the nodes latest set of addresses.
3481
        newAddresses := map[dbAddressType][]string{
×
3482
                addressTypeIPv4:   {},
×
3483
                addressTypeIPv6:   {},
×
3484
                addressTypeTorV2:  {},
×
3485
                addressTypeTorV3:  {},
×
3486
                addressTypeOpaque: {},
×
3487
        }
×
3488
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3489
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3490
        }
×
3491

3492
        for _, address := range addresses {
×
3493
                switch addr := address.(type) {
×
3494
                case *net.TCPAddr:
×
3495
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3496
                                addAddr(addressTypeIPv4, addr)
×
3497
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3498
                                addAddr(addressTypeIPv6, addr)
×
3499
                        } else {
×
3500
                                return fmt.Errorf("unhandled IP address: %v",
×
3501
                                        addr)
×
3502
                        }
×
3503

3504
                case *tor.OnionAddr:
×
3505
                        switch len(addr.OnionService) {
×
3506
                        case tor.V2Len:
×
3507
                                addAddr(addressTypeTorV2, addr)
×
3508
                        case tor.V3Len:
×
3509
                                addAddr(addressTypeTorV3, addr)
×
3510
                        default:
×
3511
                                return fmt.Errorf("invalid length for a tor " +
×
3512
                                        "address")
×
3513
                        }
3514

3515
                case *lnwire.OpaqueAddrs:
×
3516
                        addAddr(addressTypeOpaque, addr)
×
3517

3518
                default:
×
3519
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3520
                }
3521
        }
3522

3523
        // Any remaining entries in newAddresses are new addresses that need to
3524
        // be added to the database for the first time.
3525
        for addrType, addrList := range newAddresses {
×
3526
                for position, addr := range addrList {
×
3527
                        err := db.InsertNodeAddress(
×
3528
                                ctx, sqlc.InsertNodeAddressParams{
×
3529
                                        NodeID:   nodeID,
×
3530
                                        Type:     int16(addrType),
×
3531
                                        Address:  addr,
×
3532
                                        Position: int32(position),
×
3533
                                },
×
3534
                        )
×
3535
                        if err != nil {
×
3536
                                return fmt.Errorf("unable to insert "+
×
3537
                                        "node(%d) address(%v): %w", nodeID,
×
3538
                                        addr, err)
×
3539
                        }
×
3540
                }
3541
        }
3542

3543
        return nil
×
3544
}
3545

3546
// getNodeAddresses fetches the addresses for a node with the given public key.
3547
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3548
        []net.Addr, error) {
×
3549

×
3550
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
3551
        // are returned in the same order as they were inserted.
×
3552
        rows, err := db.GetNodeAddressesByPubKey(
×
3553
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
3554
                        Version: int16(ProtocolV1),
×
3555
                        PubKey:  nodePub,
×
3556
                },
×
3557
        )
×
3558
        if err != nil {
×
3559
                return false, nil, err
×
3560
        }
×
3561

3562
        // GetNodeAddressesByPubKey uses a left join so there should always be
3563
        // at least one row returned if the node exists even if it has no
3564
        // addresses.
3565
        if len(rows) == 0 {
×
3566
                return false, nil, nil
×
3567
        }
×
3568

3569
        addresses := make([]net.Addr, 0, len(rows))
×
3570
        for _, addr := range rows {
×
3571
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3572
                        continue
×
3573
                }
3574

3575
                address := addr.Address.String
×
3576

×
3577
                switch dbAddressType(addr.Type.Int16) {
×
3578
                case addressTypeIPv4:
×
3579
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3580
                        if err != nil {
×
3581
                                return false, nil, nil
×
3582
                        }
×
3583
                        tcp.IP = tcp.IP.To4()
×
3584

×
3585
                        addresses = append(addresses, tcp)
×
3586

3587
                case addressTypeIPv6:
×
3588
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3589
                        if err != nil {
×
3590
                                return false, nil, nil
×
3591
                        }
×
3592
                        addresses = append(addresses, tcp)
×
3593

3594
                case addressTypeTorV3, addressTypeTorV2:
×
3595
                        service, portStr, err := net.SplitHostPort(address)
×
3596
                        if err != nil {
×
3597
                                return false, nil, fmt.Errorf("unable to "+
×
3598
                                        "split tor v3 address: %v",
×
3599
                                        addr.Address)
×
3600
                        }
×
3601

3602
                        port, err := strconv.Atoi(portStr)
×
3603
                        if err != nil {
×
3604
                                return false, nil, err
×
3605
                        }
×
3606

3607
                        addresses = append(addresses, &tor.OnionAddr{
×
3608
                                OnionService: service,
×
3609
                                Port:         port,
×
3610
                        })
×
3611

3612
                case addressTypeOpaque:
×
3613
                        opaque, err := hex.DecodeString(address)
×
3614
                        if err != nil {
×
3615
                                return false, nil, fmt.Errorf("unable to "+
×
3616
                                        "decode opaque address: %v", addr)
×
3617
                        }
×
3618

3619
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3620
                                Payload: opaque,
×
3621
                        })
×
3622

3623
                default:
×
3624
                        return false, nil, fmt.Errorf("unknown address "+
×
3625
                                "type: %v", addr.Type)
×
3626
                }
3627
        }
3628

3629
        return true, addresses, nil
×
3630
}
3631

3632
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3633
// database. This includes updating any existing types, inserting any new types,
3634
// and deleting any types that are no longer present.
3635
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3636
        nodeID int64, extraFields map[uint64][]byte) error {
×
3637

×
3638
        // Get any existing extra signed fields for the node.
×
3639
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3640
        if err != nil {
×
3641
                return err
×
3642
        }
×
3643

3644
        // Make a lookup map of the existing field types so that we can use it
3645
        // to keep track of any fields we should delete.
3646
        m := make(map[uint64]bool)
×
3647
        for _, field := range existingFields {
×
3648
                m[uint64(field.Type)] = true
×
3649
        }
×
3650

3651
        // For all the new fields, we'll upsert them and remove them from the
3652
        // map of existing fields.
3653
        for tlvType, value := range extraFields {
×
3654
                err = db.UpsertNodeExtraType(
×
3655
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3656
                                NodeID: nodeID,
×
3657
                                Type:   int64(tlvType),
×
3658
                                Value:  value,
×
3659
                        },
×
3660
                )
×
3661
                if err != nil {
×
3662
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3663
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3664
                }
×
3665

3666
                // Remove the field from the map of existing fields if it was
3667
                // present.
3668
                delete(m, tlvType)
×
3669
        }
3670

3671
        // For all the fields that are left in the map of existing fields, we'll
3672
        // delete them as they are no longer present in the new set of fields.
3673
        for tlvType := range m {
×
3674
                err = db.DeleteExtraNodeType(
×
3675
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3676
                                NodeID: nodeID,
×
3677
                                Type:   int64(tlvType),
×
3678
                        },
×
3679
                )
×
3680
                if err != nil {
×
3681
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3682
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3683
                }
×
3684
        }
3685

3686
        return nil
×
3687
}
3688

3689
// srcNodeInfo holds the information about the source node of the graph.
3690
type srcNodeInfo struct {
3691
        // id is the DB level ID of the source node entry in the "nodes" table.
3692
        id int64
3693

3694
        // pub is the public key of the source node.
3695
        pub route.Vertex
3696
}
3697

3698
// getSourceNode returns the DB node ID and pub key of the source node for the
3699
// specified protocol version.
3700
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3701
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3702

×
3703
        s.srcNodeMu.Lock()
×
3704
        defer s.srcNodeMu.Unlock()
×
3705

×
3706
        // If we already have the source node ID and pub key cached, then
×
3707
        // return them.
×
3708
        if info, ok := s.srcNodes[version]; ok {
×
3709
                return info.id, info.pub, nil
×
3710
        }
×
3711

3712
        var pubKey route.Vertex
×
3713

×
3714
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3715
        if err != nil {
×
3716
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3717
                        err)
×
3718
        }
×
3719

3720
        if len(nodes) == 0 {
×
3721
                return 0, pubKey, ErrSourceNodeNotSet
×
3722
        } else if len(nodes) > 1 {
×
3723
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3724
                        "protocol %s found", version)
×
3725
        }
×
3726

3727
        copy(pubKey[:], nodes[0].PubKey)
×
3728

×
3729
        s.srcNodes[version] = &srcNodeInfo{
×
3730
                id:  nodes[0].NodeID,
×
3731
                pub: pubKey,
×
3732
        }
×
3733

×
3734
        return nodes[0].NodeID, pubKey, nil
×
3735
}
3736

3737
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3738
// This then produces a map from TLV type to value. If the input is not a
3739
// valid TLV stream, then an error is returned.
3740
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3741
        r := bytes.NewReader(data)
×
3742

×
3743
        tlvStream, err := tlv.NewStream()
×
3744
        if err != nil {
×
3745
                return nil, err
×
3746
        }
×
3747

3748
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3749
        // pass it into the P2P decoding variant.
3750
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3751
        if err != nil {
×
3752
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3753
        }
×
3754
        if len(parsedTypes) == 0 {
×
3755
                return nil, nil
×
3756
        }
×
3757

3758
        records := make(map[uint64][]byte)
×
3759
        for k, v := range parsedTypes {
×
3760
                records[uint64(k)] = v
×
3761
        }
×
3762

3763
        return records, nil
×
3764
}
3765

3766
// insertChannel inserts a new channel record into the database.
3767
func insertChannel(ctx context.Context, db SQLQueries,
3768
        edge *models.ChannelEdgeInfo) error {
×
3769

×
3770
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3771

×
3772
        // Make sure that the channel doesn't already exist. We do this
×
3773
        // explicitly instead of relying on catching a unique constraint error
×
3774
        // because relying on SQL to throw that error would abort the entire
×
3775
        // batch of transactions.
×
3776
        _, err := db.GetChannelBySCID(
×
3777
                ctx, sqlc.GetChannelBySCIDParams{
×
3778
                        Scid:    chanIDB[:],
×
3779
                        Version: int16(ProtocolV1),
×
3780
                },
×
3781
        )
×
3782
        if err == nil {
×
3783
                return ErrEdgeAlreadyExist
×
3784
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3785
                return fmt.Errorf("unable to fetch channel: %w", err)
×
3786
        }
×
3787

3788
        // Make sure that at least a "shell" entry for each node is present in
3789
        // the nodes table.
3790
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3791
        if err != nil {
×
3792
                return fmt.Errorf("unable to create shell node: %w", err)
×
3793
        }
×
3794

3795
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3796
        if err != nil {
×
3797
                return fmt.Errorf("unable to create shell node: %w", err)
×
3798
        }
×
3799

3800
        var capacity sql.NullInt64
×
3801
        if edge.Capacity != 0 {
×
3802
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3803
        }
×
3804

3805
        createParams := sqlc.CreateChannelParams{
×
3806
                Version:     int16(ProtocolV1),
×
3807
                Scid:        chanIDB[:],
×
3808
                NodeID1:     node1DBID,
×
3809
                NodeID2:     node2DBID,
×
3810
                Outpoint:    edge.ChannelPoint.String(),
×
3811
                Capacity:    capacity,
×
3812
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3813
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3814
        }
×
3815

×
3816
        if edge.AuthProof != nil {
×
3817
                proof := edge.AuthProof
×
3818

×
3819
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3820
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3821
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3822
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3823
        }
×
3824

3825
        // Insert the new channel record.
3826
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3827
        if err != nil {
×
3828
                return err
×
3829
        }
×
3830

3831
        // Insert any channel features.
3832
        if len(edge.Features) != 0 {
×
3833
                chanFeatures := lnwire.NewRawFeatureVector()
×
3834
                err := chanFeatures.Decode(bytes.NewReader(edge.Features))
×
3835
                if err != nil {
×
3836
                        return err
×
3837
                }
×
3838

3839
                fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features)
×
3840
                for feature := range fv.Features() {
×
3841
                        err = db.InsertChannelFeature(
×
3842
                                ctx, sqlc.InsertChannelFeatureParams{
×
3843
                                        ChannelID:  dbChanID,
×
3844
                                        FeatureBit: int32(feature),
×
3845
                                },
×
3846
                        )
×
3847
                        if err != nil {
×
3848
                                return fmt.Errorf("unable to insert "+
×
3849
                                        "channel(%d) feature(%v): %w", dbChanID,
×
3850
                                        feature, err)
×
3851
                        }
×
3852
                }
3853
        }
3854

3855
        // Finally, insert any extra TLV fields in the channel announcement.
3856
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3857
        if err != nil {
×
3858
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
3859
                        err)
×
3860
        }
×
3861

3862
        for tlvType, value := range extra {
×
3863
                err := db.CreateChannelExtraType(
×
3864
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3865
                                ChannelID: dbChanID,
×
3866
                                Type:      int64(tlvType),
×
3867
                                Value:     value,
×
3868
                        },
×
3869
                )
×
3870
                if err != nil {
×
3871
                        return fmt.Errorf("unable to upsert channel(%d) extra "+
×
3872
                                "signed field(%v): %w", edge.ChannelID,
×
3873
                                tlvType, err)
×
3874
                }
×
3875
        }
3876

3877
        return nil
×
3878
}
3879

3880
// maybeCreateShellNode checks if a shell node entry exists for the
3881
// given public key. If it does not exist, then a new shell node entry is
3882
// created. The ID of the node is returned. A shell node only has a protocol
3883
// version and public key persisted.
3884
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3885
        pubKey route.Vertex) (int64, error) {
×
3886

×
3887
        dbNode, err := db.GetNodeByPubKey(
×
3888
                ctx, sqlc.GetNodeByPubKeyParams{
×
3889
                        PubKey:  pubKey[:],
×
3890
                        Version: int16(ProtocolV1),
×
3891
                },
×
3892
        )
×
3893
        // The node exists. Return the ID.
×
3894
        if err == nil {
×
3895
                return dbNode.ID, nil
×
3896
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3897
                return 0, err
×
3898
        }
×
3899

3900
        // Otherwise, the node does not exist, so we create a shell entry for
3901
        // it.
3902
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3903
                Version: int16(ProtocolV1),
×
3904
                PubKey:  pubKey[:],
×
3905
        })
×
3906
        if err != nil {
×
3907
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3908
        }
×
3909

3910
        return id, nil
×
3911
}
3912

3913
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3914
// the database. This includes deleting any existing types and then inserting
3915
// the new types.
3916
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3917
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3918

×
3919
        // Delete all existing extra signed fields for the channel policy.
×
3920
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3921
        if err != nil {
×
3922
                return fmt.Errorf("unable to delete "+
×
3923
                        "existing policy extra signed fields for policy %d: %w",
×
3924
                        chanPolicyID, err)
×
3925
        }
×
3926

3927
        // Insert all new extra signed fields for the channel policy.
3928
        for tlvType, value := range extraFields {
×
3929
                err = db.InsertChanPolicyExtraType(
×
3930
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3931
                                ChannelPolicyID: chanPolicyID,
×
3932
                                Type:            int64(tlvType),
×
3933
                                Value:           value,
×
3934
                        },
×
3935
                )
×
3936
                if err != nil {
×
3937
                        return fmt.Errorf("unable to insert "+
×
3938
                                "channel_policy(%d) extra signed field(%v): %w",
×
3939
                                chanPolicyID, tlvType, err)
×
3940
                }
×
3941
        }
3942

3943
        return nil
×
3944
}
3945

3946
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3947
// provided dbChanRow and also fetches any other required information
3948
// to construct the edge info.
3949
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
3950
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1,
3951
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3952

×
3953
        if dbChan.Version != int16(ProtocolV1) {
×
3954
                return nil, fmt.Errorf("unsupported channel version: %d",
×
3955
                        dbChan.Version)
×
3956
        }
×
3957

3958
        fv, extras, err := getChanFeaturesAndExtras(
×
3959
                ctx, db, dbChanID,
×
3960
        )
×
3961
        if err != nil {
×
3962
                return nil, err
×
3963
        }
×
3964

3965
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
3966
        if err != nil {
×
3967
                return nil, err
×
3968
        }
×
3969

3970
        var featureBuf bytes.Buffer
×
3971
        if err := fv.Encode(&featureBuf); err != nil {
×
3972
                return nil, fmt.Errorf("unable to encode features: %w", err)
×
3973
        }
×
3974

3975
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3976
        if err != nil {
×
3977
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3978
                        "fields: %w", err)
×
3979
        }
×
3980
        if recs == nil {
×
3981
                recs = make([]byte, 0)
×
3982
        }
×
3983

3984
        var btcKey1, btcKey2 route.Vertex
×
3985
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
3986
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
3987

×
3988
        channel := &models.ChannelEdgeInfo{
×
3989
                ChainHash:        chain,
×
3990
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
3991
                NodeKey1Bytes:    node1,
×
3992
                NodeKey2Bytes:    node2,
×
3993
                BitcoinKey1Bytes: btcKey1,
×
3994
                BitcoinKey2Bytes: btcKey2,
×
3995
                ChannelPoint:     *op,
×
3996
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
3997
                Features:         featureBuf.Bytes(),
×
3998
                ExtraOpaqueData:  recs,
×
3999
        }
×
4000

×
4001
        // We always set all the signatures at the same time, so we can
×
4002
        // safely check if one signature is present to determine if we have the
×
4003
        // rest of the signatures for the auth proof.
×
4004
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4005
                channel.AuthProof = &models.ChannelAuthProof{
×
4006
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4007
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4008
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4009
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4010
                }
×
4011
        }
×
4012

4013
        return channel, nil
×
4014
}
4015

4016
// buildNodeVertices is a helper that converts raw node public keys
4017
// into route.Vertex instances.
4018
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4019
        route.Vertex, error) {
×
4020

×
4021
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4022
        if err != nil {
×
4023
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4024
                        "create vertex from node1 pubkey: %w", err)
×
4025
        }
×
4026

4027
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4028
        if err != nil {
×
4029
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4030
                        "create vertex from node2 pubkey: %w", err)
×
4031
        }
×
4032

4033
        return node1Vertex, node2Vertex, nil
×
4034
}
4035

4036
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
4037
// for a channel with the given ID.
4038
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
4039
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
4040

×
4041
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
4042
        if err != nil {
×
4043
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
4044
                        "features and extras: %w", err)
×
4045
        }
×
4046

4047
        var (
×
4048
                fv     = lnwire.EmptyFeatureVector()
×
4049
                extras = make(map[uint64][]byte)
×
4050
        )
×
4051
        for _, row := range rows {
×
4052
                if row.IsFeature {
×
4053
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
4054

×
4055
                        continue
×
4056
                }
4057

4058
                tlvType, ok := row.ExtraKey.(int64)
×
4059
                if !ok {
×
4060
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4061
                                "TLV type: %T", row.ExtraKey)
×
4062
                }
×
4063

4064
                valueBytes, ok := row.Value.([]byte)
×
4065
                if !ok {
×
4066
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4067
                                "Value: %T", row.Value)
×
4068
                }
×
4069

4070
                extras[uint64(tlvType)] = valueBytes
×
4071
        }
4072

4073
        return fv, extras, nil
×
4074
}
4075

4076
// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves
4077
// all the extra info required to build the complete models.ChannelEdgePolicy
4078
// types. It returns two policies, which may be nil if the provided
4079
// sqlc.ChannelPolicy records are nil.
4080
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4081
        dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1,
4082
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4083
        *models.ChannelEdgePolicy, error) {
×
4084

×
4085
        if dbPol1 == nil && dbPol2 == nil {
×
4086
                return nil, nil, nil
×
4087
        }
×
4088

4089
        var (
×
4090
                policy1ID int64
×
4091
                policy2ID int64
×
4092
        )
×
4093
        if dbPol1 != nil {
×
4094
                policy1ID = dbPol1.ID
×
4095
        }
×
4096
        if dbPol2 != nil {
×
4097
                policy2ID = dbPol2.ID
×
4098
        }
×
4099
        rows, err := db.GetChannelPolicyExtraTypes(
×
4100
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
4101
                        ID:   policy1ID,
×
4102
                        ID_2: policy2ID,
×
4103
                },
×
4104
        )
×
4105
        if err != nil {
×
4106
                return nil, nil, err
×
4107
        }
×
4108

4109
        var (
×
4110
                dbPol1Extras = make(map[uint64][]byte)
×
4111
                dbPol2Extras = make(map[uint64][]byte)
×
4112
        )
×
4113
        for _, row := range rows {
×
4114
                switch row.PolicyID {
×
4115
                case policy1ID:
×
4116
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
4117
                case policy2ID:
×
4118
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
4119
                default:
×
4120
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
4121
                                "in row: %v", row.PolicyID, row)
×
4122
                }
4123
        }
4124

4125
        var pol1, pol2 *models.ChannelEdgePolicy
×
4126
        if dbPol1 != nil {
×
4127
                pol1, err = buildChanPolicy(
×
NEW
4128
                        *dbPol1, channelID, dbPol1Extras, node2,
×
4129
                )
×
4130
                if err != nil {
×
4131
                        return nil, nil, err
×
4132
                }
×
4133
        }
4134
        if dbPol2 != nil {
×
4135
                pol2, err = buildChanPolicy(
×
NEW
4136
                        *dbPol2, channelID, dbPol2Extras, node1,
×
4137
                )
×
4138
                if err != nil {
×
4139
                        return nil, nil, err
×
4140
                }
×
4141
        }
4142

4143
        return pol1, pol2, nil
×
4144
}
4145

4146
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4147
// provided sqlc.ChannelPolicy and other required information.
4148
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
4149
        extras map[uint64][]byte,
NEW
4150
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4151

×
4152
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4153
        if err != nil {
×
4154
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4155
                        "fields: %w", err)
×
4156
        }
×
4157

4158
        var inboundFee fn.Option[lnwire.Fee]
×
4159
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4160
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4161

×
4162
                inboundFee = fn.Some(lnwire.Fee{
×
4163
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4164
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4165
                })
×
4166
        }
×
4167

4168
        return &models.ChannelEdgePolicy{
×
4169
                SigBytes:  dbPolicy.Signature,
×
4170
                ChannelID: channelID,
×
4171
                LastUpdate: time.Unix(
×
4172
                        dbPolicy.LastUpdate.Int64, 0,
×
4173
                ),
×
NEW
4174
                MessageFlags: lnwire.ChanUpdateMsgFlags(
×
NEW
4175
                        dbPolicy.MessageFlags.Int16,
×
NEW
4176
                ),
×
NEW
4177
                ChannelFlags: lnwire.ChanUpdateChanFlags(
×
NEW
4178
                        dbPolicy.ChannelFlags.Int16,
×
NEW
4179
                ),
×
4180
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4181
                MinHTLC: lnwire.MilliSatoshi(
×
4182
                        dbPolicy.MinHtlcMsat,
×
4183
                ),
×
4184
                MaxHTLC: lnwire.MilliSatoshi(
×
4185
                        dbPolicy.MaxHtlcMsat.Int64,
×
4186
                ),
×
4187
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4188
                        dbPolicy.BaseFeeMsat,
×
4189
                ),
×
4190
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4191
                ToNode:                    toNode,
×
4192
                InboundFee:                inboundFee,
×
4193
                ExtraOpaqueData:           recs,
×
4194
        }, nil
×
4195
}
4196

4197
// buildNodes builds the models.LightningNode instances for the
4198
// given row which is expected to be a sqlc type that contains node information.
4199
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4200
        dbNode2 sqlc.Node) (*models.LightningNode, *models.LightningNode,
4201
        error) {
×
4202

×
4203
        node1, err := buildNode(ctx, db, &dbNode1)
×
4204
        if err != nil {
×
4205
                return nil, nil, err
×
4206
        }
×
4207

4208
        node2, err := buildNode(ctx, db, &dbNode2)
×
4209
        if err != nil {
×
4210
                return nil, nil, err
×
4211
        }
×
4212

4213
        return node1, node2, nil
×
4214
}
4215

4216
// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give
4217
// row which is expected to be a sqlc type that contains channel policy
4218
// information. It returns two policies, which may be nil if the policy
4219
// information is not present in the row.
4220
//
4221
//nolint:ll,dupl,funlen
4222
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
4223
        error) {
×
4224

×
4225
        var policy1, policy2 *sqlc.ChannelPolicy
×
4226
        switch r := row.(type) {
×
4227
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4228
                if r.Policy1ID.Valid {
×
4229
                        policy1 = &sqlc.ChannelPolicy{
×
4230
                                ID:                      r.Policy1ID.Int64,
×
4231
                                Version:                 r.Policy1Version.Int16,
×
4232
                                ChannelID:               r.Channel.ID,
×
4233
                                NodeID:                  r.Policy1NodeID.Int64,
×
4234
                                Timelock:                r.Policy1Timelock.Int32,
×
4235
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4236
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4237
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4238
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4239
                                LastUpdate:              r.Policy1LastUpdate,
×
4240
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4241
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4242
                                Disabled:                r.Policy1Disabled,
×
NEW
4243
                                MessageFlags:            r.Policy1MessageFlags,
×
NEW
4244
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4245
                                Signature:               r.Policy1Signature,
×
4246
                        }
×
4247
                }
×
4248
                if r.Policy2ID.Valid {
×
4249
                        policy2 = &sqlc.ChannelPolicy{
×
4250
                                ID:                      r.Policy2ID.Int64,
×
4251
                                Version:                 r.Policy2Version.Int16,
×
4252
                                ChannelID:               r.Channel.ID,
×
4253
                                NodeID:                  r.Policy2NodeID.Int64,
×
4254
                                Timelock:                r.Policy2Timelock.Int32,
×
4255
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4256
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4257
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4258
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4259
                                LastUpdate:              r.Policy2LastUpdate,
×
4260
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4261
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4262
                                Disabled:                r.Policy2Disabled,
×
NEW
4263
                                MessageFlags:            r.Policy2MessageFlags,
×
NEW
4264
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4265
                                Signature:               r.Policy2Signature,
×
4266
                        }
×
4267
                }
×
4268

4269
                return policy1, policy2, nil
×
4270

4271
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4272
                if r.Policy1ID.Valid {
×
4273
                        policy1 = &sqlc.ChannelPolicy{
×
4274
                                ID:                      r.Policy1ID.Int64,
×
4275
                                Version:                 r.Policy1Version.Int16,
×
4276
                                ChannelID:               r.Channel.ID,
×
4277
                                NodeID:                  r.Policy1NodeID.Int64,
×
4278
                                Timelock:                r.Policy1Timelock.Int32,
×
4279
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4280
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4281
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4282
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4283
                                LastUpdate:              r.Policy1LastUpdate,
×
4284
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4285
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4286
                                Disabled:                r.Policy1Disabled,
×
NEW
4287
                                MessageFlags:            r.Policy1MessageFlags,
×
NEW
4288
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4289
                                Signature:               r.Policy1Signature,
×
4290
                        }
×
4291
                }
×
4292
                if r.Policy2ID.Valid {
×
4293
                        policy2 = &sqlc.ChannelPolicy{
×
4294
                                ID:                      r.Policy2ID.Int64,
×
4295
                                Version:                 r.Policy2Version.Int16,
×
4296
                                ChannelID:               r.Channel.ID,
×
4297
                                NodeID:                  r.Policy2NodeID.Int64,
×
4298
                                Timelock:                r.Policy2Timelock.Int32,
×
4299
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4300
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4301
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4302
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4303
                                LastUpdate:              r.Policy2LastUpdate,
×
4304
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4305
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4306
                                Disabled:                r.Policy2Disabled,
×
NEW
4307
                                MessageFlags:            r.Policy2MessageFlags,
×
NEW
4308
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4309
                                Signature:               r.Policy2Signature,
×
4310
                        }
×
4311
                }
×
4312

4313
                return policy1, policy2, nil
×
4314

4315
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4316
                if r.Policy1ID.Valid {
×
4317
                        policy1 = &sqlc.ChannelPolicy{
×
4318
                                ID:                      r.Policy1ID.Int64,
×
4319
                                Version:                 r.Policy1Version.Int16,
×
4320
                                ChannelID:               r.Channel.ID,
×
4321
                                NodeID:                  r.Policy1NodeID.Int64,
×
4322
                                Timelock:                r.Policy1Timelock.Int32,
×
4323
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4324
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4325
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4326
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4327
                                LastUpdate:              r.Policy1LastUpdate,
×
4328
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4329
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4330
                                Disabled:                r.Policy1Disabled,
×
NEW
4331
                                MessageFlags:            r.Policy1MessageFlags,
×
NEW
4332
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4333
                                Signature:               r.Policy1Signature,
×
4334
                        }
×
4335
                }
×
4336
                if r.Policy2ID.Valid {
×
4337
                        policy2 = &sqlc.ChannelPolicy{
×
4338
                                ID:                      r.Policy2ID.Int64,
×
4339
                                Version:                 r.Policy2Version.Int16,
×
4340
                                ChannelID:               r.Channel.ID,
×
4341
                                NodeID:                  r.Policy2NodeID.Int64,
×
4342
                                Timelock:                r.Policy2Timelock.Int32,
×
4343
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4344
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4345
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4346
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4347
                                LastUpdate:              r.Policy2LastUpdate,
×
4348
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4349
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4350
                                Disabled:                r.Policy2Disabled,
×
NEW
4351
                                MessageFlags:            r.Policy2MessageFlags,
×
NEW
4352
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4353
                                Signature:               r.Policy2Signature,
×
4354
                        }
×
4355
                }
×
4356

4357
                return policy1, policy2, nil
×
4358

4359
        case sqlc.ListChannelsByNodeIDRow:
×
4360
                if r.Policy1ID.Valid {
×
4361
                        policy1 = &sqlc.ChannelPolicy{
×
4362
                                ID:                      r.Policy1ID.Int64,
×
4363
                                Version:                 r.Policy1Version.Int16,
×
4364
                                ChannelID:               r.Channel.ID,
×
4365
                                NodeID:                  r.Policy1NodeID.Int64,
×
4366
                                Timelock:                r.Policy1Timelock.Int32,
×
4367
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4368
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4369
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4370
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4371
                                LastUpdate:              r.Policy1LastUpdate,
×
4372
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4373
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4374
                                Disabled:                r.Policy1Disabled,
×
NEW
4375
                                MessageFlags:            r.Policy1MessageFlags,
×
NEW
4376
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4377
                                Signature:               r.Policy1Signature,
×
4378
                        }
×
4379
                }
×
4380
                if r.Policy2ID.Valid {
×
4381
                        policy2 = &sqlc.ChannelPolicy{
×
4382
                                ID:                      r.Policy2ID.Int64,
×
4383
                                Version:                 r.Policy2Version.Int16,
×
4384
                                ChannelID:               r.Channel.ID,
×
4385
                                NodeID:                  r.Policy2NodeID.Int64,
×
4386
                                Timelock:                r.Policy2Timelock.Int32,
×
4387
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4388
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4389
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4390
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4391
                                LastUpdate:              r.Policy2LastUpdate,
×
4392
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4393
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4394
                                Disabled:                r.Policy2Disabled,
×
NEW
4395
                                MessageFlags:            r.Policy2MessageFlags,
×
NEW
4396
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4397
                                Signature:               r.Policy2Signature,
×
4398
                        }
×
4399
                }
×
4400

4401
                return policy1, policy2, nil
×
4402

4403
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4404
                if r.Policy1ID.Valid {
×
4405
                        policy1 = &sqlc.ChannelPolicy{
×
4406
                                ID:                      r.Policy1ID.Int64,
×
4407
                                Version:                 r.Policy1Version.Int16,
×
4408
                                ChannelID:               r.Channel.ID,
×
4409
                                NodeID:                  r.Policy1NodeID.Int64,
×
4410
                                Timelock:                r.Policy1Timelock.Int32,
×
4411
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4412
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4413
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4414
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4415
                                LastUpdate:              r.Policy1LastUpdate,
×
4416
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4417
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4418
                                Disabled:                r.Policy1Disabled,
×
NEW
4419
                                MessageFlags:            r.Policy1MessageFlags,
×
NEW
4420
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4421
                                Signature:               r.Policy1Signature,
×
4422
                        }
×
4423
                }
×
4424
                if r.Policy2ID.Valid {
×
4425
                        policy2 = &sqlc.ChannelPolicy{
×
4426
                                ID:                      r.Policy2ID.Int64,
×
4427
                                Version:                 r.Policy2Version.Int16,
×
4428
                                ChannelID:               r.Channel.ID,
×
4429
                                NodeID:                  r.Policy2NodeID.Int64,
×
4430
                                Timelock:                r.Policy2Timelock.Int32,
×
4431
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4432
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4433
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4434
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4435
                                LastUpdate:              r.Policy2LastUpdate,
×
4436
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4437
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4438
                                Disabled:                r.Policy2Disabled,
×
NEW
4439
                                MessageFlags:            r.Policy2MessageFlags,
×
NEW
4440
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4441
                                Signature:               r.Policy2Signature,
×
4442
                        }
×
4443
                }
×
4444

4445
                return policy1, policy2, nil
×
4446
        default:
×
4447
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4448
                        "extractChannelPolicies: %T", r)
×
4449
        }
4450
}
4451

4452
// channelIDToBytes converts a channel ID (SCID) to a byte array
4453
// representation.
4454
func channelIDToBytes(channelID uint64) [8]byte {
×
4455
        var chanIDB [8]byte
×
4456
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4457

×
4458
        return chanIDB
×
4459
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc