• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15872517631

25 Jun 2025 09:22AM UTC coverage: 67.648% (-0.2%) from 67.8%
15872517631

Pull #9939

github

web-flow
Merge e875183c4 into 33e6f2854
Pull Request #9939: [15] graph/db: SQL prune log

0 of 386 new or added lines in 2 files covered. (0.0%)

83 existing lines in 18 files now uncovered.

134987 of 199542 relevant lines covered (67.65%)

21930.35 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        GetUnconnectedNodes(ctx context.Context) ([]sqlc.GetUnconnectedNodesRow, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
72
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
73
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
74

75
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
95
        GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error)
96
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
97
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
98
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
99
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
100
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
101
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
102
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
103
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
104
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
105
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
106
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
107
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
108
        DeleteChannel(ctx context.Context, id int64) error
109

110
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
111
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
112

113
        /*
114
                Channel Policy table queries.
115
        */
116
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
117
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.ChannelPolicy, error)
118
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
119

120
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
121
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
122
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
123

124
        /*
125
                Zombie index queries.
126
        */
127
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
128
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error)
129
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
130
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
131
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
132

133
        /*
134
                Prune log table queries.
135
        */
136
        GetPruneTip(ctx context.Context) (sqlc.PruneLog, error)
137
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
138
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
139
}
140

141
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
142
// database operations.
143
type BatchedSQLQueries interface {
144
        SQLQueries
145
        sqldb.BatchedTx[SQLQueries]
146
}
147

148
// SQLStore is an implementation of the V1Store interface that uses a SQL
149
// database as the backend.
150
//
151
// NOTE: currently, this temporarily embeds the KVStore struct so that we can
152
// implement the V1Store interface incrementally. For any method not
153
// implemented,  things will fall back to the KVStore. This is ONLY the case
154
// for the time being while this struct is purely used in unit tests only.
155
type SQLStore struct {
156
        cfg *SQLStoreConfig
157
        db  BatchedSQLQueries
158

159
        // cacheMu guards all caches (rejectCache and chanCache). If
160
        // this mutex will be acquired at the same time as the DB mutex then
161
        // the cacheMu MUST be acquired first to prevent deadlock.
162
        cacheMu     sync.RWMutex
163
        rejectCache *rejectCache
164
        chanCache   *channelCache
165

166
        chanScheduler batch.Scheduler[SQLQueries]
167
        nodeScheduler batch.Scheduler[SQLQueries]
168

169
        srcNodes  map[ProtocolVersion]*srcNodeInfo
170
        srcNodeMu sync.Mutex
171

172
        // Temporary fall-back to the KVStore so that we can implement the
173
        // interface incrementally.
174
        *KVStore
175
}
176

177
// A compile-time assertion to ensure that SQLStore implements the V1Store
178
// interface.
179
var _ V1Store = (*SQLStore)(nil)
180

181
// SQLStoreConfig holds the configuration for the SQLStore.
182
type SQLStoreConfig struct {
183
        // ChainHash is the genesis hash for the chain that all the gossip
184
        // messages in this store are aimed at.
185
        ChainHash chainhash.Hash
186
}
187

188
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
189
// storage backend.
190
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries, kvStore *KVStore,
191
        options ...StoreOptionModifier) (*SQLStore, error) {
×
192

×
193
        opts := DefaultOptions()
×
194
        for _, o := range options {
×
195
                o(opts)
×
196
        }
×
197

198
        if opts.NoMigration {
×
199
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
200
                        "supported for SQL stores")
×
201
        }
×
202

203
        s := &SQLStore{
×
204
                cfg:         cfg,
×
205
                db:          db,
×
206
                KVStore:     kvStore,
×
207
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
208
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
209
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
210
        }
×
211

×
212
        s.chanScheduler = batch.NewTimeScheduler(
×
213
                db, &s.cacheMu, opts.BatchCommitInterval,
×
214
        )
×
215
        s.nodeScheduler = batch.NewTimeScheduler(
×
216
                db, nil, opts.BatchCommitInterval,
×
217
        )
×
218

×
219
        return s, nil
×
220
}
221

222
// AddLightningNode adds a vertex/node to the graph database. If the node is not
223
// in the database from before, this will add a new, unconnected one to the
224
// graph. If it is present from before, this will update that node's
225
// information.
226
//
227
// NOTE: part of the V1Store interface.
228
func (s *SQLStore) AddLightningNode(ctx context.Context,
229
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
230

×
231
        r := &batch.Request[SQLQueries]{
×
232
                Opts: batch.NewSchedulerOptions(opts...),
×
233
                Do: func(queries SQLQueries) error {
×
234
                        _, err := upsertNode(ctx, queries, node)
×
235
                        return err
×
236
                },
×
237
        }
238

239
        return s.nodeScheduler.Execute(ctx, r)
×
240
}
241

242
// FetchLightningNode attempts to look up a target node by its identity public
243
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
244
// returned.
245
//
246
// NOTE: part of the V1Store interface.
247
func (s *SQLStore) FetchLightningNode(ctx context.Context,
248
        pubKey route.Vertex) (*models.LightningNode, error) {
×
249

×
250
        var node *models.LightningNode
×
251
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
252
                var err error
×
253
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
254

×
255
                return err
×
256
        }, sqldb.NoOpReset)
×
257
        if err != nil {
×
258
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
259
        }
×
260

261
        return node, nil
×
262
}
263

264
// HasLightningNode determines if the graph has a vertex identified by the
265
// target node identity public key. If the node exists in the database, a
266
// timestamp of when the data for the node was lasted updated is returned along
267
// with a true boolean. Otherwise, an empty time.Time is returned with a false
268
// boolean.
269
//
270
// NOTE: part of the V1Store interface.
271
func (s *SQLStore) HasLightningNode(ctx context.Context,
272
        pubKey [33]byte) (time.Time, bool, error) {
×
273

×
274
        var (
×
275
                exists     bool
×
276
                lastUpdate time.Time
×
277
        )
×
278
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
279
                dbNode, err := db.GetNodeByPubKey(
×
280
                        ctx, sqlc.GetNodeByPubKeyParams{
×
281
                                Version: int16(ProtocolV1),
×
282
                                PubKey:  pubKey[:],
×
283
                        },
×
284
                )
×
285
                if errors.Is(err, sql.ErrNoRows) {
×
286
                        return nil
×
287
                } else if err != nil {
×
288
                        return fmt.Errorf("unable to fetch node: %w", err)
×
289
                }
×
290

291
                exists = true
×
292

×
293
                if dbNode.LastUpdate.Valid {
×
294
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
295
                }
×
296

297
                return nil
×
298
        }, sqldb.NoOpReset)
299
        if err != nil {
×
300
                return time.Time{}, false,
×
301
                        fmt.Errorf("unable to fetch node: %w", err)
×
302
        }
×
303

304
        return lastUpdate, exists, nil
×
305
}
306

307
// AddrsForNode returns all known addresses for the target node public key
308
// that the graph DB is aware of. The returned boolean indicates if the
309
// given node is unknown to the graph DB or not.
310
//
311
// NOTE: part of the V1Store interface.
312
func (s *SQLStore) AddrsForNode(ctx context.Context,
313
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
314

×
315
        var (
×
316
                addresses []net.Addr
×
317
                known     bool
×
318
        )
×
319
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
320
                var err error
×
321
                known, addresses, err = getNodeAddresses(
×
322
                        ctx, db, nodePub.SerializeCompressed(),
×
323
                )
×
324
                if err != nil {
×
325
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
326
                                err)
×
327
                }
×
328

329
                return nil
×
330
        }, sqldb.NoOpReset)
331
        if err != nil {
×
332
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
333
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
334
        }
×
335

336
        return known, addresses, nil
×
337
}
338

339
// DeleteLightningNode starts a new database transaction to remove a vertex/node
340
// from the database according to the node's public key.
341
//
342
// NOTE: part of the V1Store interface.
343
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
344
        pubKey route.Vertex) error {
×
345

×
346
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
347
                res, err := db.DeleteNodeByPubKey(
×
348
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
349
                                Version: int16(ProtocolV1),
×
350
                                PubKey:  pubKey[:],
×
351
                        },
×
352
                )
×
353
                if err != nil {
×
354
                        return err
×
355
                }
×
356

357
                rows, err := res.RowsAffected()
×
358
                if err != nil {
×
359
                        return err
×
360
                }
×
361

362
                if rows == 0 {
×
363
                        return ErrGraphNodeNotFound
×
364
                } else if rows > 1 {
×
365
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
366
                }
×
367

368
                return err
×
369
        }, sqldb.NoOpReset)
370
        if err != nil {
×
371
                return fmt.Errorf("unable to delete node: %w", err)
×
372
        }
×
373

374
        return nil
×
375
}
376

377
// FetchNodeFeatures returns the features of the given node. If no features are
378
// known for the node, an empty feature vector is returned.
379
//
380
// NOTE: this is part of the graphdb.NodeTraverser interface.
381
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
382
        *lnwire.FeatureVector, error) {
×
383

×
384
        ctx := context.TODO()
×
385

×
386
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
387
}
×
388

389
// DisabledChannelIDs returns the channel ids of disabled channels.
390
// A channel is disabled when two of the associated ChanelEdgePolicies
391
// have their disabled bit on.
392
//
393
// NOTE: part of the V1Store interface.
394
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
395
        var (
×
396
                ctx     = context.TODO()
×
397
                chanIDs []uint64
×
398
        )
×
399
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
400
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
401
                if err != nil {
×
402
                        return fmt.Errorf("unable to fetch disabled "+
×
403
                                "channels: %w", err)
×
404
                }
×
405

406
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
407

×
408
                return nil
×
409
        }, sqldb.NoOpReset)
410
        if err != nil {
×
411
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
412
                        err)
×
413
        }
×
414

415
        return chanIDs, nil
×
416
}
417

418
// LookupAlias attempts to return the alias as advertised by the target node.
419
//
420
// NOTE: part of the V1Store interface.
421
func (s *SQLStore) LookupAlias(ctx context.Context,
422
        pub *btcec.PublicKey) (string, error) {
×
423

×
424
        var alias string
×
425
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
426
                dbNode, err := db.GetNodeByPubKey(
×
427
                        ctx, sqlc.GetNodeByPubKeyParams{
×
428
                                Version: int16(ProtocolV1),
×
429
                                PubKey:  pub.SerializeCompressed(),
×
430
                        },
×
431
                )
×
432
                if errors.Is(err, sql.ErrNoRows) {
×
433
                        return ErrNodeAliasNotFound
×
434
                } else if err != nil {
×
435
                        return fmt.Errorf("unable to fetch node: %w", err)
×
436
                }
×
437

438
                if !dbNode.Alias.Valid {
×
439
                        return ErrNodeAliasNotFound
×
440
                }
×
441

442
                alias = dbNode.Alias.String
×
443

×
444
                return nil
×
445
        }, sqldb.NoOpReset)
446
        if err != nil {
×
447
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
448
        }
×
449

450
        return alias, nil
×
451
}
452

453
// SourceNode returns the source node of the graph. The source node is treated
454
// as the center node within a star-graph. This method may be used to kick off
455
// a path finding algorithm in order to explore the reachability of another
456
// node based off the source node.
457
//
458
// NOTE: part of the V1Store interface.
459
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
460
        error) {
×
461

×
462
        var node *models.LightningNode
×
463
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
464
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
465
                if err != nil {
×
466
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
467
                                err)
×
468
                }
×
469

470
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
471

×
472
                return err
×
473
        }, sqldb.NoOpReset)
474
        if err != nil {
×
475
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
476
        }
×
477

478
        return node, nil
×
479
}
480

481
// SetSourceNode sets the source node within the graph database. The source
482
// node is to be used as the center of a star-graph within path finding
483
// algorithms.
484
//
485
// NOTE: part of the V1Store interface.
486
func (s *SQLStore) SetSourceNode(ctx context.Context,
487
        node *models.LightningNode) error {
×
488

×
489
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
490
                id, err := upsertNode(ctx, db, node)
×
491
                if err != nil {
×
492
                        return fmt.Errorf("unable to upsert source node: %w",
×
493
                                err)
×
494
                }
×
495

496
                // Make sure that if a source node for this version is already
497
                // set, then the ID is the same as the one we are about to set.
498
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
499
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
500
                        return fmt.Errorf("unable to fetch source node: %w",
×
501
                                err)
×
502
                } else if err == nil {
×
503
                        if dbSourceNodeID != id {
×
504
                                return fmt.Errorf("v1 source node already "+
×
505
                                        "set to a different node: %d vs %d",
×
506
                                        dbSourceNodeID, id)
×
507
                        }
×
508

509
                        return nil
×
510
                }
511

512
                return db.AddSourceNode(ctx, id)
×
513
        }, sqldb.NoOpReset)
514
}
515

516
// NodeUpdatesInHorizon returns all the known lightning node which have an
517
// update timestamp within the passed range. This method can be used by two
518
// nodes to quickly determine if they have the same set of up to date node
519
// announcements.
520
//
521
// NOTE: This is part of the V1Store interface.
522
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
523
        endTime time.Time) ([]models.LightningNode, error) {
×
524

×
525
        ctx := context.TODO()
×
526

×
527
        var nodes []models.LightningNode
×
528
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
529
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
530
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
531
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
532
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
533
                        },
×
534
                )
×
535
                if err != nil {
×
536
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
537
                }
×
538

539
                for _, dbNode := range dbNodes {
×
540
                        node, err := buildNode(ctx, db, &dbNode)
×
541
                        if err != nil {
×
542
                                return fmt.Errorf("unable to build node: %w",
×
543
                                        err)
×
544
                        }
×
545

546
                        nodes = append(nodes, *node)
×
547
                }
548

549
                return nil
×
550
        }, sqldb.NoOpReset)
551
        if err != nil {
×
552
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
553
        }
×
554

555
        return nodes, nil
×
556
}
557

558
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
559
// undirected edge from the two target nodes are created. The information stored
560
// denotes the static attributes of the channel, such as the channelID, the keys
561
// involved in creation of the channel, and the set of features that the channel
562
// supports. The chanPoint and chanID are used to uniquely identify the edge
563
// globally within the database.
564
//
565
// NOTE: part of the V1Store interface.
566
func (s *SQLStore) AddChannelEdge(ctx context.Context,
567
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
568

×
569
        var alreadyExists bool
×
570
        r := &batch.Request[SQLQueries]{
×
571
                Opts: batch.NewSchedulerOptions(opts...),
×
572
                Reset: func() {
×
573
                        alreadyExists = false
×
574
                },
×
575
                Do: func(tx SQLQueries) error {
×
576
                        err := insertChannel(ctx, tx, edge)
×
577

×
578
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
579
                        // succeed, but propagate the error via local state.
×
580
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
581
                                alreadyExists = true
×
582
                                return nil
×
583
                        }
×
584

585
                        return err
×
586
                },
587
                OnCommit: func(err error) error {
×
588
                        switch {
×
589
                        case err != nil:
×
590
                                return err
×
591
                        case alreadyExists:
×
592
                                return ErrEdgeAlreadyExist
×
593
                        default:
×
594
                                s.rejectCache.remove(edge.ChannelID)
×
595
                                s.chanCache.remove(edge.ChannelID)
×
596
                                return nil
×
597
                        }
598
                },
599
        }
600

601
        return s.chanScheduler.Execute(ctx, r)
×
602
}
603

604
// HighestChanID returns the "highest" known channel ID in the channel graph.
605
// This represents the "newest" channel from the PoV of the chain. This method
606
// can be used by peers to quickly determine if their graphs are in sync.
607
//
608
// NOTE: This is part of the V1Store interface.
609
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
610
        var highestChanID uint64
×
611
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
612
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
613
                if errors.Is(err, sql.ErrNoRows) {
×
614
                        return nil
×
615
                } else if err != nil {
×
616
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
617
                                err)
×
618
                }
×
619

620
                highestChanID = byteOrder.Uint64(chanID)
×
621

×
622
                return nil
×
623
        }, sqldb.NoOpReset)
624
        if err != nil {
×
625
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
626
        }
×
627

628
        return highestChanID, nil
×
629
}
630

631
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
632
// within the database for the referenced channel. The `flags` attribute within
633
// the ChannelEdgePolicy determines which of the directed edges are being
634
// updated. If the flag is 1, then the first node's information is being
635
// updated, otherwise it's the second node's information. The node ordering is
636
// determined by the lexicographical ordering of the identity public keys of the
637
// nodes on either side of the channel.
638
//
639
// NOTE: part of the V1Store interface.
640
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
641
        edge *models.ChannelEdgePolicy,
642
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
643

×
644
        var (
×
645
                isUpdate1    bool
×
646
                edgeNotFound bool
×
647
                from, to     route.Vertex
×
648
        )
×
649

×
650
        r := &batch.Request[SQLQueries]{
×
651
                Opts: batch.NewSchedulerOptions(opts...),
×
652
                Reset: func() {
×
653
                        isUpdate1 = false
×
654
                        edgeNotFound = false
×
655
                },
×
656
                Do: func(tx SQLQueries) error {
×
657
                        var err error
×
658
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
659
                                ctx, tx, edge,
×
660
                        )
×
661
                        if err != nil {
×
662
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
663
                        }
×
664

665
                        // Silence ErrEdgeNotFound so that the batch can
666
                        // succeed, but propagate the error via local state.
667
                        if errors.Is(err, ErrEdgeNotFound) {
×
668
                                edgeNotFound = true
×
669
                                return nil
×
670
                        }
×
671

672
                        return err
×
673
                },
674
                OnCommit: func(err error) error {
×
675
                        switch {
×
676
                        case err != nil:
×
677
                                return err
×
678
                        case edgeNotFound:
×
679
                                return ErrEdgeNotFound
×
680
                        default:
×
681
                                s.updateEdgeCache(edge, isUpdate1)
×
682
                                return nil
×
683
                        }
684
                },
685
        }
686

687
        err := s.chanScheduler.Execute(ctx, r)
×
688

×
689
        return from, to, err
×
690
}
691

692
// updateEdgeCache updates our reject and channel caches with the new
693
// edge policy information.
694
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
695
        isUpdate1 bool) {
×
696

×
697
        // If an entry for this channel is found in reject cache, we'll modify
×
698
        // the entry with the updated timestamp for the direction that was just
×
699
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
700
        // during the next query for this edge.
×
701
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
702
                if isUpdate1 {
×
703
                        entry.upd1Time = e.LastUpdate.Unix()
×
704
                } else {
×
705
                        entry.upd2Time = e.LastUpdate.Unix()
×
706
                }
×
707
                s.rejectCache.insert(e.ChannelID, entry)
×
708
        }
709

710
        // If an entry for this channel is found in channel cache, we'll modify
711
        // the entry with the updated policy for the direction that was just
712
        // written. If the edge doesn't exist, we'll defer loading the info and
713
        // policies and lazily read from disk during the next query.
714
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
715
                if isUpdate1 {
×
716
                        channel.Policy1 = e
×
717
                } else {
×
718
                        channel.Policy2 = e
×
719
                }
×
720
                s.chanCache.insert(e.ChannelID, channel)
×
721
        }
722
}
723

724
// ForEachSourceNodeChannel iterates through all channels of the source node,
725
// executing the passed callback on each. The call-back is provided with the
726
// channel's outpoint, whether we have a policy for the channel and the channel
727
// peer's node information.
728
//
729
// NOTE: part of the V1Store interface.
730
func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint,
731
        havePolicy bool, otherNode *models.LightningNode) error) error {
×
732

×
733
        var ctx = context.TODO()
×
734

×
735
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
736
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
737
                if err != nil {
×
738
                        return fmt.Errorf("unable to fetch source node: %w",
×
739
                                err)
×
740
                }
×
741

742
                return forEachNodeChannel(
×
743
                        ctx, db, s.cfg.ChainHash, nodeID,
×
744
                        func(info *models.ChannelEdgeInfo,
×
745
                                outPolicy *models.ChannelEdgePolicy,
×
746
                                _ *models.ChannelEdgePolicy) error {
×
747

×
748
                                // Fetch the other node.
×
749
                                var (
×
750
                                        otherNodePub [33]byte
×
751
                                        node1        = info.NodeKey1Bytes
×
752
                                        node2        = info.NodeKey2Bytes
×
753
                                )
×
754
                                switch {
×
755
                                case bytes.Equal(node1[:], nodePub[:]):
×
756
                                        otherNodePub = node2
×
757
                                case bytes.Equal(node2[:], nodePub[:]):
×
758
                                        otherNodePub = node1
×
759
                                default:
×
760
                                        return fmt.Errorf("node not " +
×
761
                                                "participating in this channel")
×
762
                                }
763

764
                                _, otherNode, err := getNodeByPubKey(
×
765
                                        ctx, db, otherNodePub,
×
766
                                )
×
767
                                if err != nil {
×
768
                                        return fmt.Errorf("unable to fetch "+
×
769
                                                "other node(%x): %w",
×
770
                                                otherNodePub, err)
×
771
                                }
×
772

773
                                return cb(
×
774
                                        info.ChannelPoint, outPolicy != nil,
×
775
                                        otherNode,
×
776
                                )
×
777
                        },
778
                )
779
        }, sqldb.NoOpReset)
780
}
781

782
// ForEachNode iterates through all the stored vertices/nodes in the graph,
783
// executing the passed callback with each node encountered. If the callback
784
// returns an error, then the transaction is aborted and the iteration stops
785
// early. Any operations performed on the NodeTx passed to the call-back are
786
// executed under the same read transaction and so, methods on the NodeTx object
787
// _MUST_ only be called from within the call-back.
788
//
789
// NOTE: part of the V1Store interface.
790
func (s *SQLStore) ForEachNode(cb func(tx NodeRTx) error) error {
×
791
        var (
×
792
                ctx          = context.TODO()
×
793
                lastID int64 = 0
×
794
        )
×
795

×
796
        handleNode := func(db SQLQueries, dbNode sqlc.Node) error {
×
797
                node, err := buildNode(ctx, db, &dbNode)
×
798
                if err != nil {
×
799
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
800
                                dbNode.ID, err)
×
801
                }
×
802

803
                err = cb(
×
804
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
805
                )
×
806
                if err != nil {
×
807
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
808
                                dbNode.ID, err)
×
809
                }
×
810

811
                return nil
×
812
        }
813

814
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
815
                for {
×
816
                        nodes, err := db.ListNodesPaginated(
×
817
                                ctx, sqlc.ListNodesPaginatedParams{
×
818
                                        Version: int16(ProtocolV1),
×
819
                                        ID:      lastID,
×
820
                                        Limit:   pageSize,
×
821
                                },
×
822
                        )
×
823
                        if err != nil {
×
824
                                return fmt.Errorf("unable to fetch nodes: %w",
×
825
                                        err)
×
826
                        }
×
827

828
                        if len(nodes) == 0 {
×
829
                                break
×
830
                        }
831

832
                        for _, dbNode := range nodes {
×
833
                                err = handleNode(db, dbNode)
×
834
                                if err != nil {
×
835
                                        return err
×
836
                                }
×
837

838
                                lastID = dbNode.ID
×
839
                        }
840
                }
841

842
                return nil
×
843
        }, sqldb.NoOpReset)
844
}
845

846
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
847
// SQLStore and a SQL transaction.
848
type sqlGraphNodeTx struct {
849
        db    SQLQueries
850
        id    int64
851
        node  *models.LightningNode
852
        chain chainhash.Hash
853
}
854

855
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
856
// interface.
857
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
858

859
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
860
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
861

×
862
        return &sqlGraphNodeTx{
×
863
                db:    db,
×
864
                chain: chain,
×
865
                id:    id,
×
866
                node:  node,
×
867
        }
×
868
}
×
869

870
// Node returns the raw information of the node.
871
//
872
// NOTE: This is a part of the NodeRTx interface.
873
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
874
        return s.node
×
875
}
×
876

877
// ForEachChannel can be used to iterate over the node's channels under the same
878
// transaction used to fetch the node.
879
//
880
// NOTE: This is a part of the NodeRTx interface.
881
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
882
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
883

×
884
        ctx := context.TODO()
×
885

×
886
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
887
}
×
888

889
// FetchNode fetches the node with the given pub key under the same transaction
890
// used to fetch the current node. The returned node is also a NodeRTx and any
891
// operations on that NodeRTx will also be done under the same transaction.
892
//
893
// NOTE: This is a part of the NodeRTx interface.
894
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
895
        ctx := context.TODO()
×
896

×
897
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
898
        if err != nil {
×
899
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
900
                        nodePub, err)
×
901
        }
×
902

903
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
904
}
905

906
// ForEachNodeDirectedChannel iterates through all channels of a given node,
907
// executing the passed callback on the directed edge representing the channel
908
// and its incoming policy. If the callback returns an error, then the iteration
909
// is halted with the error propagated back up to the caller.
910
//
911
// Unknown policies are passed into the callback as nil values.
912
//
913
// NOTE: this is part of the graphdb.NodeTraverser interface.
914
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
915
        cb func(channel *DirectedChannel) error) error {
×
916

×
917
        var ctx = context.TODO()
×
918

×
919
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
920
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
921
        }, sqldb.NoOpReset)
×
922
}
923

924
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
925
// graph, executing the passed callback with each node encountered. If the
926
// callback returns an error, then the transaction is aborted and the iteration
927
// stops early.
928
//
929
// NOTE: This is a part of the V1Store interface.
930
func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex,
931
        *lnwire.FeatureVector) error) error {
×
932

×
933
        ctx := context.TODO()
×
934

×
935
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
936
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
937
                        nodePub route.Vertex) error {
×
938

×
939
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
940
                        if err != nil {
×
941
                                return fmt.Errorf("unable to fetch node "+
×
942
                                        "features: %w", err)
×
943
                        }
×
944

945
                        return cb(nodePub, features)
×
946
                })
947
        }, sqldb.NoOpReset)
948
        if err != nil {
×
949
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
950
        }
×
951

952
        return nil
×
953
}
954

955
// ForEachNodeChannel iterates through all channels of the given node,
956
// executing the passed callback with an edge info structure and the policies
957
// of each end of the channel. The first edge policy is the outgoing edge *to*
958
// the connecting node, while the second is the incoming edge *from* the
959
// connecting node. If the callback returns an error, then the iteration is
960
// halted with the error propagated back up to the caller.
961
//
962
// Unknown policies are passed into the callback as nil values.
963
//
964
// NOTE: part of the V1Store interface.
965
func (s *SQLStore) ForEachNodeChannel(nodePub route.Vertex,
966
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
967
                *models.ChannelEdgePolicy) error) error {
×
968

×
969
        var ctx = context.TODO()
×
970

×
971
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
972
                dbNode, err := db.GetNodeByPubKey(
×
973
                        ctx, sqlc.GetNodeByPubKeyParams{
×
974
                                Version: int16(ProtocolV1),
×
975
                                PubKey:  nodePub[:],
×
976
                        },
×
977
                )
×
978
                if errors.Is(err, sql.ErrNoRows) {
×
979
                        return nil
×
980
                } else if err != nil {
×
981
                        return fmt.Errorf("unable to fetch node: %w", err)
×
982
                }
×
983

984
                return forEachNodeChannel(
×
985
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
986
                )
×
987
        }, sqldb.NoOpReset)
988
}
989

990
// ChanUpdatesInHorizon returns all the known channel edges which have at least
991
// one edge that has an update timestamp within the specified horizon.
992
//
993
// NOTE: This is part of the V1Store interface.
994
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
995
        endTime time.Time) ([]ChannelEdge, error) {
×
996

×
997
        s.cacheMu.Lock()
×
998
        defer s.cacheMu.Unlock()
×
999

×
1000
        var (
×
1001
                ctx = context.TODO()
×
1002
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
1003
                // an additional map to keep track of the edges already seen to
×
1004
                // prevent re-adding it.
×
1005
                edgesSeen    = make(map[uint64]struct{})
×
1006
                edgesToCache = make(map[uint64]ChannelEdge)
×
1007
                edges        []ChannelEdge
×
1008
                hits         int
×
1009
        )
×
1010
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1011
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1012
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1013
                                Version:   int16(ProtocolV1),
×
1014
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1015
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1016
                        },
×
1017
                )
×
1018
                if err != nil {
×
1019
                        return err
×
1020
                }
×
1021

1022
                for _, row := range rows {
×
1023
                        // If we've already retrieved the info and policies for
×
1024
                        // this edge, then we can skip it as we don't need to do
×
1025
                        // so again.
×
1026
                        chanIDInt := byteOrder.Uint64(row.Channel.Scid)
×
1027
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1028
                                continue
×
1029
                        }
1030

1031
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1032
                                hits++
×
1033
                                edgesSeen[chanIDInt] = struct{}{}
×
1034
                                edges = append(edges, channel)
×
1035

×
1036
                                continue
×
1037
                        }
1038

1039
                        node1, node2, err := buildNodes(
×
1040
                                ctx, db, row.Node, row.Node_2,
×
1041
                        )
×
1042
                        if err != nil {
×
1043
                                return err
×
1044
                        }
×
1045

1046
                        channel, err := getAndBuildEdgeInfo(
×
1047
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1048
                                row.Channel, node1.PubKeyBytes,
×
1049
                                node2.PubKeyBytes,
×
1050
                        )
×
1051
                        if err != nil {
×
1052
                                return fmt.Errorf("unable to build channel "+
×
1053
                                        "info: %w", err)
×
1054
                        }
×
1055

1056
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1057
                        if err != nil {
×
1058
                                return fmt.Errorf("unable to extract channel "+
×
1059
                                        "policies: %w", err)
×
1060
                        }
×
1061

1062
                        p1, p2, err := getAndBuildChanPolicies(
×
1063
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1064
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1065
                        )
×
1066
                        if err != nil {
×
1067
                                return fmt.Errorf("unable to build channel "+
×
1068
                                        "policies: %w", err)
×
1069
                        }
×
1070

1071
                        edgesSeen[chanIDInt] = struct{}{}
×
1072
                        chanEdge := ChannelEdge{
×
1073
                                Info:    channel,
×
1074
                                Policy1: p1,
×
1075
                                Policy2: p2,
×
1076
                                Node1:   node1,
×
1077
                                Node2:   node2,
×
1078
                        }
×
1079
                        edges = append(edges, chanEdge)
×
1080
                        edgesToCache[chanIDInt] = chanEdge
×
1081
                }
1082

1083
                return nil
×
1084
        }, func() {
×
1085
                edgesSeen = make(map[uint64]struct{})
×
1086
                edgesToCache = make(map[uint64]ChannelEdge)
×
1087
                edges = nil
×
1088
        })
×
1089
        if err != nil {
×
1090
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1091
        }
×
1092

1093
        // Insert any edges loaded from disk into the cache.
1094
        for chanid, channel := range edgesToCache {
×
1095
                s.chanCache.insert(chanid, channel)
×
1096
        }
×
1097

1098
        if len(edges) > 0 {
×
1099
                log.Debugf("ChanUpdatesInHorizon hit percentage: %f (%d/%d)",
×
1100
                        float64(hits)/float64(len(edges)), hits, len(edges))
×
1101
        } else {
×
1102
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1103
                        "horizon (%s, %s)", startTime, endTime)
×
1104
        }
×
1105

1106
        return edges, nil
×
1107
}
1108

1109
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1110
// data to the call-back.
1111
//
1112
// NOTE: The callback contents MUST not be modified.
1113
//
1114
// NOTE: part of the V1Store interface.
1115
func (s *SQLStore) ForEachNodeCached(cb func(node route.Vertex,
1116
        chans map[uint64]*DirectedChannel) error) error {
×
1117

×
1118
        var ctx = context.TODO()
×
1119

×
1120
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1121
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1122
                        nodePub route.Vertex) error {
×
1123

×
1124
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1125
                        if err != nil {
×
1126
                                return fmt.Errorf("unable to fetch "+
×
1127
                                        "node(id=%d) features: %w", nodeID, err)
×
1128
                        }
×
1129

1130
                        toNodeCallback := func() route.Vertex {
×
1131
                                return nodePub
×
1132
                        }
×
1133

1134
                        rows, err := db.ListChannelsByNodeID(
×
1135
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1136
                                        Version: int16(ProtocolV1),
×
1137
                                        NodeID1: nodeID,
×
1138
                                },
×
1139
                        )
×
1140
                        if err != nil {
×
1141
                                return fmt.Errorf("unable to fetch channels "+
×
1142
                                        "of node(id=%d): %w", nodeID, err)
×
1143
                        }
×
1144

1145
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1146
                        for _, row := range rows {
×
1147
                                node1, node2, err := buildNodeVertices(
×
1148
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1149
                                )
×
1150
                                if err != nil {
×
1151
                                        return err
×
1152
                                }
×
1153

1154
                                e, err := getAndBuildEdgeInfo(
×
1155
                                        ctx, db, s.cfg.ChainHash,
×
1156
                                        row.Channel.ID, row.Channel, node1,
×
1157
                                        node2,
×
1158
                                )
×
1159
                                if err != nil {
×
1160
                                        return fmt.Errorf("unable to build "+
×
1161
                                                "channel info: %w", err)
×
1162
                                }
×
1163

1164
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1165
                                        row,
×
1166
                                )
×
1167
                                if err != nil {
×
1168
                                        return fmt.Errorf("unable to "+
×
1169
                                                "extract channel "+
×
1170
                                                "policies: %w", err)
×
1171
                                }
×
1172

1173
                                p1, p2, err := getAndBuildChanPolicies(
×
1174
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1175
                                        node1, node2,
×
1176
                                )
×
1177
                                if err != nil {
×
1178
                                        return fmt.Errorf("unable to "+
×
1179
                                                "build channel policies: %w",
×
1180
                                                err)
×
1181
                                }
×
1182

1183
                                // Determine the outgoing and incoming policy
1184
                                // for this channel and node combo.
1185
                                outPolicy, inPolicy := p1, p2
×
1186
                                if p1 != nil && p1.ToNode == nodePub {
×
1187
                                        outPolicy, inPolicy = p2, p1
×
1188
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1189
                                        outPolicy, inPolicy = p2, p1
×
1190
                                }
×
1191

1192
                                var cachedInPolicy *models.CachedEdgePolicy
×
1193
                                if inPolicy != nil {
×
1194
                                        cachedInPolicy = models.NewCachedPolicy(
×
1195
                                                p2,
×
1196
                                        )
×
1197
                                        cachedInPolicy.ToNodePubKey =
×
1198
                                                toNodeCallback
×
1199
                                        cachedInPolicy.ToNodeFeatures =
×
1200
                                                features
×
1201
                                }
×
1202

1203
                                var inboundFee lnwire.Fee
×
1204
                                outPolicy.InboundFee.WhenSome(
×
1205
                                        func(fee lnwire.Fee) {
×
1206
                                                inboundFee = fee
×
1207
                                        },
×
1208
                                )
1209

1210
                                directedChannel := &DirectedChannel{
×
1211
                                        ChannelID: e.ChannelID,
×
1212
                                        IsNode1: nodePub ==
×
1213
                                                e.NodeKey1Bytes,
×
1214
                                        OtherNode:    e.NodeKey2Bytes,
×
1215
                                        Capacity:     e.Capacity,
×
1216
                                        OutPolicySet: p1 != nil,
×
1217
                                        InPolicy:     cachedInPolicy,
×
1218
                                        InboundFee:   inboundFee,
×
1219
                                }
×
1220

×
1221
                                if nodePub == e.NodeKey2Bytes {
×
1222
                                        directedChannel.OtherNode =
×
1223
                                                e.NodeKey1Bytes
×
1224
                                }
×
1225

1226
                                channels[e.ChannelID] = directedChannel
×
1227
                        }
1228

1229
                        return cb(nodePub, channels)
×
1230
                })
1231
        }, sqldb.NoOpReset)
1232
}
1233

1234
// ForEachChannel iterates through all the channel edges stored within the
1235
// graph and invokes the passed callback for each edge. The callback takes two
1236
// edges as since this is a directed graph, both the in/out edges are visited.
1237
// If the callback returns an error, then the transaction is aborted and the
1238
// iteration stops early.
1239
//
1240
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1241
// for that particular channel edge routing policy will be passed into the
1242
// callback.
1243
//
1244
// NOTE: part of the V1Store interface.
1245
func (s *SQLStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
1246
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
1247

×
1248
        ctx := context.TODO()
×
1249

×
1250
        handleChannel := func(db SQLQueries,
×
1251
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1252

×
1253
                node1, node2, err := buildNodeVertices(
×
1254
                        row.Node1Pubkey, row.Node2Pubkey,
×
1255
                )
×
1256
                if err != nil {
×
1257
                        return fmt.Errorf("unable to build node vertices: %w",
×
1258
                                err)
×
1259
                }
×
1260

1261
                edge, err := getAndBuildEdgeInfo(
×
1262
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1263
                        node1, node2,
×
1264
                )
×
1265
                if err != nil {
×
1266
                        return fmt.Errorf("unable to build channel info: %w",
×
1267
                                err)
×
1268
                }
×
1269

1270
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1271
                if err != nil {
×
1272
                        return fmt.Errorf("unable to extract channel "+
×
1273
                                "policies: %w", err)
×
1274
                }
×
1275

1276
                p1, p2, err := getAndBuildChanPolicies(
×
1277
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1278
                )
×
1279
                if err != nil {
×
1280
                        return fmt.Errorf("unable to build channel "+
×
1281
                                "policies: %w", err)
×
1282
                }
×
1283

1284
                err = cb(edge, p1, p2)
×
1285
                if err != nil {
×
1286
                        return fmt.Errorf("callback failed for channel "+
×
1287
                                "id=%d: %w", edge.ChannelID, err)
×
1288
                }
×
1289

1290
                return nil
×
1291
        }
1292

1293
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1294
                var lastID int64
×
1295
                for {
×
1296
                        //nolint:ll
×
1297
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1298
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1299
                                        Version: int16(ProtocolV1),
×
1300
                                        ID:      lastID,
×
1301
                                        Limit:   pageSize,
×
1302
                                },
×
1303
                        )
×
1304
                        if err != nil {
×
1305
                                return err
×
1306
                        }
×
1307

1308
                        if len(rows) == 0 {
×
1309
                                break
×
1310
                        }
1311

1312
                        for _, row := range rows {
×
1313
                                err := handleChannel(db, row)
×
1314
                                if err != nil {
×
1315
                                        return err
×
1316
                                }
×
1317

1318
                                lastID = row.Channel.ID
×
1319
                        }
1320
                }
1321

1322
                return nil
×
1323
        }, sqldb.NoOpReset)
1324
}
1325

1326
// FilterChannelRange returns the channel ID's of all known channels which were
1327
// mined in a block height within the passed range. The channel IDs are grouped
1328
// by their common block height. This method can be used to quickly share with a
1329
// peer the set of channels we know of within a particular range to catch them
1330
// up after a period of time offline. If withTimestamps is true then the
1331
// timestamp info of the latest received channel update messages of the channel
1332
// will be included in the response.
1333
//
1334
// NOTE: This is part of the V1Store interface.
1335
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1336
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1337

×
1338
        var (
×
1339
                ctx       = context.TODO()
×
1340
                startSCID = &lnwire.ShortChannelID{
×
1341
                        BlockHeight: startHeight,
×
1342
                }
×
1343
                endSCID = lnwire.ShortChannelID{
×
1344
                        BlockHeight: endHeight,
×
1345
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1346
                        TxPosition:  math.MaxUint16,
×
1347
                }
×
1348
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1349
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1350
        )
×
1351

×
1352
        // 1) get all channels where channelID is between start and end chan ID.
×
1353
        // 2) skip if not public (ie, no channel_proof)
×
1354
        // 3) collect that channel.
×
1355
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1356
        //    and add those timestamps to the collected channel.
×
1357
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1358
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1359
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1360
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1361
                                StartScid: chanIDStart[:],
×
1362
                                EndScid:   chanIDEnd[:],
×
1363
                        },
×
1364
                )
×
1365
                if err != nil {
×
1366
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1367
                                err)
×
1368
                }
×
1369

1370
                for _, dbChan := range dbChans {
×
1371
                        cid := lnwire.NewShortChanIDFromInt(
×
1372
                                byteOrder.Uint64(dbChan.Scid),
×
1373
                        )
×
1374
                        chanInfo := NewChannelUpdateInfo(
×
1375
                                cid, time.Time{}, time.Time{},
×
1376
                        )
×
1377

×
1378
                        if !withTimestamps {
×
1379
                                channelsPerBlock[cid.BlockHeight] = append(
×
1380
                                        channelsPerBlock[cid.BlockHeight],
×
1381
                                        chanInfo,
×
1382
                                )
×
1383

×
1384
                                continue
×
1385
                        }
1386

1387
                        //nolint:ll
1388
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1389
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1390
                                        Version:   int16(ProtocolV1),
×
1391
                                        ChannelID: dbChan.ID,
×
1392
                                        NodeID:    dbChan.NodeID1,
×
1393
                                },
×
1394
                        )
×
1395
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1396
                                return fmt.Errorf("unable to fetch node1 "+
×
1397
                                        "policy: %w", err)
×
1398
                        } else if err == nil {
×
1399
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1400
                                        node1Policy.LastUpdate.Int64, 0,
×
1401
                                )
×
1402
                        }
×
1403

1404
                        //nolint:ll
1405
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1406
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1407
                                        Version:   int16(ProtocolV1),
×
1408
                                        ChannelID: dbChan.ID,
×
1409
                                        NodeID:    dbChan.NodeID2,
×
1410
                                },
×
1411
                        )
×
1412
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1413
                                return fmt.Errorf("unable to fetch node2 "+
×
1414
                                        "policy: %w", err)
×
1415
                        } else if err == nil {
×
1416
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1417
                                        node2Policy.LastUpdate.Int64, 0,
×
1418
                                )
×
1419
                        }
×
1420

1421
                        channelsPerBlock[cid.BlockHeight] = append(
×
1422
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1423
                        )
×
1424
                }
1425

1426
                return nil
×
1427
        }, func() {
×
1428
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1429
        })
×
1430
        if err != nil {
×
1431
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1432
        }
×
1433

1434
        if len(channelsPerBlock) == 0 {
×
1435
                return nil, nil
×
1436
        }
×
1437

1438
        // Return the channel ranges in ascending block height order.
1439
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1440
        slices.Sort(blocks)
×
1441

×
1442
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1443
                return BlockChannelRange{
×
1444
                        Height:   block,
×
1445
                        Channels: channelsPerBlock[block],
×
1446
                }
×
1447
        }), nil
×
1448
}
1449

1450
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1451
// zombie. This method is used on an ad-hoc basis, when channels need to be
1452
// marked as zombies outside the normal pruning cycle.
1453
//
1454
// NOTE: part of the V1Store interface.
1455
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1456
        pubKey1, pubKey2 [33]byte) error {
×
1457

×
1458
        ctx := context.TODO()
×
1459

×
1460
        s.cacheMu.Lock()
×
1461
        defer s.cacheMu.Unlock()
×
1462

×
1463
        chanIDB := channelIDToBytes(chanID)
×
1464

×
1465
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1466
                return db.UpsertZombieChannel(
×
1467
                        ctx, sqlc.UpsertZombieChannelParams{
×
1468
                                Version:  int16(ProtocolV1),
×
1469
                                Scid:     chanIDB[:],
×
1470
                                NodeKey1: pubKey1[:],
×
1471
                                NodeKey2: pubKey2[:],
×
1472
                        },
×
1473
                )
×
1474
        }, sqldb.NoOpReset)
×
1475
        if err != nil {
×
1476
                return fmt.Errorf("unable to upsert zombie channel "+
×
1477
                        "(channel_id=%d): %w", chanID, err)
×
1478
        }
×
1479

1480
        s.rejectCache.remove(chanID)
×
1481
        s.chanCache.remove(chanID)
×
1482

×
1483
        return nil
×
1484
}
1485

1486
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1487
//
1488
// NOTE: part of the V1Store interface.
1489
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1490
        s.cacheMu.Lock()
×
1491
        defer s.cacheMu.Unlock()
×
1492

×
1493
        var (
×
1494
                ctx     = context.TODO()
×
1495
                chanIDB = channelIDToBytes(chanID)
×
1496
        )
×
1497

×
1498
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1499
                res, err := db.DeleteZombieChannel(
×
1500
                        ctx, sqlc.DeleteZombieChannelParams{
×
1501
                                Scid:    chanIDB[:],
×
1502
                                Version: int16(ProtocolV1),
×
1503
                        },
×
1504
                )
×
1505
                if err != nil {
×
1506
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1507
                                err)
×
1508
                }
×
1509

1510
                rows, err := res.RowsAffected()
×
1511
                if err != nil {
×
1512
                        return err
×
1513
                }
×
1514

1515
                if rows == 0 {
×
1516
                        return ErrZombieEdgeNotFound
×
1517
                } else if rows > 1 {
×
1518
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1519
                                "expected 1", rows)
×
1520
                }
×
1521

1522
                return nil
×
1523
        }, sqldb.NoOpReset)
1524
        if err != nil {
×
1525
                return fmt.Errorf("unable to mark edge live "+
×
1526
                        "(channel_id=%d): %w", chanID, err)
×
1527
        }
×
1528

1529
        s.rejectCache.remove(chanID)
×
1530
        s.chanCache.remove(chanID)
×
1531

×
1532
        return err
×
1533
}
1534

1535
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1536
// zombie, then the two node public keys corresponding to this edge are also
1537
// returned.
1538
//
1539
// NOTE: part of the V1Store interface.
1540
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) {
×
1541
        var (
×
1542
                ctx              = context.TODO()
×
1543
                isZombie         bool
×
1544
                pubKey1, pubKey2 route.Vertex
×
1545
                chanIDB          = channelIDToBytes(chanID)
×
1546
        )
×
1547

×
1548
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1549
                zombie, err := db.GetZombieChannel(
×
1550
                        ctx, sqlc.GetZombieChannelParams{
×
1551
                                Scid:    chanIDB[:],
×
1552
                                Version: int16(ProtocolV1),
×
1553
                        },
×
1554
                )
×
1555
                if errors.Is(err, sql.ErrNoRows) {
×
1556
                        return nil
×
1557
                }
×
1558
                if err != nil {
×
1559
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1560
                                err)
×
1561
                }
×
1562

1563
                copy(pubKey1[:], zombie.NodeKey1)
×
1564
                copy(pubKey2[:], zombie.NodeKey2)
×
1565
                isZombie = true
×
1566

×
1567
                return nil
×
1568
        }, sqldb.NoOpReset)
1569
        if err != nil {
×
1570
                // TODO(elle): update the IsZombieEdge method to return an
×
1571
                // error.
×
1572
                return false, route.Vertex{}, route.Vertex{}
×
1573
        }
×
1574

1575
        return isZombie, pubKey1, pubKey2
×
1576
}
1577

1578
// NumZombies returns the current number of zombie channels in the graph.
1579
//
1580
// NOTE: part of the V1Store interface.
1581
func (s *SQLStore) NumZombies() (uint64, error) {
×
1582
        var (
×
1583
                ctx        = context.TODO()
×
1584
                numZombies uint64
×
1585
        )
×
1586
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1587
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1588
                if err != nil {
×
1589
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1590
                                err)
×
1591
                }
×
1592

1593
                numZombies = uint64(count)
×
1594

×
1595
                return nil
×
1596
        }, sqldb.NoOpReset)
1597
        if err != nil {
×
1598
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1599
        }
×
1600

1601
        return numZombies, nil
×
1602
}
1603

1604
// DeleteChannelEdges removes edges with the given channel IDs from the
1605
// database and marks them as zombies. This ensures that we're unable to re-add
1606
// it to our database once again. If an edge does not exist within the
1607
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1608
// true, then when we mark these edges as zombies, we'll set up the keys such
1609
// that we require the node that failed to send the fresh update to be the one
1610
// that resurrects the channel from its zombie state. The markZombie bool
1611
// denotes whether to mark the channel as a zombie.
1612
//
1613
// NOTE: part of the V1Store interface.
1614
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1615
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1616

×
1617
        s.cacheMu.Lock()
×
1618
        defer s.cacheMu.Unlock()
×
1619

×
1620
        var (
×
1621
                ctx     = context.TODO()
×
1622
                deleted []*models.ChannelEdgeInfo
×
1623
        )
×
1624
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1625
                for _, chanID := range chanIDs {
×
1626
                        chanIDB := channelIDToBytes(chanID)
×
1627

×
1628
                        row, err := db.GetChannelBySCIDWithPolicies(
×
1629
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1630
                                        Scid:    chanIDB[:],
×
1631
                                        Version: int16(ProtocolV1),
×
1632
                                },
×
1633
                        )
×
1634
                        if errors.Is(err, sql.ErrNoRows) {
×
1635
                                return ErrEdgeNotFound
×
1636
                        } else if err != nil {
×
1637
                                return fmt.Errorf("unable to fetch channel: %w",
×
1638
                                        err)
×
1639
                        }
×
1640

1641
                        node1, node2, err := buildNodeVertices(
×
1642
                                row.Node.PubKey, row.Node_2.PubKey,
×
1643
                        )
×
1644
                        if err != nil {
×
1645
                                return err
×
1646
                        }
×
1647

1648
                        info, err := getAndBuildEdgeInfo(
×
1649
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1650
                                row.Channel, node1, node2,
×
1651
                        )
×
1652
                        if err != nil {
×
1653
                                return err
×
1654
                        }
×
1655

1656
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
1657
                        if err != nil {
×
1658
                                return fmt.Errorf("unable to delete "+
×
1659
                                        "channel: %w", err)
×
1660
                        }
×
1661

1662
                        deleted = append(deleted, info)
×
1663

×
1664
                        if !markZombie {
×
1665
                                continue
×
1666
                        }
1667

1668
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1669
                                info.NodeKey2Bytes
×
1670
                        if strictZombiePruning {
×
1671
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1672
                                if row.Policy1LastUpdate.Valid {
×
1673
                                        e1Time := time.Unix(
×
1674
                                                row.Policy1LastUpdate.Int64, 0,
×
1675
                                        )
×
1676
                                        e1UpdateTime = &e1Time
×
1677
                                }
×
1678
                                if row.Policy2LastUpdate.Valid {
×
1679
                                        e2Time := time.Unix(
×
1680
                                                row.Policy2LastUpdate.Int64, 0,
×
1681
                                        )
×
1682
                                        e2UpdateTime = &e2Time
×
1683
                                }
×
1684

1685
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1686
                                        info, e1UpdateTime, e2UpdateTime,
×
1687
                                )
×
1688
                        }
1689

1690
                        err = db.UpsertZombieChannel(
×
1691
                                ctx, sqlc.UpsertZombieChannelParams{
×
1692
                                        Version:  int16(ProtocolV1),
×
1693
                                        Scid:     chanIDB[:],
×
1694
                                        NodeKey1: nodeKey1[:],
×
1695
                                        NodeKey2: nodeKey2[:],
×
1696
                                },
×
1697
                        )
×
1698
                        if err != nil {
×
1699
                                return fmt.Errorf("unable to mark channel as "+
×
1700
                                        "zombie: %w", err)
×
1701
                        }
×
1702
                }
1703

1704
                return nil
×
1705
        }, func() {
×
1706
                deleted = nil
×
1707
        })
×
1708
        if err != nil {
×
1709
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1710
                        err)
×
1711
        }
×
1712

1713
        for _, chanID := range chanIDs {
×
1714
                s.rejectCache.remove(chanID)
×
1715
                s.chanCache.remove(chanID)
×
1716
        }
×
1717

1718
        return deleted, nil
×
1719
}
1720

1721
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1722
// channel identified by the channel ID. If the channel can't be found, then
1723
// ErrEdgeNotFound is returned. A struct which houses the general information
1724
// for the channel itself is returned as well as two structs that contain the
1725
// routing policies for the channel in either direction.
1726
//
1727
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1728
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1729
// the ChannelEdgeInfo will only include the public keys of each node.
1730
//
1731
// NOTE: part of the V1Store interface.
1732
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1733
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1734
        *models.ChannelEdgePolicy, error) {
×
1735

×
1736
        var (
×
1737
                ctx              = context.TODO()
×
1738
                edge             *models.ChannelEdgeInfo
×
1739
                policy1, policy2 *models.ChannelEdgePolicy
×
1740
        )
×
1741
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1742
                var chanIDB [8]byte
×
1743
                byteOrder.PutUint64(chanIDB[:], chanID)
×
1744

×
1745
                row, err := db.GetChannelBySCIDWithPolicies(
×
1746
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1747
                                Scid:    chanIDB[:],
×
1748
                                Version: int16(ProtocolV1),
×
1749
                        },
×
1750
                )
×
1751
                if errors.Is(err, sql.ErrNoRows) {
×
1752
                        // First check if this edge is perhaps in the zombie
×
1753
                        // index.
×
1754
                        isZombie, err := db.IsZombieChannel(
×
1755
                                ctx, sqlc.IsZombieChannelParams{
×
1756
                                        Scid:    chanIDB[:],
×
1757
                                        Version: int16(ProtocolV1),
×
1758
                                },
×
1759
                        )
×
1760
                        if err != nil {
×
1761
                                return fmt.Errorf("unable to check if "+
×
1762
                                        "channel is zombie: %w", err)
×
1763
                        } else if isZombie {
×
1764
                                return ErrZombieEdge
×
1765
                        }
×
1766

1767
                        return ErrEdgeNotFound
×
1768
                } else if err != nil {
×
1769
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1770
                }
×
1771

1772
                node1, node2, err := buildNodeVertices(
×
1773
                        row.Node.PubKey, row.Node_2.PubKey,
×
1774
                )
×
1775
                if err != nil {
×
1776
                        return err
×
1777
                }
×
1778

1779
                edge, err = getAndBuildEdgeInfo(
×
1780
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1781
                        node1, node2,
×
1782
                )
×
1783
                if err != nil {
×
1784
                        return fmt.Errorf("unable to build channel info: %w",
×
1785
                                err)
×
1786
                }
×
1787

1788
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1789
                if err != nil {
×
1790
                        return fmt.Errorf("unable to extract channel "+
×
1791
                                "policies: %w", err)
×
1792
                }
×
1793

1794
                policy1, policy2, err = getAndBuildChanPolicies(
×
1795
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1796
                )
×
1797
                if err != nil {
×
1798
                        return fmt.Errorf("unable to build channel "+
×
1799
                                "policies: %w", err)
×
1800
                }
×
1801

1802
                return nil
×
1803
        }, sqldb.NoOpReset)
1804
        if err != nil {
×
1805
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1806
                        err)
×
1807
        }
×
1808

1809
        return edge, policy1, policy2, nil
×
1810
}
1811

1812
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1813
// the channel identified by the funding outpoint. If the channel can't be
1814
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1815
// information for the channel itself is returned as well as two structs that
1816
// contain the routing policies for the channel in either direction.
1817
//
1818
// NOTE: part of the V1Store interface.
1819
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1820
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1821
        *models.ChannelEdgePolicy, error) {
×
1822

×
1823
        var (
×
1824
                ctx              = context.TODO()
×
1825
                edge             *models.ChannelEdgeInfo
×
1826
                policy1, policy2 *models.ChannelEdgePolicy
×
1827
        )
×
1828
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1829
                row, err := db.GetChannelByOutpointWithPolicies(
×
1830
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1831
                                Outpoint: op.String(),
×
1832
                                Version:  int16(ProtocolV1),
×
1833
                        },
×
1834
                )
×
1835
                if errors.Is(err, sql.ErrNoRows) {
×
1836
                        return ErrEdgeNotFound
×
1837
                } else if err != nil {
×
1838
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1839
                }
×
1840

1841
                node1, node2, err := buildNodeVertices(
×
1842
                        row.Node1Pubkey, row.Node2Pubkey,
×
1843
                )
×
1844
                if err != nil {
×
1845
                        return err
×
1846
                }
×
1847

1848
                edge, err = getAndBuildEdgeInfo(
×
1849
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1850
                        node1, node2,
×
1851
                )
×
1852
                if err != nil {
×
1853
                        return fmt.Errorf("unable to build channel info: %w",
×
1854
                                err)
×
1855
                }
×
1856

1857
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1858
                if err != nil {
×
1859
                        return fmt.Errorf("unable to extract channel "+
×
1860
                                "policies: %w", err)
×
1861
                }
×
1862

1863
                policy1, policy2, err = getAndBuildChanPolicies(
×
1864
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1865
                )
×
1866
                if err != nil {
×
1867
                        return fmt.Errorf("unable to build channel "+
×
1868
                                "policies: %w", err)
×
1869
                }
×
1870

1871
                return nil
×
1872
        }, sqldb.NoOpReset)
1873
        if err != nil {
×
1874
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1875
                        err)
×
1876
        }
×
1877

1878
        return edge, policy1, policy2, nil
×
1879
}
1880

1881
// HasChannelEdge returns true if the database knows of a channel edge with the
1882
// passed channel ID, and false otherwise. If an edge with that ID is found
1883
// within the graph, then two time stamps representing the last time the edge
1884
// was updated for both directed edges are returned along with the boolean. If
1885
// it is not found, then the zombie index is checked and its result is returned
1886
// as the second boolean.
1887
//
1888
// NOTE: part of the V1Store interface.
1889
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1890
        bool, error) {
×
1891

×
1892
        ctx := context.TODO()
×
1893

×
1894
        var (
×
1895
                exists          bool
×
1896
                isZombie        bool
×
1897
                node1LastUpdate time.Time
×
1898
                node2LastUpdate time.Time
×
1899
        )
×
1900

×
1901
        // We'll query the cache with the shared lock held to allow multiple
×
1902
        // readers to access values in the cache concurrently if they exist.
×
1903
        s.cacheMu.RLock()
×
1904
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1905
                s.cacheMu.RUnlock()
×
1906
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1907
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1908
                exists, isZombie = entry.flags.unpack()
×
1909

×
1910
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1911
        }
×
1912
        s.cacheMu.RUnlock()
×
1913

×
1914
        s.cacheMu.Lock()
×
1915
        defer s.cacheMu.Unlock()
×
1916

×
1917
        // The item was not found with the shared lock, so we'll acquire the
×
1918
        // exclusive lock and check the cache again in case another method added
×
1919
        // the entry to the cache while no lock was held.
×
1920
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1921
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1922
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1923
                exists, isZombie = entry.flags.unpack()
×
1924

×
1925
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1926
        }
×
1927

1928
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1929
                var chanIDB [8]byte
×
1930
                byteOrder.PutUint64(chanIDB[:], chanID)
×
1931

×
1932
                channel, err := db.GetChannelBySCID(
×
1933
                        ctx, sqlc.GetChannelBySCIDParams{
×
1934
                                Scid:    chanIDB[:],
×
1935
                                Version: int16(ProtocolV1),
×
1936
                        },
×
1937
                )
×
1938
                if errors.Is(err, sql.ErrNoRows) {
×
1939
                        // Check if it is a zombie channel.
×
1940
                        isZombie, err = db.IsZombieChannel(
×
1941
                                ctx, sqlc.IsZombieChannelParams{
×
1942
                                        Scid:    chanIDB[:],
×
1943
                                        Version: int16(ProtocolV1),
×
1944
                                },
×
1945
                        )
×
1946
                        if err != nil {
×
1947
                                return fmt.Errorf("could not check if channel "+
×
1948
                                        "is zombie: %w", err)
×
1949
                        }
×
1950

1951
                        return nil
×
1952
                } else if err != nil {
×
1953
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1954
                }
×
1955

1956
                exists = true
×
1957

×
1958
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
1959
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1960
                                Version:   int16(ProtocolV1),
×
1961
                                ChannelID: channel.ID,
×
1962
                                NodeID:    channel.NodeID1,
×
1963
                        },
×
1964
                )
×
1965
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1966
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
1967
                                err)
×
1968
                } else if err == nil {
×
1969
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
1970
                }
×
1971

1972
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
1973
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1974
                                Version:   int16(ProtocolV1),
×
1975
                                ChannelID: channel.ID,
×
1976
                                NodeID:    channel.NodeID2,
×
1977
                        },
×
1978
                )
×
1979
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1980
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
1981
                                err)
×
1982
                } else if err == nil {
×
1983
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
1984
                }
×
1985

1986
                return nil
×
1987
        }, sqldb.NoOpReset)
1988
        if err != nil {
×
1989
                return time.Time{}, time.Time{}, false, false,
×
1990
                        fmt.Errorf("unable to fetch channel: %w", err)
×
1991
        }
×
1992

1993
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
1994
                upd1Time: node1LastUpdate.Unix(),
×
1995
                upd2Time: node2LastUpdate.Unix(),
×
1996
                flags:    packRejectFlags(exists, isZombie),
×
1997
        })
×
1998

×
1999
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2000
}
2001

2002
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2003
// passed channel point (outpoint). If the passed channel doesn't exist within
2004
// the database, then ErrEdgeNotFound is returned.
2005
//
2006
// NOTE: part of the V1Store interface.
2007
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2008
        var (
×
2009
                ctx       = context.TODO()
×
2010
                channelID uint64
×
2011
        )
×
2012
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2013
                chanID, err := db.GetSCIDByOutpoint(
×
2014
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2015
                                Outpoint: chanPoint.String(),
×
2016
                                Version:  int16(ProtocolV1),
×
2017
                        },
×
2018
                )
×
2019
                if errors.Is(err, sql.ErrNoRows) {
×
2020
                        return ErrEdgeNotFound
×
2021
                } else if err != nil {
×
2022
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2023
                                err)
×
2024
                }
×
2025

2026
                channelID = byteOrder.Uint64(chanID)
×
2027

×
2028
                return nil
×
2029
        }, sqldb.NoOpReset)
2030
        if err != nil {
×
2031
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2032
        }
×
2033

2034
        return channelID, nil
×
2035
}
2036

2037
// IsPublicNode is a helper method that determines whether the node with the
2038
// given public key is seen as a public node in the graph from the graph's
2039
// source node's point of view.
2040
//
2041
// NOTE: part of the V1Store interface.
2042
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2043
        ctx := context.TODO()
×
2044

×
2045
        var isPublic bool
×
2046
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2047
                var err error
×
2048
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2049

×
2050
                return err
×
2051
        }, sqldb.NoOpReset)
×
2052
        if err != nil {
×
2053
                return false, fmt.Errorf("unable to check if node is "+
×
2054
                        "public: %w", err)
×
2055
        }
×
2056

2057
        return isPublic, nil
×
2058
}
2059

2060
// FetchChanInfos returns the set of channel edges that correspond to the passed
2061
// channel ID's. If an edge is the query is unknown to the database, it will
2062
// skipped and the result will contain only those edges that exist at the time
2063
// of the query. This can be used to respond to peer queries that are seeking to
2064
// fill in gaps in their view of the channel graph.
2065
//
2066
// NOTE: part of the V1Store interface.
2067
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2068
        var (
×
2069
                ctx   = context.TODO()
×
2070
                edges []ChannelEdge
×
2071
        )
×
2072
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2073
                for _, chanID := range chanIDs {
×
2074
                        var chanIDB [8]byte
×
2075
                        byteOrder.PutUint64(chanIDB[:], chanID)
×
2076

×
2077
                        // TODO(elle): potentially optimize this by using
×
2078
                        //  sqlc.slice() once that works for both SQLite and
×
2079
                        //  Postgres.
×
2080
                        row, err := db.GetChannelBySCIDWithPolicies(
×
2081
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
2082
                                        Scid:    chanIDB[:],
×
2083
                                        Version: int16(ProtocolV1),
×
2084
                                },
×
2085
                        )
×
2086
                        if errors.Is(err, sql.ErrNoRows) {
×
2087
                                continue
×
2088
                        } else if err != nil {
×
2089
                                return fmt.Errorf("unable to fetch channel: %w",
×
2090
                                        err)
×
2091
                        }
×
2092

2093
                        node1, node2, err := buildNodes(
×
2094
                                ctx, db, row.Node, row.Node_2,
×
2095
                        )
×
2096
                        if err != nil {
×
2097
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2098
                                        err)
×
2099
                        }
×
2100

2101
                        edge, err := getAndBuildEdgeInfo(
×
2102
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2103
                                row.Channel, node1.PubKeyBytes,
×
2104
                                node2.PubKeyBytes,
×
2105
                        )
×
2106
                        if err != nil {
×
2107
                                return fmt.Errorf("unable to build "+
×
2108
                                        "channel info: %w", err)
×
2109
                        }
×
2110

2111
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2112
                        if err != nil {
×
2113
                                return fmt.Errorf("unable to extract channel "+
×
2114
                                        "policies: %w", err)
×
2115
                        }
×
2116

2117
                        p1, p2, err := getAndBuildChanPolicies(
×
2118
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2119
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2120
                        )
×
2121
                        if err != nil {
×
2122
                                return fmt.Errorf("unable to build channel "+
×
2123
                                        "policies: %w", err)
×
2124
                        }
×
2125

2126
                        edges = append(edges, ChannelEdge{
×
2127
                                Info:    edge,
×
2128
                                Policy1: p1,
×
2129
                                Policy2: p2,
×
2130
                                Node1:   node1,
×
2131
                                Node2:   node2,
×
2132
                        })
×
2133
                }
2134

2135
                return nil
×
2136
        }, func() {
×
2137
                edges = nil
×
2138
        })
×
2139
        if err != nil {
×
2140
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2141
        }
×
2142

2143
        return edges, nil
×
2144
}
2145

2146
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2147
// ID's that we don't know and are not known zombies of the passed set. In other
2148
// words, we perform a set difference of our set of chan ID's and the ones
2149
// passed in. This method can be used by callers to determine the set of
2150
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2151
// known zombies is also returned.
2152
//
2153
// NOTE: part of the V1Store interface.
2154
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2155
        []ChannelUpdateInfo, error) {
×
2156

×
2157
        var (
×
2158
                ctx          = context.TODO()
×
2159
                newChanIDs   []uint64
×
2160
                knownZombies []ChannelUpdateInfo
×
2161
        )
×
2162
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2163
                for _, chanInfo := range chansInfo {
×
2164
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2165
                        var chanIDB [8]byte
×
2166
                        byteOrder.PutUint64(chanIDB[:], channelID)
×
2167

×
2168
                        // TODO(elle): potentially optimize this by using
×
2169
                        //  sqlc.slice() once that works for both SQLite and
×
2170
                        //  Postgres.
×
2171
                        _, err := db.GetChannelBySCID(
×
2172
                                ctx, sqlc.GetChannelBySCIDParams{
×
2173
                                        Version: int16(ProtocolV1),
×
2174
                                        Scid:    chanIDB[:],
×
2175
                                },
×
2176
                        )
×
2177
                        if err == nil {
×
2178
                                continue
×
2179
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
2180
                                return fmt.Errorf("unable to fetch channel: %w",
×
2181
                                        err)
×
2182
                        }
×
2183

2184
                        isZombie, err := db.IsZombieChannel(
×
2185
                                ctx, sqlc.IsZombieChannelParams{
×
2186
                                        Scid:    chanIDB[:],
×
2187
                                        Version: int16(ProtocolV1),
×
2188
                                },
×
2189
                        )
×
2190
                        if err != nil {
×
2191
                                return fmt.Errorf("unable to fetch zombie "+
×
2192
                                        "channel: %w", err)
×
2193
                        }
×
2194

2195
                        if isZombie {
×
2196
                                knownZombies = append(knownZombies, chanInfo)
×
2197

×
2198
                                continue
×
2199
                        }
2200

2201
                        newChanIDs = append(newChanIDs, channelID)
×
2202
                }
2203

2204
                return nil
×
2205
        }, func() {
×
2206
                newChanIDs = nil
×
2207
                knownZombies = nil
×
2208
        })
×
2209
        if err != nil {
×
2210
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2211
        }
×
2212

2213
        return newChanIDs, knownZombies, nil
×
2214
}
2215

2216
// PruneGraphNodes is a garbage collection method which attempts to prune out
2217
// any nodes from the channel graph that are currently unconnected. This ensure
2218
// that we only maintain a graph of reachable nodes. In the event that a pruned
2219
// node gains more channels, it will be re-added back to the graph.
2220
//
2221
// NOTE: this prunes nodes across protocol versions. It will never prune the
2222
// source nodes.
2223
//
2224
// NOTE: part of the V1Store interface.
NEW
2225
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
NEW
2226
        var ctx = context.TODO()
×
NEW
2227

×
NEW
2228
        var prunedNodes []route.Vertex
×
NEW
2229
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2230
                var err error
×
NEW
2231
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
NEW
2232

×
NEW
2233
                return err
×
NEW
2234
        }, func() {
×
NEW
2235
                prunedNodes = nil
×
NEW
2236
        })
×
NEW
2237
        if err != nil {
×
NEW
2238
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
NEW
2239
        }
×
2240

NEW
2241
        return prunedNodes, nil
×
2242
}
2243

2244
// PruneGraph prunes newly closed channels from the channel graph in response
2245
// to a new block being solved on the network. Any transactions which spend the
2246
// funding output of any known channels within he graph will be deleted.
2247
// Additionally, the "prune tip", or the last block which has been used to
2248
// prune the graph is stored so callers can ensure the graph is fully in sync
2249
// with the current UTXO state. A slice of channels that have been closed by
2250
// the target block along with any pruned nodes are returned if the function
2251
// succeeds without error.
2252
//
2253
// NOTE: part of the V1Store interface.
2254
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2255
        blockHash *chainhash.Hash, blockHeight uint32) (
NEW
2256
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
NEW
2257

×
NEW
2258
        ctx := context.TODO()
×
NEW
2259

×
NEW
2260
        s.cacheMu.Lock()
×
NEW
2261
        defer s.cacheMu.Unlock()
×
NEW
2262

×
NEW
2263
        var (
×
NEW
2264
                closedChans []*models.ChannelEdgeInfo
×
NEW
2265
                prunedNodes []route.Vertex
×
NEW
2266
        )
×
NEW
2267
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2268
                for _, outpoint := range spentOutputs {
×
NEW
2269
                        // TODO(elle): potentially optimize this by using
×
NEW
2270
                        //  sqlc.slice() once that works for both SQLite and
×
NEW
2271
                        //  Postgres.
×
NEW
2272
                        //
×
NEW
2273
                        // NOTE: this fetches channels for all protocol
×
NEW
2274
                        // versions.
×
NEW
2275
                        row, err := db.GetChannelByOutpoint(
×
NEW
2276
                                ctx, outpoint.String(),
×
NEW
2277
                        )
×
NEW
2278
                        if errors.Is(err, sql.ErrNoRows) {
×
NEW
2279
                                continue
×
NEW
2280
                        } else if err != nil {
×
NEW
2281
                                return fmt.Errorf("unable to fetch channel: %w",
×
NEW
2282
                                        err)
×
NEW
2283
                        }
×
2284

NEW
2285
                        node1, node2, err := buildNodeVertices(
×
NEW
2286
                                row.Node1Pubkey, row.Node2Pubkey,
×
NEW
2287
                        )
×
NEW
2288
                        if err != nil {
×
NEW
2289
                                return err
×
NEW
2290
                        }
×
2291

NEW
2292
                        info, err := getAndBuildEdgeInfo(
×
NEW
2293
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
NEW
2294
                                row.Channel, node1, node2,
×
NEW
2295
                        )
×
NEW
2296
                        if err != nil {
×
NEW
2297
                                return err
×
NEW
2298
                        }
×
2299

NEW
2300
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
NEW
2301
                        if err != nil {
×
NEW
2302
                                return fmt.Errorf("unable to delete "+
×
NEW
2303
                                        "channel: %w", err)
×
NEW
2304
                        }
×
2305

NEW
2306
                        closedChans = append(closedChans, info)
×
2307
                }
2308

NEW
2309
                err := db.UpsertPruneLogEntry(
×
NEW
2310
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
NEW
2311
                                BlockHash:   blockHash[:],
×
NEW
2312
                                BlockHeight: int64(blockHeight),
×
NEW
2313
                        },
×
NEW
2314
                )
×
NEW
2315
                if err != nil {
×
NEW
2316
                        return fmt.Errorf("unable to insert prune log "+
×
NEW
2317
                                "entry: %w", err)
×
NEW
2318
                }
×
2319

2320
                // Now that we've pruned some channels, we'll also prune any
2321
                // nodes that no longer have any channels.
NEW
2322
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
NEW
2323
                if err != nil {
×
NEW
2324
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
NEW
2325
                                err)
×
NEW
2326
                }
×
2327

NEW
2328
                return nil
×
NEW
2329
        }, func() {
×
NEW
2330
                prunedNodes = nil
×
NEW
2331
                closedChans = nil
×
NEW
2332
        })
×
NEW
2333
        if err != nil {
×
NEW
2334
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
NEW
2335
        }
×
2336

NEW
2337
        for _, channel := range closedChans {
×
NEW
2338
                s.rejectCache.remove(channel.ChannelID)
×
NEW
2339
                s.chanCache.remove(channel.ChannelID)
×
NEW
2340
        }
×
2341

NEW
2342
        return closedChans, prunedNodes, nil
×
2343
}
2344

2345
// ChannelView returns the verifiable edge information for each active channel
2346
// within the known channel graph. The set of UTXOs (along with their scripts)
2347
// returned are the ones that need to be watched on chain to detect channel
2348
// closes on the resident blockchain.
2349
//
2350
// NOTE: part of the V1Store interface.
NEW
2351
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
NEW
2352
        var (
×
NEW
2353
                ctx        = context.TODO()
×
NEW
2354
                edgePoints []EdgePoint
×
NEW
2355
        )
×
NEW
2356

×
NEW
2357
        handleChannel := func(db SQLQueries,
×
NEW
2358
                channel sqlc.ListChannelsPaginatedRow) error {
×
NEW
2359

×
NEW
2360
                pkScript, err := genMultiSigP2WSH(
×
NEW
2361
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
NEW
2362
                )
×
NEW
2363
                if err != nil {
×
NEW
2364
                        return err
×
NEW
2365
                }
×
2366

NEW
2367
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
NEW
2368
                if err != nil {
×
NEW
2369
                        return err
×
NEW
2370
                }
×
2371

NEW
2372
                edgePoints = append(edgePoints, EdgePoint{
×
NEW
2373
                        FundingPkScript: pkScript,
×
NEW
2374
                        OutPoint:        *op,
×
NEW
2375
                })
×
NEW
2376

×
NEW
2377
                return nil
×
2378
        }
2379

NEW
2380
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2381
                lastID := int64(-1)
×
NEW
2382
                for {
×
NEW
2383
                        rows, err := db.ListChannelsPaginated(
×
NEW
2384
                                ctx, sqlc.ListChannelsPaginatedParams{
×
NEW
2385
                                        Version: int16(ProtocolV1),
×
NEW
2386
                                        ID:      lastID,
×
NEW
2387
                                        Limit:   pageSize,
×
NEW
2388
                                },
×
NEW
2389
                        )
×
NEW
2390
                        if err != nil {
×
NEW
2391
                                return err
×
NEW
2392
                        }
×
2393

NEW
2394
                        if len(rows) == 0 {
×
NEW
2395
                                break
×
2396
                        }
2397

NEW
2398
                        for _, row := range rows {
×
NEW
2399
                                err := handleChannel(db, row)
×
NEW
2400
                                if err != nil {
×
NEW
2401
                                        return err
×
NEW
2402
                                }
×
2403

NEW
2404
                                lastID = row.ID
×
2405
                        }
2406
                }
2407

NEW
2408
                return nil
×
NEW
2409
        }, func() {
×
NEW
2410
                edgePoints = nil
×
NEW
2411
        })
×
NEW
2412
        if err != nil {
×
NEW
2413
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
NEW
2414
        }
×
2415

NEW
2416
        return edgePoints, nil
×
2417
}
2418

2419
// PruneTip returns the block height and hash of the latest block that has been
2420
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2421
// to tell if the graph is currently in sync with the current best known UTXO
2422
// state.
2423
//
2424
// NOTE: part of the V1Store interface.
NEW
2425
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
NEW
2426
        var (
×
NEW
2427
                ctx       = context.TODO()
×
NEW
2428
                tipHash   chainhash.Hash
×
NEW
2429
                tipHeight uint32
×
NEW
2430
        )
×
NEW
2431
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2432
                pruneTip, err := db.GetPruneTip(ctx)
×
NEW
2433
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
2434
                        return ErrGraphNeverPruned
×
NEW
2435
                } else if err != nil {
×
NEW
2436
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
NEW
2437
                }
×
2438

NEW
2439
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
NEW
2440
                tipHeight = uint32(pruneTip.BlockHeight)
×
NEW
2441

×
NEW
2442
                return nil
×
2443
        }, sqldb.NoOpReset)
NEW
2444
        if err != nil {
×
NEW
2445
                return nil, 0, err
×
NEW
2446
        }
×
2447

NEW
2448
        return &tipHash, tipHeight, nil
×
2449
}
2450

2451
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2452
//
2453
// NOTE: this prunes nodes across protocol versions. It will never prune the
2454
// source nodes.
2455
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
NEW
2456
        db SQLQueries) ([]route.Vertex, error) {
×
NEW
2457

×
NEW
2458
        // Fetch all un-connected nodes from the database.
×
NEW
2459
        // NOTE: this will not include any nodes that are listed in the
×
NEW
2460
        // source table.
×
NEW
2461
        nodes, err := db.GetUnconnectedNodes(ctx)
×
NEW
2462
        if err != nil {
×
NEW
2463
                return nil, fmt.Errorf("unable to fetch unconnected nodes: %w",
×
NEW
2464
                        err)
×
NEW
2465
        }
×
2466

NEW
2467
        prunedNodes := make([]route.Vertex, 0, len(nodes))
×
NEW
2468
        for _, node := range nodes {
×
NEW
2469
                // TODO(elle): update to use sqlc.slice() once that works.
×
NEW
2470
                if err = db.DeleteNode(ctx, node.ID); err != nil {
×
NEW
2471
                        return nil, fmt.Errorf("unable to delete "+
×
NEW
2472
                                "node(id=%d): %w", node.ID, err)
×
NEW
2473
                }
×
2474

NEW
2475
                pubKey, err := route.NewVertexFromBytes(node.PubKey)
×
NEW
2476
                if err != nil {
×
NEW
2477
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
NEW
2478
                                "for node(id=%d): %w", node.ID, err)
×
NEW
2479
                }
×
2480

NEW
2481
                prunedNodes = append(prunedNodes, pubKey)
×
2482
        }
2483

NEW
2484
        return prunedNodes, nil
×
2485
}
2486

2487
// DisconnectBlockAtHeight is used to indicate that the block specified
2488
// by the passed height has been disconnected from the main chain. This
2489
// will "rewind" the graph back to the height below, deleting channels
2490
// that are no longer confirmed from the graph. The prune log will be
2491
// set to the last prune height valid for the remaining chain.
2492
// Channels that were removed from the graph resulting from the
2493
// disconnected block are returned.
2494
//
2495
// NOTE: part of the V1Store interface.
2496
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
NEW
2497
        []*models.ChannelEdgeInfo, error) {
×
NEW
2498

×
NEW
2499
        ctx := context.TODO()
×
NEW
2500

×
NEW
2501
        var (
×
NEW
2502
                // Every channel having a ShortChannelID starting at 'height'
×
NEW
2503
                // will no longer be confirmed.
×
NEW
2504
                startShortChanID = lnwire.ShortChannelID{
×
NEW
2505
                        BlockHeight: height,
×
NEW
2506
                }
×
NEW
2507

×
NEW
2508
                // Delete everything after this height from the db up until the
×
NEW
2509
                // SCID alias range.
×
NEW
2510
                endShortChanID = aliasmgr.StartingAlias
×
NEW
2511

×
NEW
2512
                removedChans []*models.ChannelEdgeInfo
×
NEW
2513
        )
×
NEW
2514

×
NEW
2515
        var chanIDStart [8]byte
×
NEW
2516
        byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64())
×
NEW
2517
        var chanIDEnd [8]byte
×
NEW
2518
        byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64())
×
NEW
2519

×
NEW
2520
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2521
                rows, err := db.GetChannelsBySCIDRange(
×
NEW
2522
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
NEW
2523
                                StartScid: chanIDStart[:],
×
NEW
2524
                                EndScid:   chanIDEnd[:],
×
NEW
2525
                        },
×
NEW
2526
                )
×
NEW
2527
                if err != nil {
×
NEW
2528
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
2529
                }
×
2530

NEW
2531
                for _, row := range rows {
×
NEW
2532
                        node1, node2, err := buildNodeVertices(
×
NEW
2533
                                row.Node1PubKey, row.Node2PubKey,
×
NEW
2534
                        )
×
NEW
2535
                        if err != nil {
×
NEW
2536
                                return err
×
NEW
2537
                        }
×
2538

NEW
2539
                        channel, err := getAndBuildEdgeInfo(
×
NEW
2540
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
NEW
2541
                                row.Channel, node1, node2,
×
NEW
2542
                        )
×
NEW
2543
                        if err != nil {
×
NEW
2544
                                return err
×
NEW
2545
                        }
×
2546

NEW
2547
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
NEW
2548
                        if err != nil {
×
NEW
2549
                                return fmt.Errorf("unable to delete "+
×
NEW
2550
                                        "channel: %w", err)
×
NEW
2551
                        }
×
2552

NEW
2553
                        removedChans = append(removedChans, channel)
×
2554
                }
2555

NEW
2556
                return db.DeletePruneLogEntriesInRange(
×
NEW
2557
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
NEW
2558
                                StartHeight: int64(height),
×
NEW
2559
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
NEW
2560
                        },
×
NEW
2561
                )
×
NEW
2562
        }, func() {
×
NEW
2563
                removedChans = nil
×
NEW
2564
        })
×
NEW
2565
        if err != nil {
×
NEW
2566
                return nil, fmt.Errorf("unable to disconnect block at "+
×
NEW
2567
                        "height: %w", err)
×
NEW
2568
        }
×
2569

NEW
2570
        for _, channel := range removedChans {
×
NEW
2571
                s.rejectCache.remove(channel.ChannelID)
×
NEW
2572
                s.chanCache.remove(channel.ChannelID)
×
NEW
2573
        }
×
2574

NEW
2575
        return removedChans, nil
×
2576
}
2577

2578
// forEachNodeDirectedChannel iterates through all channels of a given
2579
// node, executing the passed callback on the directed edge representing the
2580
// channel and its incoming policy. If the node is not found, no error is
2581
// returned.
2582
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2583
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2584

×
2585
        toNodeCallback := func() route.Vertex {
×
2586
                return nodePub
×
2587
        }
×
2588

2589
        dbID, err := db.GetNodeIDByPubKey(
×
2590
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2591
                        Version: int16(ProtocolV1),
×
2592
                        PubKey:  nodePub[:],
×
2593
                },
×
2594
        )
×
2595
        if errors.Is(err, sql.ErrNoRows) {
×
2596
                return nil
×
2597
        } else if err != nil {
×
2598
                return fmt.Errorf("unable to fetch node: %w", err)
×
2599
        }
×
2600

2601
        rows, err := db.ListChannelsByNodeID(
×
2602
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2603
                        Version: int16(ProtocolV1),
×
2604
                        NodeID1: dbID,
×
2605
                },
×
2606
        )
×
2607
        if err != nil {
×
2608
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2609
        }
×
2610

2611
        // Exit early if there are no channels for this node so we don't
2612
        // do the unnecessary feature fetching.
2613
        if len(rows) == 0 {
×
2614
                return nil
×
2615
        }
×
2616

2617
        features, err := getNodeFeatures(ctx, db, dbID)
×
2618
        if err != nil {
×
2619
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2620
        }
×
2621

2622
        for _, row := range rows {
×
2623
                node1, node2, err := buildNodeVertices(
×
2624
                        row.Node1Pubkey, row.Node2Pubkey,
×
2625
                )
×
2626
                if err != nil {
×
2627
                        return fmt.Errorf("unable to build node vertices: %w",
×
2628
                                err)
×
2629
                }
×
2630

2631
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
2632

×
2633
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2634
                if err != nil {
×
2635
                        return err
×
2636
                }
×
2637

2638
                var p1, p2 *models.CachedEdgePolicy
×
2639
                if dbPol1 != nil {
×
2640
                        policy1, err := buildChanPolicy(
×
2641
                                *dbPol1, edge.ChannelID, nil, node2, true,
×
2642
                        )
×
2643
                        if err != nil {
×
2644
                                return err
×
2645
                        }
×
2646

2647
                        p1 = models.NewCachedPolicy(policy1)
×
2648
                }
2649
                if dbPol2 != nil {
×
2650
                        policy2, err := buildChanPolicy(
×
2651
                                *dbPol2, edge.ChannelID, nil, node1, false,
×
2652
                        )
×
2653
                        if err != nil {
×
2654
                                return err
×
2655
                        }
×
2656

2657
                        p2 = models.NewCachedPolicy(policy2)
×
2658
                }
2659

2660
                // Determine the outgoing and incoming policy for this
2661
                // channel and node combo.
2662
                outPolicy, inPolicy := p1, p2
×
2663
                if p1 != nil && node2 == nodePub {
×
2664
                        outPolicy, inPolicy = p2, p1
×
2665
                } else if p2 != nil && node1 != nodePub {
×
2666
                        outPolicy, inPolicy = p2, p1
×
2667
                }
×
2668

2669
                var cachedInPolicy *models.CachedEdgePolicy
×
2670
                if inPolicy != nil {
×
2671
                        cachedInPolicy = inPolicy
×
2672
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2673
                        cachedInPolicy.ToNodeFeatures = features
×
2674
                }
×
2675

2676
                directedChannel := &DirectedChannel{
×
2677
                        ChannelID:    edge.ChannelID,
×
2678
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2679
                        OtherNode:    edge.NodeKey2Bytes,
×
2680
                        Capacity:     edge.Capacity,
×
2681
                        OutPolicySet: outPolicy != nil,
×
2682
                        InPolicy:     cachedInPolicy,
×
2683
                }
×
2684
                if outPolicy != nil {
×
2685
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2686
                                directedChannel.InboundFee = fee
×
2687
                        })
×
2688
                }
2689

2690
                if nodePub == edge.NodeKey2Bytes {
×
2691
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2692
                }
×
2693

2694
                if err := cb(directedChannel); err != nil {
×
2695
                        return err
×
2696
                }
×
2697
        }
2698

2699
        return nil
×
2700
}
2701

2702
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2703
// and executes the provided callback for each node.
2704
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
2705
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
2706

×
2707
        var lastID int64
×
2708

×
2709
        for {
×
2710
                nodes, err := db.ListNodeIDsAndPubKeys(
×
2711
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
2712
                                Version: int16(ProtocolV1),
×
2713
                                ID:      lastID,
×
2714
                                Limit:   pageSize,
×
2715
                        },
×
2716
                )
×
2717
                if err != nil {
×
2718
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
2719
                }
×
2720

2721
                if len(nodes) == 0 {
×
2722
                        break
×
2723
                }
2724

2725
                for _, node := range nodes {
×
2726
                        var pub route.Vertex
×
2727
                        copy(pub[:], node.PubKey)
×
2728

×
2729
                        if err := cb(node.ID, pub); err != nil {
×
2730
                                return fmt.Errorf("forEachNodeCacheable "+
×
2731
                                        "callback failed for node(id=%d): %w",
×
2732
                                        node.ID, err)
×
2733
                        }
×
2734

2735
                        lastID = node.ID
×
2736
                }
2737
        }
2738

2739
        return nil
×
2740
}
2741

2742
// forEachNodeChannel iterates through all channels of a node, executing
2743
// the passed callback on each. The call-back is provided with the channel's
2744
// edge information, the outgoing policy and the incoming policy for the
2745
// channel and node combo.
2746
func forEachNodeChannel(ctx context.Context, db SQLQueries,
2747
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
2748
                *models.ChannelEdgePolicy,
2749
                *models.ChannelEdgePolicy) error) error {
×
2750

×
2751
        // Get all the V1 channels for this node.Add commentMore actions
×
2752
        rows, err := db.ListChannelsByNodeID(
×
2753
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2754
                        Version: int16(ProtocolV1),
×
2755
                        NodeID1: id,
×
2756
                },
×
2757
        )
×
2758
        if err != nil {
×
2759
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2760
        }
×
2761

2762
        // Call the call-back for each channel and its known policies.
2763
        for _, row := range rows {
×
2764
                node1, node2, err := buildNodeVertices(
×
2765
                        row.Node1Pubkey, row.Node2Pubkey,
×
2766
                )
×
2767
                if err != nil {
×
2768
                        return fmt.Errorf("unable to build node vertices: %w",
×
2769
                                err)
×
2770
                }
×
2771

2772
                edge, err := getAndBuildEdgeInfo(
×
2773
                        ctx, db, chain, row.Channel.ID, row.Channel, node1,
×
2774
                        node2,
×
2775
                )
×
2776
                if err != nil {
×
2777
                        return fmt.Errorf("unable to build channel info: %w",
×
2778
                                err)
×
2779
                }
×
2780

2781
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2782
                if err != nil {
×
2783
                        return fmt.Errorf("unable to extract channel "+
×
2784
                                "policies: %w", err)
×
2785
                }
×
2786

2787
                p1, p2, err := getAndBuildChanPolicies(
×
2788
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
2789
                )
×
2790
                if err != nil {
×
2791
                        return fmt.Errorf("unable to build channel "+
×
2792
                                "policies: %w", err)
×
2793
                }
×
2794

2795
                // Determine the outgoing and incoming policy for this
2796
                // channel and node combo.
2797
                p1ToNode := row.Channel.NodeID2
×
2798
                p2ToNode := row.Channel.NodeID1
×
2799
                outPolicy, inPolicy := p1, p2
×
2800
                if (p1 != nil && p1ToNode == id) ||
×
2801
                        (p2 != nil && p2ToNode != id) {
×
2802

×
2803
                        outPolicy, inPolicy = p2, p1
×
2804
                }
×
2805

2806
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
2807
                        return err
×
2808
                }
×
2809
        }
2810

2811
        return nil
×
2812
}
2813

2814
// updateChanEdgePolicy upserts the channel policy info we have stored for
2815
// a channel we already know of.
2816
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
2817
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
2818
        error) {
×
2819

×
2820
        var (
×
2821
                node1Pub, node2Pub route.Vertex
×
2822
                isNode1            bool
×
2823
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
2824
        )
×
2825

×
2826
        // Check that this edge policy refers to a channel that we already
×
2827
        // know of. We do this explicitly so that we can return the appropriate
×
2828
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
2829
        // abort the transaction which would abort the entire batch.
×
2830
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
2831
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
2832
                        Scid:    chanIDB[:],
×
2833
                        Version: int16(ProtocolV1),
×
2834
                },
×
2835
        )
×
2836
        if errors.Is(err, sql.ErrNoRows) {
×
2837
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
2838
        } else if err != nil {
×
2839
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
2840
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
2841
        }
×
2842

2843
        copy(node1Pub[:], dbChan.Node1PubKey)
×
2844
        copy(node2Pub[:], dbChan.Node2PubKey)
×
2845

×
2846
        // Figure out which node this edge is from.
×
2847
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
2848
        nodeID := dbChan.NodeID1
×
2849
        if !isNode1 {
×
2850
                nodeID = dbChan.NodeID2
×
2851
        }
×
2852

2853
        var (
×
2854
                inboundBase sql.NullInt64
×
2855
                inboundRate sql.NullInt64
×
2856
        )
×
2857
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2858
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
2859
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
2860
        })
×
2861

2862
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
2863
                Version:     int16(ProtocolV1),
×
2864
                ChannelID:   dbChan.ID,
×
2865
                NodeID:      nodeID,
×
2866
                Timelock:    int32(edge.TimeLockDelta),
×
2867
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
2868
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
2869
                MinHtlcMsat: int64(edge.MinHTLC),
×
2870
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
2871
                Disabled: sql.NullBool{
×
2872
                        Valid: true,
×
2873
                        Bool:  edge.IsDisabled(),
×
2874
                },
×
2875
                MaxHtlcMsat: sql.NullInt64{
×
2876
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
2877
                        Int64: int64(edge.MaxHTLC),
×
2878
                },
×
2879
                InboundBaseFeeMsat:      inboundBase,
×
2880
                InboundFeeRateMilliMsat: inboundRate,
×
2881
                Signature:               edge.SigBytes,
×
2882
        })
×
2883
        if err != nil {
×
2884
                return node1Pub, node2Pub, isNode1,
×
2885
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
2886
        }
×
2887

2888
        // Convert the flat extra opaque data into a map of TLV types to
2889
        // values.
2890
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
2891
        if err != nil {
×
2892
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
2893
                        "marshal extra opaque data: %w", err)
×
2894
        }
×
2895

2896
        // Update the channel policy's extra signed fields.
2897
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
2898
        if err != nil {
×
2899
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
2900
                        "policy extra TLVs: %w", err)
×
2901
        }
×
2902

2903
        return node1Pub, node2Pub, isNode1, nil
×
2904
}
2905

2906
// getNodeByPubKey attempts to look up a target node by its public key.
2907
func getNodeByPubKey(ctx context.Context, db SQLQueries,
2908
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
2909

×
2910
        dbNode, err := db.GetNodeByPubKey(
×
2911
                ctx, sqlc.GetNodeByPubKeyParams{
×
2912
                        Version: int16(ProtocolV1),
×
2913
                        PubKey:  pubKey[:],
×
2914
                },
×
2915
        )
×
2916
        if errors.Is(err, sql.ErrNoRows) {
×
2917
                return 0, nil, ErrGraphNodeNotFound
×
2918
        } else if err != nil {
×
2919
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
2920
        }
×
2921

2922
        node, err := buildNode(ctx, db, &dbNode)
×
2923
        if err != nil {
×
2924
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
2925
        }
×
2926

2927
        return dbNode.ID, node, nil
×
2928
}
2929

2930
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
2931
// provided database channel row and the public keys of the two nodes
2932
// involved in the channel.
2933
func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub,
2934
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
2935

×
2936
        return &models.CachedEdgeInfo{
×
2937
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
2938
                NodeKey1Bytes: node1Pub,
×
2939
                NodeKey2Bytes: node2Pub,
×
2940
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
2941
        }
×
2942
}
×
2943

2944
// buildNode constructs a LightningNode instance from the given database node
2945
// record. The node's features, addresses and extra signed fields are also
2946
// fetched from the database and set on the node.
2947
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
2948
        *models.LightningNode, error) {
×
2949

×
2950
        if dbNode.Version != int16(ProtocolV1) {
×
2951
                return nil, fmt.Errorf("unsupported node version: %d",
×
2952
                        dbNode.Version)
×
2953
        }
×
2954

2955
        var pub [33]byte
×
2956
        copy(pub[:], dbNode.PubKey)
×
2957

×
2958
        node := &models.LightningNode{
×
2959
                PubKeyBytes: pub,
×
2960
                Features:    lnwire.EmptyFeatureVector(),
×
2961
                LastUpdate:  time.Unix(0, 0),
×
2962
        }
×
2963

×
2964
        if len(dbNode.Signature) == 0 {
×
2965
                return node, nil
×
2966
        }
×
2967

2968
        node.HaveNodeAnnouncement = true
×
2969
        node.AuthSigBytes = dbNode.Signature
×
2970
        node.Alias = dbNode.Alias.String
×
2971
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
2972

×
2973
        var err error
×
2974
        if dbNode.Color.Valid {
×
2975
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
2976
                if err != nil {
×
2977
                        return nil, fmt.Errorf("unable to decode color: %w",
×
2978
                                err)
×
2979
                }
×
2980
        }
2981

2982
        // Fetch the node's features.
2983
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
2984
        if err != nil {
×
2985
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
2986
                        "features: %w", dbNode.ID, err)
×
2987
        }
×
2988

2989
        // Fetch the node's addresses.
2990
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
2991
        if err != nil {
×
2992
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
2993
                        "addresses: %w", dbNode.ID, err)
×
2994
        }
×
2995

2996
        // Fetch the node's extra signed fields.
2997
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
2998
        if err != nil {
×
2999
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3000
                        "extra signed fields: %w", dbNode.ID, err)
×
3001
        }
×
3002

3003
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
3004
        if err != nil {
×
3005
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3006
                        "fields: %w", err)
×
3007
        }
×
3008

3009
        if len(recs) != 0 {
×
3010
                node.ExtraOpaqueData = recs
×
3011
        }
×
3012

3013
        return node, nil
×
3014
}
3015

3016
// getNodeFeatures fetches the feature bits and constructs the feature vector
3017
// for a node with the given DB ID.
3018
func getNodeFeatures(ctx context.Context, db SQLQueries,
3019
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3020

×
3021
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3022
        if err != nil {
×
3023
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3024
                        nodeID, err)
×
3025
        }
×
3026

3027
        features := lnwire.EmptyFeatureVector()
×
3028
        for _, feature := range rows {
×
3029
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3030
        }
×
3031

3032
        return features, nil
×
3033
}
3034

3035
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3036
// given DB ID.
3037
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3038
        nodeID int64) (map[uint64][]byte, error) {
×
3039

×
3040
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3041
        if err != nil {
×
3042
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
3043
                        "signed fields: %w", nodeID, err)
×
3044
        }
×
3045

3046
        extraFields := make(map[uint64][]byte)
×
3047
        for _, field := range fields {
×
3048
                extraFields[uint64(field.Type)] = field.Value
×
3049
        }
×
3050

3051
        return extraFields, nil
×
3052
}
3053

3054
// upsertNode upserts the node record into the database. If the node already
3055
// exists, then the node's information is updated. If the node doesn't exist,
3056
// then a new node is created. The node's features, addresses and extra TLV
3057
// types are also updated. The node's DB ID is returned.
3058
func upsertNode(ctx context.Context, db SQLQueries,
3059
        node *models.LightningNode) (int64, error) {
×
3060

×
3061
        params := sqlc.UpsertNodeParams{
×
3062
                Version: int16(ProtocolV1),
×
3063
                PubKey:  node.PubKeyBytes[:],
×
3064
        }
×
3065

×
3066
        if node.HaveNodeAnnouncement {
×
3067
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3068
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3069
                params.Alias = sqldb.SQLStr(node.Alias)
×
3070
                params.Signature = node.AuthSigBytes
×
3071
        }
×
3072

3073
        nodeID, err := db.UpsertNode(ctx, params)
×
3074
        if err != nil {
×
3075
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3076
                        err)
×
3077
        }
×
3078

3079
        // We can exit here if we don't have the announcement yet.
3080
        if !node.HaveNodeAnnouncement {
×
3081
                return nodeID, nil
×
3082
        }
×
3083

3084
        // Update the node's features.
3085
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3086
        if err != nil {
×
3087
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3088
        }
×
3089

3090
        // Update the node's addresses.
3091
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3092
        if err != nil {
×
3093
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3094
        }
×
3095

3096
        // Convert the flat extra opaque data into a map of TLV types to
3097
        // values.
3098
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3099
        if err != nil {
×
3100
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3101
                        err)
×
3102
        }
×
3103

3104
        // Update the node's extra signed fields.
3105
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3106
        if err != nil {
×
3107
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3108
        }
×
3109

3110
        return nodeID, nil
×
3111
}
3112

3113
// upsertNodeFeatures updates the node's features node_features table. This
3114
// includes deleting any feature bits no longer present and inserting any new
3115
// feature bits. If the feature bit does not yet exist in the features table,
3116
// then an entry is created in that table first.
3117
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3118
        features *lnwire.FeatureVector) error {
×
3119

×
3120
        // Get any existing features for the node.
×
3121
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3122
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3123
                return err
×
3124
        }
×
3125

3126
        // Copy the nodes latest set of feature bits.
3127
        newFeatures := make(map[int32]struct{})
×
3128
        if features != nil {
×
3129
                for feature := range features.Features() {
×
3130
                        newFeatures[int32(feature)] = struct{}{}
×
3131
                }
×
3132
        }
3133

3134
        // For any current feature that already exists in the DB, remove it from
3135
        // the in-memory map. For any existing feature that does not exist in
3136
        // the in-memory map, delete it from the database.
3137
        for _, feature := range existingFeatures {
×
3138
                // The feature is still present, so there are no updates to be
×
3139
                // made.
×
3140
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3141
                        delete(newFeatures, feature.FeatureBit)
×
3142
                        continue
×
3143
                }
3144

3145
                // The feature is no longer present, so we remove it from the
3146
                // database.
3147
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3148
                        NodeID:     nodeID,
×
3149
                        FeatureBit: feature.FeatureBit,
×
3150
                })
×
3151
                if err != nil {
×
3152
                        return fmt.Errorf("unable to delete node(%d) "+
×
3153
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3154
                                err)
×
3155
                }
×
3156
        }
3157

3158
        // Any remaining entries in newFeatures are new features that need to be
3159
        // added to the database for the first time.
3160
        for feature := range newFeatures {
×
3161
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3162
                        NodeID:     nodeID,
×
3163
                        FeatureBit: feature,
×
3164
                })
×
3165
                if err != nil {
×
3166
                        return fmt.Errorf("unable to insert node(%d) "+
×
3167
                                "feature(%v): %w", nodeID, feature, err)
×
3168
                }
×
3169
        }
3170

3171
        return nil
×
3172
}
3173

3174
// fetchNodeFeatures fetches the features for a node with the given public key.
3175
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3176
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3177

×
3178
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3179
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3180
                        PubKey:  nodePub[:],
×
3181
                        Version: int16(ProtocolV1),
×
3182
                },
×
3183
        )
×
3184
        if err != nil {
×
3185
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3186
                        nodePub, err)
×
3187
        }
×
3188

3189
        features := lnwire.EmptyFeatureVector()
×
3190
        for _, bit := range rows {
×
3191
                features.Set(lnwire.FeatureBit(bit))
×
3192
        }
×
3193

3194
        return features, nil
×
3195
}
3196

3197
// dbAddressType is an enum type that represents the different address types
3198
// that we store in the node_addresses table. The address type determines how
3199
// the address is to be serialised/deserialize.
3200
type dbAddressType uint8
3201

3202
const (
3203
        addressTypeIPv4   dbAddressType = 1
3204
        addressTypeIPv6   dbAddressType = 2
3205
        addressTypeTorV2  dbAddressType = 3
3206
        addressTypeTorV3  dbAddressType = 4
3207
        addressTypeOpaque dbAddressType = math.MaxInt8
3208
)
3209

3210
// upsertNodeAddresses updates the node's addresses in the database. This
3211
// includes deleting any existing addresses and inserting the new set of
3212
// addresses. The deletion is necessary since the ordering of the addresses may
3213
// change, and we need to ensure that the database reflects the latest set of
3214
// addresses so that at the time of reconstructing the node announcement, the
3215
// order is preserved and the signature over the message remains valid.
3216
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3217
        addresses []net.Addr) error {
×
3218

×
3219
        // Delete any existing addresses for the node. This is required since
×
3220
        // even if the new set of addresses is the same, the ordering may have
×
3221
        // changed for a given address type.
×
3222
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3223
        if err != nil {
×
3224
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3225
                        nodeID, err)
×
3226
        }
×
3227

3228
        // Copy the nodes latest set of addresses.
3229
        newAddresses := map[dbAddressType][]string{
×
3230
                addressTypeIPv4:   {},
×
3231
                addressTypeIPv6:   {},
×
3232
                addressTypeTorV2:  {},
×
3233
                addressTypeTorV3:  {},
×
3234
                addressTypeOpaque: {},
×
3235
        }
×
3236
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3237
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3238
        }
×
3239

3240
        for _, address := range addresses {
×
3241
                switch addr := address.(type) {
×
3242
                case *net.TCPAddr:
×
3243
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3244
                                addAddr(addressTypeIPv4, addr)
×
3245
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3246
                                addAddr(addressTypeIPv6, addr)
×
3247
                        } else {
×
3248
                                return fmt.Errorf("unhandled IP address: %v",
×
3249
                                        addr)
×
3250
                        }
×
3251

3252
                case *tor.OnionAddr:
×
3253
                        switch len(addr.OnionService) {
×
3254
                        case tor.V2Len:
×
3255
                                addAddr(addressTypeTorV2, addr)
×
3256
                        case tor.V3Len:
×
3257
                                addAddr(addressTypeTorV3, addr)
×
3258
                        default:
×
3259
                                return fmt.Errorf("invalid length for a tor " +
×
3260
                                        "address")
×
3261
                        }
3262

3263
                case *lnwire.OpaqueAddrs:
×
3264
                        addAddr(addressTypeOpaque, addr)
×
3265

3266
                default:
×
3267
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3268
                }
3269
        }
3270

3271
        // Any remaining entries in newAddresses are new addresses that need to
3272
        // be added to the database for the first time.
3273
        for addrType, addrList := range newAddresses {
×
3274
                for position, addr := range addrList {
×
3275
                        err := db.InsertNodeAddress(
×
3276
                                ctx, sqlc.InsertNodeAddressParams{
×
3277
                                        NodeID:   nodeID,
×
3278
                                        Type:     int16(addrType),
×
3279
                                        Address:  addr,
×
3280
                                        Position: int32(position),
×
3281
                                },
×
3282
                        )
×
3283
                        if err != nil {
×
3284
                                return fmt.Errorf("unable to insert "+
×
3285
                                        "node(%d) address(%v): %w", nodeID,
×
3286
                                        addr, err)
×
3287
                        }
×
3288
                }
3289
        }
3290

3291
        return nil
×
3292
}
3293

3294
// getNodeAddresses fetches the addresses for a node with the given public key.
3295
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3296
        []net.Addr, error) {
×
3297

×
3298
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
3299
        // are returned in the same order as they were inserted.
×
3300
        rows, err := db.GetNodeAddressesByPubKey(
×
3301
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
3302
                        Version: int16(ProtocolV1),
×
3303
                        PubKey:  nodePub,
×
3304
                },
×
3305
        )
×
3306
        if err != nil {
×
3307
                return false, nil, err
×
3308
        }
×
3309

3310
        // GetNodeAddressesByPubKey uses a left join so there should always be
3311
        // at least one row returned if the node exists even if it has no
3312
        // addresses.
3313
        if len(rows) == 0 {
×
3314
                return false, nil, nil
×
3315
        }
×
3316

3317
        addresses := make([]net.Addr, 0, len(rows))
×
3318
        for _, addr := range rows {
×
3319
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3320
                        continue
×
3321
                }
3322

3323
                address := addr.Address.String
×
3324

×
3325
                switch dbAddressType(addr.Type.Int16) {
×
3326
                case addressTypeIPv4:
×
3327
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3328
                        if err != nil {
×
3329
                                return false, nil, nil
×
3330
                        }
×
3331
                        tcp.IP = tcp.IP.To4()
×
3332

×
3333
                        addresses = append(addresses, tcp)
×
3334

3335
                case addressTypeIPv6:
×
3336
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3337
                        if err != nil {
×
3338
                                return false, nil, nil
×
3339
                        }
×
3340
                        addresses = append(addresses, tcp)
×
3341

3342
                case addressTypeTorV3, addressTypeTorV2:
×
3343
                        service, portStr, err := net.SplitHostPort(address)
×
3344
                        if err != nil {
×
3345
                                return false, nil, fmt.Errorf("unable to "+
×
3346
                                        "split tor v3 address: %v",
×
3347
                                        addr.Address)
×
3348
                        }
×
3349

3350
                        port, err := strconv.Atoi(portStr)
×
3351
                        if err != nil {
×
3352
                                return false, nil, err
×
3353
                        }
×
3354

3355
                        addresses = append(addresses, &tor.OnionAddr{
×
3356
                                OnionService: service,
×
3357
                                Port:         port,
×
3358
                        })
×
3359

3360
                case addressTypeOpaque:
×
3361
                        opaque, err := hex.DecodeString(address)
×
3362
                        if err != nil {
×
3363
                                return false, nil, fmt.Errorf("unable to "+
×
3364
                                        "decode opaque address: %v", addr)
×
3365
                        }
×
3366

3367
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3368
                                Payload: opaque,
×
3369
                        })
×
3370

3371
                default:
×
3372
                        return false, nil, fmt.Errorf("unknown address "+
×
3373
                                "type: %v", addr.Type)
×
3374
                }
3375
        }
3376

3377
        return true, addresses, nil
×
3378
}
3379

3380
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3381
// database. This includes updating any existing types, inserting any new types,
3382
// and deleting any types that are no longer present.
3383
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3384
        nodeID int64, extraFields map[uint64][]byte) error {
×
3385

×
3386
        // Get any existing extra signed fields for the node.
×
3387
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3388
        if err != nil {
×
3389
                return err
×
3390
        }
×
3391

3392
        // Make a lookup map of the existing field types so that we can use it
3393
        // to keep track of any fields we should delete.
3394
        m := make(map[uint64]bool)
×
3395
        for _, field := range existingFields {
×
3396
                m[uint64(field.Type)] = true
×
3397
        }
×
3398

3399
        // For all the new fields, we'll upsert them and remove them from the
3400
        // map of existing fields.
3401
        for tlvType, value := range extraFields {
×
3402
                err = db.UpsertNodeExtraType(
×
3403
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3404
                                NodeID: nodeID,
×
3405
                                Type:   int64(tlvType),
×
3406
                                Value:  value,
×
3407
                        },
×
3408
                )
×
3409
                if err != nil {
×
3410
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3411
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3412
                }
×
3413

3414
                // Remove the field from the map of existing fields if it was
3415
                // present.
3416
                delete(m, tlvType)
×
3417
        }
3418

3419
        // For all the fields that are left in the map of existing fields, we'll
3420
        // delete them as they are no longer present in the new set of fields.
3421
        for tlvType := range m {
×
3422
                err = db.DeleteExtraNodeType(
×
3423
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3424
                                NodeID: nodeID,
×
3425
                                Type:   int64(tlvType),
×
3426
                        },
×
3427
                )
×
3428
                if err != nil {
×
3429
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3430
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3431
                }
×
3432
        }
3433

3434
        return nil
×
3435
}
3436

3437
// srcNodeInfo holds the information about the source node of the graph.
3438
type srcNodeInfo struct {
3439
        // id is the DB level ID of the source node entry in the "nodes" table.
3440
        id int64
3441

3442
        // pub is the public key of the source node.
3443
        pub route.Vertex
3444
}
3445

3446
// getSourceNode returns the DB node ID and pub key of the source node for the
3447
// specified protocol version.
3448
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3449
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3450

×
3451
        s.srcNodeMu.Lock()
×
3452
        defer s.srcNodeMu.Unlock()
×
3453

×
3454
        // If we already have the source node ID and pub key cached, then
×
3455
        // return them.
×
3456
        if info, ok := s.srcNodes[version]; ok {
×
3457
                return info.id, info.pub, nil
×
3458
        }
×
3459

3460
        var pubKey route.Vertex
×
3461

×
3462
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3463
        if err != nil {
×
3464
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3465
                        err)
×
3466
        }
×
3467

3468
        if len(nodes) == 0 {
×
3469
                return 0, pubKey, ErrSourceNodeNotSet
×
3470
        } else if len(nodes) > 1 {
×
3471
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3472
                        "protocol %s found", version)
×
3473
        }
×
3474

3475
        copy(pubKey[:], nodes[0].PubKey)
×
3476

×
3477
        s.srcNodes[version] = &srcNodeInfo{
×
3478
                id:  nodes[0].NodeID,
×
3479
                pub: pubKey,
×
3480
        }
×
3481

×
3482
        return nodes[0].NodeID, pubKey, nil
×
3483
}
3484

3485
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3486
// This then produces a map from TLV type to value. If the input is not a
3487
// valid TLV stream, then an error is returned.
3488
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3489
        r := bytes.NewReader(data)
×
3490

×
3491
        tlvStream, err := tlv.NewStream()
×
3492
        if err != nil {
×
3493
                return nil, err
×
3494
        }
×
3495

3496
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3497
        // pass it into the P2P decoding variant.
3498
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3499
        if err != nil {
×
3500
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3501
        }
×
3502
        if len(parsedTypes) == 0 {
×
3503
                return nil, nil
×
3504
        }
×
3505

3506
        records := make(map[uint64][]byte)
×
3507
        for k, v := range parsedTypes {
×
3508
                records[uint64(k)] = v
×
3509
        }
×
3510

3511
        return records, nil
×
3512
}
3513

3514
// insertChannel inserts a new channel record into the database.
3515
func insertChannel(ctx context.Context, db SQLQueries,
3516
        edge *models.ChannelEdgeInfo) error {
×
3517

×
3518
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3519

×
3520
        // Make sure that the channel doesn't already exist. We do this
×
3521
        // explicitly instead of relying on catching a unique constraint error
×
3522
        // because relying on SQL to throw that error would abort the entire
×
3523
        // batch of transactions.
×
3524
        _, err := db.GetChannelBySCID(
×
3525
                ctx, sqlc.GetChannelBySCIDParams{
×
3526
                        Scid:    chanIDB[:],
×
3527
                        Version: int16(ProtocolV1),
×
3528
                },
×
3529
        )
×
3530
        if err == nil {
×
3531
                return ErrEdgeAlreadyExist
×
3532
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3533
                return fmt.Errorf("unable to fetch channel: %w", err)
×
3534
        }
×
3535

3536
        // Make sure that at least a "shell" entry for each node is present in
3537
        // the nodes table.
3538
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3539
        if err != nil {
×
3540
                return fmt.Errorf("unable to create shell node: %w", err)
×
3541
        }
×
3542

3543
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3544
        if err != nil {
×
3545
                return fmt.Errorf("unable to create shell node: %w", err)
×
3546
        }
×
3547

3548
        var capacity sql.NullInt64
×
3549
        if edge.Capacity != 0 {
×
3550
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3551
        }
×
3552

3553
        createParams := sqlc.CreateChannelParams{
×
3554
                Version:     int16(ProtocolV1),
×
3555
                Scid:        chanIDB[:],
×
3556
                NodeID1:     node1DBID,
×
3557
                NodeID2:     node2DBID,
×
3558
                Outpoint:    edge.ChannelPoint.String(),
×
3559
                Capacity:    capacity,
×
3560
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3561
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3562
        }
×
3563

×
3564
        if edge.AuthProof != nil {
×
3565
                proof := edge.AuthProof
×
3566

×
3567
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3568
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3569
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3570
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3571
        }
×
3572

3573
        // Insert the new channel record.
3574
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3575
        if err != nil {
×
3576
                return err
×
3577
        }
×
3578

3579
        // Insert any channel features.
3580
        if len(edge.Features) != 0 {
×
3581
                chanFeatures := lnwire.NewRawFeatureVector()
×
3582
                err := chanFeatures.Decode(bytes.NewReader(edge.Features))
×
3583
                if err != nil {
×
3584
                        return err
×
3585
                }
×
3586

3587
                fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features)
×
3588
                for feature := range fv.Features() {
×
3589
                        err = db.InsertChannelFeature(
×
3590
                                ctx, sqlc.InsertChannelFeatureParams{
×
3591
                                        ChannelID:  dbChanID,
×
3592
                                        FeatureBit: int32(feature),
×
3593
                                },
×
3594
                        )
×
3595
                        if err != nil {
×
3596
                                return fmt.Errorf("unable to insert "+
×
3597
                                        "channel(%d) feature(%v): %w", dbChanID,
×
3598
                                        feature, err)
×
3599
                        }
×
3600
                }
3601
        }
3602

3603
        // Finally, insert any extra TLV fields in the channel announcement.
3604
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3605
        if err != nil {
×
3606
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
3607
                        err)
×
3608
        }
×
3609

3610
        for tlvType, value := range extra {
×
3611
                err := db.CreateChannelExtraType(
×
3612
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3613
                                ChannelID: dbChanID,
×
3614
                                Type:      int64(tlvType),
×
3615
                                Value:     value,
×
3616
                        },
×
3617
                )
×
3618
                if err != nil {
×
3619
                        return fmt.Errorf("unable to upsert channel(%d) extra "+
×
3620
                                "signed field(%v): %w", edge.ChannelID,
×
3621
                                tlvType, err)
×
3622
                }
×
3623
        }
3624

3625
        return nil
×
3626
}
3627

3628
// maybeCreateShellNode checks if a shell node entry exists for the
3629
// given public key. If it does not exist, then a new shell node entry is
3630
// created. The ID of the node is returned. A shell node only has a protocol
3631
// version and public key persisted.
3632
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3633
        pubKey route.Vertex) (int64, error) {
×
3634

×
3635
        dbNode, err := db.GetNodeByPubKey(
×
3636
                ctx, sqlc.GetNodeByPubKeyParams{
×
3637
                        PubKey:  pubKey[:],
×
3638
                        Version: int16(ProtocolV1),
×
3639
                },
×
3640
        )
×
3641
        // The node exists. Return the ID.
×
3642
        if err == nil {
×
3643
                return dbNode.ID, nil
×
3644
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3645
                return 0, err
×
3646
        }
×
3647

3648
        // Otherwise, the node does not exist, so we create a shell entry for
3649
        // it.
3650
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3651
                Version: int16(ProtocolV1),
×
3652
                PubKey:  pubKey[:],
×
3653
        })
×
3654
        if err != nil {
×
3655
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3656
        }
×
3657

3658
        return id, nil
×
3659
}
3660

3661
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3662
// the database. This includes deleting any existing types and then inserting
3663
// the new types.
3664
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3665
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3666

×
3667
        // Delete all existing extra signed fields for the channel policy.
×
3668
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3669
        if err != nil {
×
3670
                return fmt.Errorf("unable to delete "+
×
3671
                        "existing policy extra signed fields for policy %d: %w",
×
3672
                        chanPolicyID, err)
×
3673
        }
×
3674

3675
        // Insert all new extra signed fields for the channel policy.
3676
        for tlvType, value := range extraFields {
×
3677
                err = db.InsertChanPolicyExtraType(
×
3678
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3679
                                ChannelPolicyID: chanPolicyID,
×
3680
                                Type:            int64(tlvType),
×
3681
                                Value:           value,
×
3682
                        },
×
3683
                )
×
3684
                if err != nil {
×
3685
                        return fmt.Errorf("unable to insert "+
×
3686
                                "channel_policy(%d) extra signed field(%v): %w",
×
3687
                                chanPolicyID, tlvType, err)
×
3688
                }
×
3689
        }
3690

3691
        return nil
×
3692
}
3693

3694
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3695
// provided dbChanRow and also fetches any other required information
3696
// to construct the edge info.
3697
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
3698
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1,
3699
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3700

×
NEW
3701
        if dbChan.Version != int16(ProtocolV1) {
×
NEW
3702
                return nil, fmt.Errorf("unsupported channel version: %d",
×
NEW
3703
                        dbChan.Version)
×
NEW
3704
        }
×
3705

3706
        fv, extras, err := getChanFeaturesAndExtras(
×
3707
                ctx, db, dbChanID,
×
3708
        )
×
3709
        if err != nil {
×
3710
                return nil, err
×
3711
        }
×
3712

3713
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
3714
        if err != nil {
×
3715
                return nil, err
×
3716
        }
×
3717

3718
        var featureBuf bytes.Buffer
×
3719
        if err := fv.Encode(&featureBuf); err != nil {
×
3720
                return nil, fmt.Errorf("unable to encode features: %w", err)
×
3721
        }
×
3722

3723
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3724
        if err != nil {
×
3725
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3726
                        "fields: %w", err)
×
3727
        }
×
3728
        if recs == nil {
×
3729
                recs = make([]byte, 0)
×
3730
        }
×
3731

3732
        var btcKey1, btcKey2 route.Vertex
×
3733
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
3734
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
3735

×
3736
        channel := &models.ChannelEdgeInfo{
×
3737
                ChainHash:        chain,
×
3738
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
3739
                NodeKey1Bytes:    node1,
×
3740
                NodeKey2Bytes:    node2,
×
3741
                BitcoinKey1Bytes: btcKey1,
×
3742
                BitcoinKey2Bytes: btcKey2,
×
3743
                ChannelPoint:     *op,
×
3744
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
3745
                Features:         featureBuf.Bytes(),
×
3746
                ExtraOpaqueData:  recs,
×
3747
        }
×
3748

×
3749
        // We always set all the signatures at the same time, so we can
×
3750
        // safely check if one signature is present to determine if we have the
×
3751
        // rest of the signatures for the auth proof.
×
3752
        if len(dbChan.Bitcoin1Signature) > 0 {
×
3753
                channel.AuthProof = &models.ChannelAuthProof{
×
3754
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
3755
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
3756
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
3757
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
3758
                }
×
3759
        }
×
3760

3761
        return channel, nil
×
3762
}
3763

3764
// buildNodeVertices is a helper that converts raw node public keys
3765
// into route.Vertex instances.
3766
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
3767
        route.Vertex, error) {
×
3768

×
3769
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
3770
        if err != nil {
×
3771
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
3772
                        "create vertex from node1 pubkey: %w", err)
×
3773
        }
×
3774

3775
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
3776
        if err != nil {
×
3777
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
3778
                        "create vertex from node2 pubkey: %w", err)
×
3779
        }
×
3780

3781
        return node1Vertex, node2Vertex, nil
×
3782
}
3783

3784
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
3785
// for a channel with the given ID.
3786
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
3787
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
3788

×
3789
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
3790
        if err != nil {
×
3791
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
3792
                        "features and extras: %w", err)
×
3793
        }
×
3794

3795
        var (
×
3796
                fv     = lnwire.EmptyFeatureVector()
×
3797
                extras = make(map[uint64][]byte)
×
3798
        )
×
3799
        for _, row := range rows {
×
3800
                if row.IsFeature {
×
3801
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
3802

×
3803
                        continue
×
3804
                }
3805

3806
                tlvType, ok := row.ExtraKey.(int64)
×
3807
                if !ok {
×
3808
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
3809
                                "TLV type: %T", row.ExtraKey)
×
3810
                }
×
3811

3812
                valueBytes, ok := row.Value.([]byte)
×
3813
                if !ok {
×
3814
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
3815
                                "Value: %T", row.Value)
×
3816
                }
×
3817

3818
                extras[uint64(tlvType)] = valueBytes
×
3819
        }
3820

3821
        return fv, extras, nil
×
3822
}
3823

3824
// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves
3825
// all the extra info required to build the complete models.ChannelEdgePolicy
3826
// types. It returns two policies, which may be nil if the provided
3827
// sqlc.ChannelPolicy records are nil.
3828
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
3829
        dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1,
3830
        node2 route.Vertex) (*models.ChannelEdgePolicy,
3831
        *models.ChannelEdgePolicy, error) {
×
3832

×
3833
        if dbPol1 == nil && dbPol2 == nil {
×
3834
                return nil, nil, nil
×
3835
        }
×
3836

3837
        var (
×
3838
                policy1ID int64
×
3839
                policy2ID int64
×
3840
        )
×
3841
        if dbPol1 != nil {
×
3842
                policy1ID = dbPol1.ID
×
3843
        }
×
3844
        if dbPol2 != nil {
×
3845
                policy2ID = dbPol2.ID
×
3846
        }
×
3847
        rows, err := db.GetChannelPolicyExtraTypes(
×
3848
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
3849
                        ID:   policy1ID,
×
3850
                        ID_2: policy2ID,
×
3851
                },
×
3852
        )
×
3853
        if err != nil {
×
3854
                return nil, nil, err
×
3855
        }
×
3856

3857
        var (
×
3858
                dbPol1Extras = make(map[uint64][]byte)
×
3859
                dbPol2Extras = make(map[uint64][]byte)
×
3860
        )
×
3861
        for _, row := range rows {
×
3862
                switch row.PolicyID {
×
3863
                case policy1ID:
×
3864
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
3865
                case policy2ID:
×
3866
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
3867
                default:
×
3868
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
3869
                                "in row: %v", row.PolicyID, row)
×
3870
                }
3871
        }
3872

3873
        var pol1, pol2 *models.ChannelEdgePolicy
×
3874
        if dbPol1 != nil {
×
3875
                pol1, err = buildChanPolicy(
×
3876
                        *dbPol1, channelID, dbPol1Extras, node2, true,
×
3877
                )
×
3878
                if err != nil {
×
3879
                        return nil, nil, err
×
3880
                }
×
3881
        }
3882
        if dbPol2 != nil {
×
3883
                pol2, err = buildChanPolicy(
×
3884
                        *dbPol2, channelID, dbPol2Extras, node1, false,
×
3885
                )
×
3886
                if err != nil {
×
3887
                        return nil, nil, err
×
3888
                }
×
3889
        }
3890

3891
        return pol1, pol2, nil
×
3892
}
3893

3894
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
3895
// provided sqlc.ChannelPolicy and other required information.
3896
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
3897
        extras map[uint64][]byte, toNode route.Vertex,
3898
        isNode1 bool) (*models.ChannelEdgePolicy, error) {
×
3899

×
3900
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3901
        if err != nil {
×
3902
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3903
                        "fields: %w", err)
×
3904
        }
×
3905

3906
        var msgFlags lnwire.ChanUpdateMsgFlags
×
3907
        if dbPolicy.MaxHtlcMsat.Valid {
×
3908
                msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc
×
3909
        }
×
3910

3911
        var chanFlags lnwire.ChanUpdateChanFlags
×
3912
        if !isNode1 {
×
3913
                chanFlags |= lnwire.ChanUpdateDirection
×
3914
        }
×
3915
        if dbPolicy.Disabled.Bool {
×
3916
                chanFlags |= lnwire.ChanUpdateDisabled
×
3917
        }
×
3918

3919
        var inboundFee fn.Option[lnwire.Fee]
×
3920
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
3921
                dbPolicy.InboundBaseFeeMsat.Valid {
×
3922

×
3923
                inboundFee = fn.Some(lnwire.Fee{
×
3924
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
3925
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
3926
                })
×
3927
        }
×
3928

3929
        return &models.ChannelEdgePolicy{
×
3930
                SigBytes:  dbPolicy.Signature,
×
3931
                ChannelID: channelID,
×
3932
                LastUpdate: time.Unix(
×
3933
                        dbPolicy.LastUpdate.Int64, 0,
×
3934
                ),
×
3935
                MessageFlags:  msgFlags,
×
3936
                ChannelFlags:  chanFlags,
×
3937
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
3938
                MinHTLC: lnwire.MilliSatoshi(
×
3939
                        dbPolicy.MinHtlcMsat,
×
3940
                ),
×
3941
                MaxHTLC: lnwire.MilliSatoshi(
×
3942
                        dbPolicy.MaxHtlcMsat.Int64,
×
3943
                ),
×
3944
                FeeBaseMSat: lnwire.MilliSatoshi(
×
3945
                        dbPolicy.BaseFeeMsat,
×
3946
                ),
×
3947
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
3948
                ToNode:                    toNode,
×
3949
                InboundFee:                inboundFee,
×
3950
                ExtraOpaqueData:           recs,
×
3951
        }, nil
×
3952
}
3953

3954
// buildNodes builds the models.LightningNode instances for the
3955
// given row which is expected to be a sqlc type that contains node information.
3956
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
3957
        dbNode2 sqlc.Node) (*models.LightningNode, *models.LightningNode,
3958
        error) {
×
3959

×
3960
        node1, err := buildNode(ctx, db, &dbNode1)
×
3961
        if err != nil {
×
3962
                return nil, nil, err
×
3963
        }
×
3964

3965
        node2, err := buildNode(ctx, db, &dbNode2)
×
3966
        if err != nil {
×
3967
                return nil, nil, err
×
3968
        }
×
3969

3970
        return node1, node2, nil
×
3971
}
3972

3973
// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give
3974
// row which is expected to be a sqlc type that contains channel policy
3975
// information. It returns two policies, which may be nil if the policy
3976
// information is not present in the row.
3977
//
3978
//nolint:ll,dupl,funlen
3979
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
3980
        error) {
×
3981

×
3982
        var policy1, policy2 *sqlc.ChannelPolicy
×
3983
        switch r := row.(type) {
×
3984
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
3985
                if r.Policy1ID.Valid {
×
3986
                        policy1 = &sqlc.ChannelPolicy{
×
3987
                                ID:                      r.Policy1ID.Int64,
×
3988
                                Version:                 r.Policy1Version.Int16,
×
3989
                                ChannelID:               r.Channel.ID,
×
3990
                                NodeID:                  r.Policy1NodeID.Int64,
×
3991
                                Timelock:                r.Policy1Timelock.Int32,
×
3992
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
3993
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
3994
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
3995
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
3996
                                LastUpdate:              r.Policy1LastUpdate,
×
3997
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
3998
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
3999
                                Disabled:                r.Policy1Disabled,
×
4000
                                Signature:               r.Policy1Signature,
×
4001
                        }
×
4002
                }
×
4003
                if r.Policy2ID.Valid {
×
4004
                        policy2 = &sqlc.ChannelPolicy{
×
4005
                                ID:                      r.Policy2ID.Int64,
×
4006
                                Version:                 r.Policy2Version.Int16,
×
4007
                                ChannelID:               r.Channel.ID,
×
4008
                                NodeID:                  r.Policy2NodeID.Int64,
×
4009
                                Timelock:                r.Policy2Timelock.Int32,
×
4010
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4011
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4012
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4013
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4014
                                LastUpdate:              r.Policy2LastUpdate,
×
4015
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4016
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4017
                                Disabled:                r.Policy2Disabled,
×
4018
                                Signature:               r.Policy2Signature,
×
4019
                        }
×
4020
                }
×
4021

4022
                return policy1, policy2, nil
×
4023

4024
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4025
                if r.Policy1ID.Valid {
×
4026
                        policy1 = &sqlc.ChannelPolicy{
×
4027
                                ID:                      r.Policy1ID.Int64,
×
4028
                                Version:                 r.Policy1Version.Int16,
×
4029
                                ChannelID:               r.Channel.ID,
×
4030
                                NodeID:                  r.Policy1NodeID.Int64,
×
4031
                                Timelock:                r.Policy1Timelock.Int32,
×
4032
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4033
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4034
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4035
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4036
                                LastUpdate:              r.Policy1LastUpdate,
×
4037
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4038
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4039
                                Disabled:                r.Policy1Disabled,
×
4040
                                Signature:               r.Policy1Signature,
×
4041
                        }
×
4042
                }
×
4043
                if r.Policy2ID.Valid {
×
4044
                        policy2 = &sqlc.ChannelPolicy{
×
4045
                                ID:                      r.Policy2ID.Int64,
×
4046
                                Version:                 r.Policy2Version.Int16,
×
4047
                                ChannelID:               r.Channel.ID,
×
4048
                                NodeID:                  r.Policy2NodeID.Int64,
×
4049
                                Timelock:                r.Policy2Timelock.Int32,
×
4050
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4051
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4052
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4053
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4054
                                LastUpdate:              r.Policy2LastUpdate,
×
4055
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4056
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4057
                                Disabled:                r.Policy2Disabled,
×
4058
                                Signature:               r.Policy2Signature,
×
4059
                        }
×
4060
                }
×
4061

4062
                return policy1, policy2, nil
×
4063

4064
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4065
                if r.Policy1ID.Valid {
×
4066
                        policy1 = &sqlc.ChannelPolicy{
×
4067
                                ID:                      r.Policy1ID.Int64,
×
4068
                                Version:                 r.Policy1Version.Int16,
×
4069
                                ChannelID:               r.Channel.ID,
×
4070
                                NodeID:                  r.Policy1NodeID.Int64,
×
4071
                                Timelock:                r.Policy1Timelock.Int32,
×
4072
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4073
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4074
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4075
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4076
                                LastUpdate:              r.Policy1LastUpdate,
×
4077
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4078
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4079
                                Disabled:                r.Policy1Disabled,
×
4080
                                Signature:               r.Policy1Signature,
×
4081
                        }
×
4082
                }
×
4083
                if r.Policy2ID.Valid {
×
4084
                        policy2 = &sqlc.ChannelPolicy{
×
4085
                                ID:                      r.Policy2ID.Int64,
×
4086
                                Version:                 r.Policy2Version.Int16,
×
4087
                                ChannelID:               r.Channel.ID,
×
4088
                                NodeID:                  r.Policy2NodeID.Int64,
×
4089
                                Timelock:                r.Policy2Timelock.Int32,
×
4090
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4091
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4092
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4093
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4094
                                LastUpdate:              r.Policy2LastUpdate,
×
4095
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4096
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4097
                                Disabled:                r.Policy2Disabled,
×
4098
                                Signature:               r.Policy2Signature,
×
4099
                        }
×
4100
                }
×
4101

4102
                return policy1, policy2, nil
×
4103

4104
        case sqlc.ListChannelsByNodeIDRow:
×
4105
                if r.Policy1ID.Valid {
×
4106
                        policy1 = &sqlc.ChannelPolicy{
×
4107
                                ID:                      r.Policy1ID.Int64,
×
4108
                                Version:                 r.Policy1Version.Int16,
×
4109
                                ChannelID:               r.Channel.ID,
×
4110
                                NodeID:                  r.Policy1NodeID.Int64,
×
4111
                                Timelock:                r.Policy1Timelock.Int32,
×
4112
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4113
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4114
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4115
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4116
                                LastUpdate:              r.Policy1LastUpdate,
×
4117
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4118
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4119
                                Disabled:                r.Policy1Disabled,
×
4120
                                Signature:               r.Policy1Signature,
×
4121
                        }
×
4122
                }
×
4123
                if r.Policy2ID.Valid {
×
4124
                        policy2 = &sqlc.ChannelPolicy{
×
4125
                                ID:                      r.Policy2ID.Int64,
×
4126
                                Version:                 r.Policy2Version.Int16,
×
4127
                                ChannelID:               r.Channel.ID,
×
4128
                                NodeID:                  r.Policy2NodeID.Int64,
×
4129
                                Timelock:                r.Policy2Timelock.Int32,
×
4130
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4131
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4132
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4133
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4134
                                LastUpdate:              r.Policy2LastUpdate,
×
4135
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4136
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4137
                                Disabled:                r.Policy2Disabled,
×
4138
                                Signature:               r.Policy2Signature,
×
4139
                        }
×
4140
                }
×
4141

4142
                return policy1, policy2, nil
×
4143

4144
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4145
                if r.Policy1ID.Valid {
×
4146
                        policy1 = &sqlc.ChannelPolicy{
×
4147
                                ID:                      r.Policy1ID.Int64,
×
4148
                                Version:                 r.Policy1Version.Int16,
×
4149
                                ChannelID:               r.Channel.ID,
×
4150
                                NodeID:                  r.Policy1NodeID.Int64,
×
4151
                                Timelock:                r.Policy1Timelock.Int32,
×
4152
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4153
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4154
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4155
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4156
                                LastUpdate:              r.Policy1LastUpdate,
×
4157
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4158
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4159
                                Disabled:                r.Policy1Disabled,
×
4160
                                Signature:               r.Policy1Signature,
×
4161
                        }
×
4162
                }
×
4163
                if r.Policy2ID.Valid {
×
4164
                        policy2 = &sqlc.ChannelPolicy{
×
4165
                                ID:                      r.Policy2ID.Int64,
×
4166
                                Version:                 r.Policy2Version.Int16,
×
4167
                                ChannelID:               r.Channel.ID,
×
4168
                                NodeID:                  r.Policy2NodeID.Int64,
×
4169
                                Timelock:                r.Policy2Timelock.Int32,
×
4170
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4171
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4172
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4173
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4174
                                LastUpdate:              r.Policy2LastUpdate,
×
4175
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4176
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4177
                                Disabled:                r.Policy2Disabled,
×
4178
                                Signature:               r.Policy2Signature,
×
4179
                        }
×
4180
                }
×
4181

4182
                return policy1, policy2, nil
×
4183
        default:
×
4184
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4185
                        "extractChannelPolicies: %T", r)
×
4186
        }
4187
}
4188

4189
// channelIDToBytes converts a channel ID (SCID) to a byte array
4190
// representation.
4191
func channelIDToBytes(channelID uint64) [8]byte {
×
4192
        var chanIDB [8]byte
×
4193
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4194

×
4195
        return chanIDB
×
4196
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc