• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16298284123

15 Jul 2025 03:59PM UTC coverage: 67.225% (-0.1%) from 67.338%
16298284123

Pull #10081

github

web-flow
Merge a3f3f2e60 into df6c02e3a
Pull Request #10081: graph/db: use `/*SLICE:<field_name>*/` to optimise various graph queries

3 of 364 new or added lines in 5 files covered. (0.82%)

674 existing lines in 22 files now uncovered.

135394 of 201405 relevant lines covered (67.22%)

21718.17 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
72
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
73
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
74

75
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
96
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.Channel, error)
97
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
98
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
99
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
100
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
101
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
102
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
103
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
104
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
105
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
106
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
107
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
108
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
109
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
110
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
111
        DeleteChannel(ctx context.Context, id int64) error
112

113
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
114
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
115

116
        /*
117
                Channel Policy table queries.
118
        */
119
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
120
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.ChannelPolicy, error)
121
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
122

123
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
124
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
125
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
126

127
        /*
128
                Zombie index queries.
129
        */
130
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
131
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error)
132
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
133
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
134
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
135

136
        /*
137
                Prune log table queries.
138
        */
139
        GetPruneTip(ctx context.Context) (sqlc.PruneLog, error)
140
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
141
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
142
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
143

144
        /*
145
                Closed SCID table queries.
146
        */
147
        InsertClosedChannel(ctx context.Context, scid []byte) error
148
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
149
}
150

151
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
152
// database operations.
153
type BatchedSQLQueries interface {
154
        SQLQueries
155
        sqldb.BatchedTx[SQLQueries]
156
}
157

158
// SQLStore is an implementation of the V1Store interface that uses a SQL
159
// database as the backend.
160
type SQLStore struct {
161
        cfg *SQLStoreConfig
162
        db  BatchedSQLQueries
163

164
        // cacheMu guards all caches (rejectCache and chanCache). If
165
        // this mutex will be acquired at the same time as the DB mutex then
166
        // the cacheMu MUST be acquired first to prevent deadlock.
167
        cacheMu     sync.RWMutex
168
        rejectCache *rejectCache
169
        chanCache   *channelCache
170

171
        chanScheduler batch.Scheduler[SQLQueries]
172
        nodeScheduler batch.Scheduler[SQLQueries]
173

174
        srcNodes  map[ProtocolVersion]*srcNodeInfo
175
        srcNodeMu sync.Mutex
176
}
177

178
// A compile-time assertion to ensure that SQLStore implements the V1Store
179
// interface.
180
var _ V1Store = (*SQLStore)(nil)
181

182
// SQLStoreConfig holds the configuration for the SQLStore.
183
type SQLStoreConfig struct {
184
        // ChainHash is the genesis hash for the chain that all the gossip
185
        // messages in this store are aimed at.
186
        ChainHash chainhash.Hash
187

188
        // PaginationCfg is the configuration for paginated queries.
189
        PaginationCfg *sqldb.PagedQueryConfig
190
}
191

192
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
193
// storage backend.
194
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
195
        options ...StoreOptionModifier) (*SQLStore, error) {
×
196

×
197
        opts := DefaultOptions()
×
198
        for _, o := range options {
×
199
                o(opts)
×
200
        }
×
201

202
        if opts.NoMigration {
×
203
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
204
                        "supported for SQL stores")
×
205
        }
×
206

207
        s := &SQLStore{
×
208
                cfg:         cfg,
×
209
                db:          db,
×
210
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
211
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
212
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
213
        }
×
214

×
215
        s.chanScheduler = batch.NewTimeScheduler(
×
216
                db, &s.cacheMu, opts.BatchCommitInterval,
×
217
        )
×
218
        s.nodeScheduler = batch.NewTimeScheduler(
×
219
                db, nil, opts.BatchCommitInterval,
×
220
        )
×
221

×
222
        return s, nil
×
223
}
224

225
// AddLightningNode adds a vertex/node to the graph database. If the node is not
226
// in the database from before, this will add a new, unconnected one to the
227
// graph. If it is present from before, this will update that node's
228
// information.
229
//
230
// NOTE: part of the V1Store interface.
231
func (s *SQLStore) AddLightningNode(ctx context.Context,
232
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
233

×
234
        r := &batch.Request[SQLQueries]{
×
235
                Opts: batch.NewSchedulerOptions(opts...),
×
236
                Do: func(queries SQLQueries) error {
×
237
                        _, err := upsertNode(ctx, queries, node)
×
238
                        return err
×
239
                },
×
240
        }
241

242
        return s.nodeScheduler.Execute(ctx, r)
×
243
}
244

245
// FetchLightningNode attempts to look up a target node by its identity public
246
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
247
// returned.
248
//
249
// NOTE: part of the V1Store interface.
250
func (s *SQLStore) FetchLightningNode(ctx context.Context,
251
        pubKey route.Vertex) (*models.LightningNode, error) {
×
252

×
253
        var node *models.LightningNode
×
254
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
255
                var err error
×
256
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
257

×
258
                return err
×
259
        }, sqldb.NoOpReset)
×
260
        if err != nil {
×
261
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
262
        }
×
263

264
        return node, nil
×
265
}
266

267
// HasLightningNode determines if the graph has a vertex identified by the
268
// target node identity public key. If the node exists in the database, a
269
// timestamp of when the data for the node was lasted updated is returned along
270
// with a true boolean. Otherwise, an empty time.Time is returned with a false
271
// boolean.
272
//
273
// NOTE: part of the V1Store interface.
274
func (s *SQLStore) HasLightningNode(ctx context.Context,
275
        pubKey [33]byte) (time.Time, bool, error) {
×
276

×
277
        var (
×
278
                exists     bool
×
279
                lastUpdate time.Time
×
280
        )
×
281
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
282
                dbNode, err := db.GetNodeByPubKey(
×
283
                        ctx, sqlc.GetNodeByPubKeyParams{
×
284
                                Version: int16(ProtocolV1),
×
285
                                PubKey:  pubKey[:],
×
286
                        },
×
287
                )
×
288
                if errors.Is(err, sql.ErrNoRows) {
×
289
                        return nil
×
290
                } else if err != nil {
×
291
                        return fmt.Errorf("unable to fetch node: %w", err)
×
292
                }
×
293

294
                exists = true
×
295

×
296
                if dbNode.LastUpdate.Valid {
×
297
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
298
                }
×
299

300
                return nil
×
301
        }, sqldb.NoOpReset)
302
        if err != nil {
×
303
                return time.Time{}, false,
×
304
                        fmt.Errorf("unable to fetch node: %w", err)
×
305
        }
×
306

307
        return lastUpdate, exists, nil
×
308
}
309

310
// AddrsForNode returns all known addresses for the target node public key
311
// that the graph DB is aware of. The returned boolean indicates if the
312
// given node is unknown to the graph DB or not.
313
//
314
// NOTE: part of the V1Store interface.
315
func (s *SQLStore) AddrsForNode(ctx context.Context,
316
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
317

×
318
        var (
×
319
                addresses []net.Addr
×
320
                known     bool
×
321
        )
×
322
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
323
                var err error
×
324
                known, addresses, err = getNodeAddresses(
×
325
                        ctx, db, nodePub.SerializeCompressed(),
×
326
                )
×
327
                if err != nil {
×
328
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
329
                                err)
×
330
                }
×
331

332
                return nil
×
333
        }, sqldb.NoOpReset)
334
        if err != nil {
×
335
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
336
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
337
        }
×
338

339
        return known, addresses, nil
×
340
}
341

342
// DeleteLightningNode starts a new database transaction to remove a vertex/node
343
// from the database according to the node's public key.
344
//
345
// NOTE: part of the V1Store interface.
346
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
347
        pubKey route.Vertex) error {
×
348

×
349
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
350
                res, err := db.DeleteNodeByPubKey(
×
351
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
352
                                Version: int16(ProtocolV1),
×
353
                                PubKey:  pubKey[:],
×
354
                        },
×
355
                )
×
356
                if err != nil {
×
357
                        return err
×
358
                }
×
359

360
                rows, err := res.RowsAffected()
×
361
                if err != nil {
×
362
                        return err
×
363
                }
×
364

365
                if rows == 0 {
×
366
                        return ErrGraphNodeNotFound
×
367
                } else if rows > 1 {
×
368
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
369
                }
×
370

371
                return err
×
372
        }, sqldb.NoOpReset)
373
        if err != nil {
×
374
                return fmt.Errorf("unable to delete node: %w", err)
×
375
        }
×
376

377
        return nil
×
378
}
379

380
// FetchNodeFeatures returns the features of the given node. If no features are
381
// known for the node, an empty feature vector is returned.
382
//
383
// NOTE: this is part of the graphdb.NodeTraverser interface.
384
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
385
        *lnwire.FeatureVector, error) {
×
386

×
387
        ctx := context.TODO()
×
388

×
389
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
390
}
×
391

392
// DisabledChannelIDs returns the channel ids of disabled channels.
393
// A channel is disabled when two of the associated ChanelEdgePolicies
394
// have their disabled bit on.
395
//
396
// NOTE: part of the V1Store interface.
397
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
398
        var (
×
399
                ctx     = context.TODO()
×
400
                chanIDs []uint64
×
401
        )
×
402
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
403
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
404
                if err != nil {
×
405
                        return fmt.Errorf("unable to fetch disabled "+
×
406
                                "channels: %w", err)
×
407
                }
×
408

409
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
410

×
411
                return nil
×
412
        }, sqldb.NoOpReset)
413
        if err != nil {
×
414
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
415
                        err)
×
416
        }
×
417

418
        return chanIDs, nil
×
419
}
420

421
// LookupAlias attempts to return the alias as advertised by the target node.
422
//
423
// NOTE: part of the V1Store interface.
424
func (s *SQLStore) LookupAlias(ctx context.Context,
425
        pub *btcec.PublicKey) (string, error) {
×
426

×
427
        var alias string
×
428
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
429
                dbNode, err := db.GetNodeByPubKey(
×
430
                        ctx, sqlc.GetNodeByPubKeyParams{
×
431
                                Version: int16(ProtocolV1),
×
432
                                PubKey:  pub.SerializeCompressed(),
×
433
                        },
×
434
                )
×
435
                if errors.Is(err, sql.ErrNoRows) {
×
436
                        return ErrNodeAliasNotFound
×
437
                } else if err != nil {
×
438
                        return fmt.Errorf("unable to fetch node: %w", err)
×
439
                }
×
440

441
                if !dbNode.Alias.Valid {
×
442
                        return ErrNodeAliasNotFound
×
443
                }
×
444

445
                alias = dbNode.Alias.String
×
446

×
447
                return nil
×
448
        }, sqldb.NoOpReset)
449
        if err != nil {
×
450
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
451
        }
×
452

453
        return alias, nil
×
454
}
455

456
// SourceNode returns the source node of the graph. The source node is treated
457
// as the center node within a star-graph. This method may be used to kick off
458
// a path finding algorithm in order to explore the reachability of another
459
// node based off the source node.
460
//
461
// NOTE: part of the V1Store interface.
462
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
463
        error) {
×
464

×
465
        var node *models.LightningNode
×
466
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
467
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
468
                if err != nil {
×
469
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
470
                                err)
×
471
                }
×
472

473
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
474

×
475
                return err
×
476
        }, sqldb.NoOpReset)
477
        if err != nil {
×
478
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
479
        }
×
480

481
        return node, nil
×
482
}
483

484
// SetSourceNode sets the source node within the graph database. The source
485
// node is to be used as the center of a star-graph within path finding
486
// algorithms.
487
//
488
// NOTE: part of the V1Store interface.
489
func (s *SQLStore) SetSourceNode(ctx context.Context,
490
        node *models.LightningNode) error {
×
491

×
492
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
493
                id, err := upsertNode(ctx, db, node)
×
494
                if err != nil {
×
495
                        return fmt.Errorf("unable to upsert source node: %w",
×
496
                                err)
×
497
                }
×
498

499
                // Make sure that if a source node for this version is already
500
                // set, then the ID is the same as the one we are about to set.
501
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
502
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
503
                        return fmt.Errorf("unable to fetch source node: %w",
×
504
                                err)
×
505
                } else if err == nil {
×
506
                        if dbSourceNodeID != id {
×
507
                                return fmt.Errorf("v1 source node already "+
×
508
                                        "set to a different node: %d vs %d",
×
509
                                        dbSourceNodeID, id)
×
510
                        }
×
511

512
                        return nil
×
513
                }
514

515
                return db.AddSourceNode(ctx, id)
×
516
        }, sqldb.NoOpReset)
517
}
518

519
// NodeUpdatesInHorizon returns all the known lightning node which have an
520
// update timestamp within the passed range. This method can be used by two
521
// nodes to quickly determine if they have the same set of up to date node
522
// announcements.
523
//
524
// NOTE: This is part of the V1Store interface.
525
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
526
        endTime time.Time) ([]models.LightningNode, error) {
×
527

×
528
        ctx := context.TODO()
×
529

×
530
        var nodes []models.LightningNode
×
531
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
532
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
533
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
534
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
535
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
536
                        },
×
537
                )
×
538
                if err != nil {
×
539
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
540
                }
×
541

542
                for _, dbNode := range dbNodes {
×
543
                        node, err := buildNode(ctx, db, &dbNode)
×
544
                        if err != nil {
×
545
                                return fmt.Errorf("unable to build node: %w",
×
546
                                        err)
×
547
                        }
×
548

549
                        nodes = append(nodes, *node)
×
550
                }
551

552
                return nil
×
553
        }, sqldb.NoOpReset)
554
        if err != nil {
×
555
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
556
        }
×
557

558
        return nodes, nil
×
559
}
560

561
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
562
// undirected edge from the two target nodes are created. The information stored
563
// denotes the static attributes of the channel, such as the channelID, the keys
564
// involved in creation of the channel, and the set of features that the channel
565
// supports. The chanPoint and chanID are used to uniquely identify the edge
566
// globally within the database.
567
//
568
// NOTE: part of the V1Store interface.
569
func (s *SQLStore) AddChannelEdge(ctx context.Context,
570
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
571

×
572
        var alreadyExists bool
×
573
        r := &batch.Request[SQLQueries]{
×
574
                Opts: batch.NewSchedulerOptions(opts...),
×
575
                Reset: func() {
×
576
                        alreadyExists = false
×
577
                },
×
578
                Do: func(tx SQLQueries) error {
×
579
                        _, err := insertChannel(ctx, tx, edge)
×
580

×
581
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
582
                        // succeed, but propagate the error via local state.
×
583
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
584
                                alreadyExists = true
×
585
                                return nil
×
586
                        }
×
587

588
                        return err
×
589
                },
590
                OnCommit: func(err error) error {
×
591
                        switch {
×
592
                        case err != nil:
×
593
                                return err
×
594
                        case alreadyExists:
×
595
                                return ErrEdgeAlreadyExist
×
596
                        default:
×
597
                                s.rejectCache.remove(edge.ChannelID)
×
598
                                s.chanCache.remove(edge.ChannelID)
×
599
                                return nil
×
600
                        }
601
                },
602
        }
603

604
        return s.chanScheduler.Execute(ctx, r)
×
605
}
606

607
// HighestChanID returns the "highest" known channel ID in the channel graph.
608
// This represents the "newest" channel from the PoV of the chain. This method
609
// can be used by peers to quickly determine if their graphs are in sync.
610
//
611
// NOTE: This is part of the V1Store interface.
612
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
613
        var highestChanID uint64
×
614
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
615
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
616
                if errors.Is(err, sql.ErrNoRows) {
×
617
                        return nil
×
618
                } else if err != nil {
×
619
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
620
                                err)
×
621
                }
×
622

623
                highestChanID = byteOrder.Uint64(chanID)
×
624

×
625
                return nil
×
626
        }, sqldb.NoOpReset)
627
        if err != nil {
×
628
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
629
        }
×
630

631
        return highestChanID, nil
×
632
}
633

634
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
635
// within the database for the referenced channel. The `flags` attribute within
636
// the ChannelEdgePolicy determines which of the directed edges are being
637
// updated. If the flag is 1, then the first node's information is being
638
// updated, otherwise it's the second node's information. The node ordering is
639
// determined by the lexicographical ordering of the identity public keys of the
640
// nodes on either side of the channel.
641
//
642
// NOTE: part of the V1Store interface.
643
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
644
        edge *models.ChannelEdgePolicy,
645
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
646

×
647
        var (
×
648
                isUpdate1    bool
×
649
                edgeNotFound bool
×
650
                from, to     route.Vertex
×
651
        )
×
652

×
653
        r := &batch.Request[SQLQueries]{
×
654
                Opts: batch.NewSchedulerOptions(opts...),
×
655
                Reset: func() {
×
656
                        isUpdate1 = false
×
657
                        edgeNotFound = false
×
658
                },
×
659
                Do: func(tx SQLQueries) error {
×
660
                        var err error
×
661
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
662
                                ctx, tx, edge,
×
663
                        )
×
664
                        if err != nil {
×
665
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
666
                        }
×
667

668
                        // Silence ErrEdgeNotFound so that the batch can
669
                        // succeed, but propagate the error via local state.
670
                        if errors.Is(err, ErrEdgeNotFound) {
×
671
                                edgeNotFound = true
×
672
                                return nil
×
673
                        }
×
674

675
                        return err
×
676
                },
677
                OnCommit: func(err error) error {
×
678
                        switch {
×
679
                        case err != nil:
×
680
                                return err
×
681
                        case edgeNotFound:
×
682
                                return ErrEdgeNotFound
×
683
                        default:
×
684
                                s.updateEdgeCache(edge, isUpdate1)
×
685
                                return nil
×
686
                        }
687
                },
688
        }
689

690
        err := s.chanScheduler.Execute(ctx, r)
×
691

×
692
        return from, to, err
×
693
}
694

695
// updateEdgeCache updates our reject and channel caches with the new
696
// edge policy information.
697
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
698
        isUpdate1 bool) {
×
699

×
700
        // If an entry for this channel is found in reject cache, we'll modify
×
701
        // the entry with the updated timestamp for the direction that was just
×
702
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
703
        // during the next query for this edge.
×
704
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
705
                if isUpdate1 {
×
706
                        entry.upd1Time = e.LastUpdate.Unix()
×
707
                } else {
×
708
                        entry.upd2Time = e.LastUpdate.Unix()
×
709
                }
×
710
                s.rejectCache.insert(e.ChannelID, entry)
×
711
        }
712

713
        // If an entry for this channel is found in channel cache, we'll modify
714
        // the entry with the updated policy for the direction that was just
715
        // written. If the edge doesn't exist, we'll defer loading the info and
716
        // policies and lazily read from disk during the next query.
717
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
718
                if isUpdate1 {
×
719
                        channel.Policy1 = e
×
720
                } else {
×
721
                        channel.Policy2 = e
×
722
                }
×
723
                s.chanCache.insert(e.ChannelID, channel)
×
724
        }
725
}
726

727
// ForEachSourceNodeChannel iterates through all channels of the source node,
728
// executing the passed callback on each. The call-back is provided with the
729
// channel's outpoint, whether we have a policy for the channel and the channel
730
// peer's node information.
731
//
732
// NOTE: part of the V1Store interface.
733
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
734
        cb func(chanPoint wire.OutPoint, havePolicy bool,
735
                otherNode *models.LightningNode) error, reset func()) error {
×
736

×
737
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
738
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
739
                if err != nil {
×
740
                        return fmt.Errorf("unable to fetch source node: %w",
×
741
                                err)
×
742
                }
×
743

744
                return forEachNodeChannel(
×
745
                        ctx, db, s.cfg.ChainHash, nodeID,
×
746
                        func(info *models.ChannelEdgeInfo,
×
747
                                outPolicy *models.ChannelEdgePolicy,
×
748
                                _ *models.ChannelEdgePolicy) error {
×
749

×
750
                                // Fetch the other node.
×
751
                                var (
×
752
                                        otherNodePub [33]byte
×
753
                                        node1        = info.NodeKey1Bytes
×
754
                                        node2        = info.NodeKey2Bytes
×
755
                                )
×
756
                                switch {
×
757
                                case bytes.Equal(node1[:], nodePub[:]):
×
758
                                        otherNodePub = node2
×
759
                                case bytes.Equal(node2[:], nodePub[:]):
×
760
                                        otherNodePub = node1
×
761
                                default:
×
762
                                        return fmt.Errorf("node not " +
×
763
                                                "participating in this channel")
×
764
                                }
765

766
                                _, otherNode, err := getNodeByPubKey(
×
767
                                        ctx, db, otherNodePub,
×
768
                                )
×
769
                                if err != nil {
×
770
                                        return fmt.Errorf("unable to fetch "+
×
771
                                                "other node(%x): %w",
×
772
                                                otherNodePub, err)
×
773
                                }
×
774

775
                                return cb(
×
776
                                        info.ChannelPoint, outPolicy != nil,
×
777
                                        otherNode,
×
778
                                )
×
779
                        },
780
                )
781
        }, reset)
782
}
783

784
// ForEachNode iterates through all the stored vertices/nodes in the graph,
785
// executing the passed callback with each node encountered. If the callback
786
// returns an error, then the transaction is aborted and the iteration stops
787
// early. Any operations performed on the NodeTx passed to the call-back are
788
// executed under the same read transaction and so, methods on the NodeTx object
789
// _MUST_ only be called from within the call-back.
790
//
791
// NOTE: part of the V1Store interface.
792
func (s *SQLStore) ForEachNode(ctx context.Context,
793
        cb func(tx NodeRTx) error, reset func()) error {
×
794

×
795
        var lastID int64 = 0
×
796
        handleNode := func(db SQLQueries, dbNode sqlc.Node) error {
×
797
                node, err := buildNode(ctx, db, &dbNode)
×
798
                if err != nil {
×
799
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
800
                                dbNode.ID, err)
×
801
                }
×
802

803
                err = cb(
×
804
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
805
                )
×
806
                if err != nil {
×
807
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
808
                                dbNode.ID, err)
×
809
                }
×
810

811
                return nil
×
812
        }
813

814
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
815
                for {
×
816
                        nodes, err := db.ListNodesPaginated(
×
817
                                ctx, sqlc.ListNodesPaginatedParams{
×
818
                                        Version: int16(ProtocolV1),
×
819
                                        ID:      lastID,
×
820
                                        Limit:   pageSize,
×
821
                                },
×
822
                        )
×
823
                        if err != nil {
×
824
                                return fmt.Errorf("unable to fetch nodes: %w",
×
825
                                        err)
×
826
                        }
×
827

828
                        if len(nodes) == 0 {
×
829
                                break
×
830
                        }
831

832
                        for _, dbNode := range nodes {
×
833
                                err = handleNode(db, dbNode)
×
834
                                if err != nil {
×
835
                                        return err
×
836
                                }
×
837

838
                                lastID = dbNode.ID
×
839
                        }
840
                }
841

842
                return nil
×
843
        }, reset)
844
}
845

846
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
847
// SQLStore and a SQL transaction.
848
type sqlGraphNodeTx struct {
849
        db    SQLQueries
850
        id    int64
851
        node  *models.LightningNode
852
        chain chainhash.Hash
853
}
854

855
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
856
// interface.
857
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
858

859
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
860
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
861

×
862
        return &sqlGraphNodeTx{
×
863
                db:    db,
×
864
                chain: chain,
×
865
                id:    id,
×
866
                node:  node,
×
867
        }
×
868
}
×
869

870
// Node returns the raw information of the node.
871
//
872
// NOTE: This is a part of the NodeRTx interface.
873
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
874
        return s.node
×
875
}
×
876

877
// ForEachChannel can be used to iterate over the node's channels under the same
878
// transaction used to fetch the node.
879
//
880
// NOTE: This is a part of the NodeRTx interface.
881
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
882
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
883

×
884
        ctx := context.TODO()
×
885

×
886
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
887
}
×
888

889
// FetchNode fetches the node with the given pub key under the same transaction
890
// used to fetch the current node. The returned node is also a NodeRTx and any
891
// operations on that NodeRTx will also be done under the same transaction.
892
//
893
// NOTE: This is a part of the NodeRTx interface.
894
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
895
        ctx := context.TODO()
×
896

×
897
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
898
        if err != nil {
×
899
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
900
                        nodePub, err)
×
901
        }
×
902

903
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
904
}
905

906
// ForEachNodeDirectedChannel iterates through all channels of a given node,
907
// executing the passed callback on the directed edge representing the channel
908
// and its incoming policy. If the callback returns an error, then the iteration
909
// is halted with the error propagated back up to the caller.
910
//
911
// Unknown policies are passed into the callback as nil values.
912
//
913
// NOTE: this is part of the graphdb.NodeTraverser interface.
914
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
915
        cb func(channel *DirectedChannel) error, reset func()) error {
×
916

×
917
        var ctx = context.TODO()
×
918

×
919
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
920
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
921
        }, reset)
×
922
}
923

924
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
925
// graph, executing the passed callback with each node encountered. If the
926
// callback returns an error, then the transaction is aborted and the iteration
927
// stops early.
928
//
929
// NOTE: This is a part of the V1Store interface.
930
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
931
        cb func(route.Vertex, *lnwire.FeatureVector) error,
932
        reset func()) error {
×
933

×
934
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
935
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
936
                        nodePub route.Vertex) error {
×
937

×
938
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
939
                        if err != nil {
×
940
                                return fmt.Errorf("unable to fetch node "+
×
941
                                        "features: %w", err)
×
942
                        }
×
943

944
                        return cb(nodePub, features)
×
945
                })
946
        }, reset)
947
        if err != nil {
×
948
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
949
        }
×
950

951
        return nil
×
952
}
953

954
// ForEachNodeChannel iterates through all channels of the given node,
955
// executing the passed callback with an edge info structure and the policies
956
// of each end of the channel. The first edge policy is the outgoing edge *to*
957
// the connecting node, while the second is the incoming edge *from* the
958
// connecting node. If the callback returns an error, then the iteration is
959
// halted with the error propagated back up to the caller.
960
//
961
// Unknown policies are passed into the callback as nil values.
962
//
963
// NOTE: part of the V1Store interface.
964
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
965
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
966
                *models.ChannelEdgePolicy) error, reset func()) error {
×
967

×
968
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
969
                dbNode, err := db.GetNodeByPubKey(
×
970
                        ctx, sqlc.GetNodeByPubKeyParams{
×
971
                                Version: int16(ProtocolV1),
×
972
                                PubKey:  nodePub[:],
×
973
                        },
×
974
                )
×
975
                if errors.Is(err, sql.ErrNoRows) {
×
976
                        return nil
×
977
                } else if err != nil {
×
978
                        return fmt.Errorf("unable to fetch node: %w", err)
×
979
                }
×
980

981
                return forEachNodeChannel(
×
982
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
983
                )
×
984
        }, reset)
985
}
986

987
// ChanUpdatesInHorizon returns all the known channel edges which have at least
988
// one edge that has an update timestamp within the specified horizon.
989
//
990
// NOTE: This is part of the V1Store interface.
991
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
992
        endTime time.Time) ([]ChannelEdge, error) {
×
993

×
994
        s.cacheMu.Lock()
×
995
        defer s.cacheMu.Unlock()
×
996

×
997
        var (
×
998
                ctx = context.TODO()
×
999
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
1000
                // an additional map to keep track of the edges already seen to
×
1001
                // prevent re-adding it.
×
1002
                edgesSeen    = make(map[uint64]struct{})
×
1003
                edgesToCache = make(map[uint64]ChannelEdge)
×
1004
                edges        []ChannelEdge
×
1005
                hits         int
×
1006
        )
×
1007
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1008
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1009
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1010
                                Version:   int16(ProtocolV1),
×
1011
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1012
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1013
                        },
×
1014
                )
×
1015
                if err != nil {
×
1016
                        return err
×
1017
                }
×
1018

1019
                for _, row := range rows {
×
1020
                        // If we've already retrieved the info and policies for
×
1021
                        // this edge, then we can skip it as we don't need to do
×
1022
                        // so again.
×
1023
                        chanIDInt := byteOrder.Uint64(row.Channel.Scid)
×
1024
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1025
                                continue
×
1026
                        }
1027

1028
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1029
                                hits++
×
1030
                                edgesSeen[chanIDInt] = struct{}{}
×
1031
                                edges = append(edges, channel)
×
1032

×
1033
                                continue
×
1034
                        }
1035

1036
                        node1, node2, err := buildNodes(
×
1037
                                ctx, db, row.Node, row.Node_2,
×
1038
                        )
×
1039
                        if err != nil {
×
1040
                                return err
×
1041
                        }
×
1042

1043
                        channel, err := getAndBuildEdgeInfo(
×
1044
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1045
                                row.Channel, node1.PubKeyBytes,
×
1046
                                node2.PubKeyBytes,
×
1047
                        )
×
1048
                        if err != nil {
×
1049
                                return fmt.Errorf("unable to build channel "+
×
1050
                                        "info: %w", err)
×
1051
                        }
×
1052

1053
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1054
                        if err != nil {
×
1055
                                return fmt.Errorf("unable to extract channel "+
×
1056
                                        "policies: %w", err)
×
1057
                        }
×
1058

1059
                        p1, p2, err := getAndBuildChanPolicies(
×
1060
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1061
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1062
                        )
×
1063
                        if err != nil {
×
1064
                                return fmt.Errorf("unable to build channel "+
×
1065
                                        "policies: %w", err)
×
1066
                        }
×
1067

1068
                        edgesSeen[chanIDInt] = struct{}{}
×
1069
                        chanEdge := ChannelEdge{
×
1070
                                Info:    channel,
×
1071
                                Policy1: p1,
×
1072
                                Policy2: p2,
×
1073
                                Node1:   node1,
×
1074
                                Node2:   node2,
×
1075
                        }
×
1076
                        edges = append(edges, chanEdge)
×
1077
                        edgesToCache[chanIDInt] = chanEdge
×
1078
                }
1079

1080
                return nil
×
1081
        }, func() {
×
1082
                edgesSeen = make(map[uint64]struct{})
×
1083
                edgesToCache = make(map[uint64]ChannelEdge)
×
1084
                edges = nil
×
1085
        })
×
1086
        if err != nil {
×
1087
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1088
        }
×
1089

1090
        // Insert any edges loaded from disk into the cache.
1091
        for chanid, channel := range edgesToCache {
×
1092
                s.chanCache.insert(chanid, channel)
×
1093
        }
×
1094

1095
        if len(edges) > 0 {
×
1096
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1097
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1098
        } else {
×
1099
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1100
                        "horizon (%s, %s)", startTime, endTime)
×
1101
        }
×
1102

1103
        return edges, nil
×
1104
}
1105

1106
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1107
// data to the call-back.
1108
//
1109
// NOTE: The callback contents MUST not be modified.
1110
//
1111
// NOTE: part of the V1Store interface.
1112
func (s *SQLStore) ForEachNodeCached(ctx context.Context,
1113
        cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error,
1114
        reset func()) error {
×
1115

×
1116
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1117
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1118
                        nodePub route.Vertex) error {
×
1119

×
1120
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1121
                        if err != nil {
×
1122
                                return fmt.Errorf("unable to fetch "+
×
1123
                                        "node(id=%d) features: %w", nodeID, err)
×
1124
                        }
×
1125

1126
                        toNodeCallback := func() route.Vertex {
×
1127
                                return nodePub
×
1128
                        }
×
1129

1130
                        rows, err := db.ListChannelsByNodeID(
×
1131
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1132
                                        Version: int16(ProtocolV1),
×
1133
                                        NodeID1: nodeID,
×
1134
                                },
×
1135
                        )
×
1136
                        if err != nil {
×
1137
                                return fmt.Errorf("unable to fetch channels "+
×
1138
                                        "of node(id=%d): %w", nodeID, err)
×
1139
                        }
×
1140

1141
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1142
                        for _, row := range rows {
×
1143
                                node1, node2, err := buildNodeVertices(
×
1144
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1145
                                )
×
1146
                                if err != nil {
×
1147
                                        return err
×
1148
                                }
×
1149

1150
                                e, err := getAndBuildEdgeInfo(
×
1151
                                        ctx, db, s.cfg.ChainHash,
×
1152
                                        row.Channel.ID, row.Channel, node1,
×
1153
                                        node2,
×
1154
                                )
×
1155
                                if err != nil {
×
1156
                                        return fmt.Errorf("unable to build "+
×
1157
                                                "channel info: %w", err)
×
1158
                                }
×
1159

1160
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1161
                                        row,
×
1162
                                )
×
1163
                                if err != nil {
×
1164
                                        return fmt.Errorf("unable to "+
×
1165
                                                "extract channel "+
×
1166
                                                "policies: %w", err)
×
1167
                                }
×
1168

1169
                                p1, p2, err := getAndBuildChanPolicies(
×
1170
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1171
                                        node1, node2,
×
1172
                                )
×
1173
                                if err != nil {
×
1174
                                        return fmt.Errorf("unable to "+
×
1175
                                                "build channel policies: %w",
×
1176
                                                err)
×
1177
                                }
×
1178

1179
                                // Determine the outgoing and incoming policy
1180
                                // for this channel and node combo.
1181
                                outPolicy, inPolicy := p1, p2
×
1182
                                if p1 != nil && p1.ToNode == nodePub {
×
1183
                                        outPolicy, inPolicy = p2, p1
×
1184
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1185
                                        outPolicy, inPolicy = p2, p1
×
1186
                                }
×
1187

1188
                                var cachedInPolicy *models.CachedEdgePolicy
×
1189
                                if inPolicy != nil {
×
1190
                                        cachedInPolicy = models.NewCachedPolicy(
×
1191
                                                p2,
×
1192
                                        )
×
1193
                                        cachedInPolicy.ToNodePubKey =
×
1194
                                                toNodeCallback
×
1195
                                        cachedInPolicy.ToNodeFeatures =
×
1196
                                                features
×
1197
                                }
×
1198

1199
                                var inboundFee lnwire.Fee
×
1200
                                outPolicy.InboundFee.WhenSome(
×
1201
                                        func(fee lnwire.Fee) {
×
1202
                                                inboundFee = fee
×
1203
                                        },
×
1204
                                )
1205

1206
                                directedChannel := &DirectedChannel{
×
1207
                                        ChannelID: e.ChannelID,
×
1208
                                        IsNode1: nodePub ==
×
1209
                                                e.NodeKey1Bytes,
×
1210
                                        OtherNode:    e.NodeKey2Bytes,
×
1211
                                        Capacity:     e.Capacity,
×
1212
                                        OutPolicySet: p1 != nil,
×
1213
                                        InPolicy:     cachedInPolicy,
×
1214
                                        InboundFee:   inboundFee,
×
1215
                                }
×
1216

×
1217
                                if nodePub == e.NodeKey2Bytes {
×
1218
                                        directedChannel.OtherNode =
×
1219
                                                e.NodeKey1Bytes
×
1220
                                }
×
1221

1222
                                channels[e.ChannelID] = directedChannel
×
1223
                        }
1224

1225
                        return cb(nodePub, channels)
×
1226
                })
1227
        }, reset)
1228
}
1229

1230
// ForEachChannelCacheable iterates through all the channel edges stored
1231
// within the graph and invokes the passed callback for each edge. The
1232
// callback takes two edges as since this is a directed graph, both the
1233
// in/out edges are visited. If the callback returns an error, then the
1234
// transaction is aborted and the iteration stops early.
1235
//
1236
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1237
// pointer for that particular channel edge routing policy will be
1238
// passed into the callback.
1239
//
1240
// NOTE: this method is like ForEachChannel but fetches only the data
1241
// required for the graph cache.
1242
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1243
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1244
        reset func()) error {
×
1245

×
1246
        ctx := context.TODO()
×
1247

×
1248
        handleChannel := func(db SQLQueries,
×
1249
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1250

×
1251
                node1, node2, err := buildNodeVertices(
×
1252
                        row.Node1Pubkey, row.Node2Pubkey,
×
1253
                )
×
1254
                if err != nil {
×
1255
                        return err
×
1256
                }
×
1257

1258
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
1259

×
1260
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1261
                if err != nil {
×
1262
                        return err
×
1263
                }
×
1264

1265
                var pol1, pol2 *models.CachedEdgePolicy
×
1266
                if dbPol1 != nil {
×
1267
                        policy1, err := buildChanPolicy(
×
1268
                                *dbPol1, edge.ChannelID, nil, node2,
×
1269
                        )
×
1270
                        if err != nil {
×
1271
                                return err
×
1272
                        }
×
1273

1274
                        pol1 = models.NewCachedPolicy(policy1)
×
1275
                }
1276
                if dbPol2 != nil {
×
1277
                        policy2, err := buildChanPolicy(
×
1278
                                *dbPol2, edge.ChannelID, nil, node1,
×
1279
                        )
×
1280
                        if err != nil {
×
1281
                                return err
×
1282
                        }
×
1283

1284
                        pol2 = models.NewCachedPolicy(policy2)
×
1285
                }
1286

1287
                if err := cb(edge, pol1, pol2); err != nil {
×
1288
                        return err
×
1289
                }
×
1290

1291
                return nil
×
1292
        }
1293

1294
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1295
                lastID := int64(-1)
×
1296
                for {
×
1297
                        //nolint:ll
×
1298
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1299
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1300
                                        Version: int16(ProtocolV1),
×
1301
                                        ID:      lastID,
×
1302
                                        Limit:   pageSize,
×
1303
                                },
×
1304
                        )
×
1305
                        if err != nil {
×
1306
                                return err
×
1307
                        }
×
1308

1309
                        if len(rows) == 0 {
×
1310
                                break
×
1311
                        }
1312

1313
                        for _, row := range rows {
×
1314
                                err := handleChannel(db, row)
×
1315
                                if err != nil {
×
1316
                                        return err
×
1317
                                }
×
1318

1319
                                lastID = row.Channel.ID
×
1320
                        }
1321
                }
1322

1323
                return nil
×
1324
        }, reset)
1325
}
1326

1327
// ForEachChannel iterates through all the channel edges stored within the
1328
// graph and invokes the passed callback for each edge. The callback takes two
1329
// edges as since this is a directed graph, both the in/out edges are visited.
1330
// If the callback returns an error, then the transaction is aborted and the
1331
// iteration stops early.
1332
//
1333
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1334
// for that particular channel edge routing policy will be passed into the
1335
// callback.
1336
//
1337
// NOTE: part of the V1Store interface.
1338
func (s *SQLStore) ForEachChannel(ctx context.Context,
1339
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1340
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1341

×
1342
        handleChannel := func(db SQLQueries,
×
1343
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1344

×
1345
                node1, node2, err := buildNodeVertices(
×
1346
                        row.Node1Pubkey, row.Node2Pubkey,
×
1347
                )
×
1348
                if err != nil {
×
1349
                        return fmt.Errorf("unable to build node vertices: %w",
×
1350
                                err)
×
1351
                }
×
1352

1353
                edge, err := getAndBuildEdgeInfo(
×
1354
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1355
                        node1, node2,
×
1356
                )
×
1357
                if err != nil {
×
1358
                        return fmt.Errorf("unable to build channel info: %w",
×
1359
                                err)
×
1360
                }
×
1361

1362
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1363
                if err != nil {
×
1364
                        return fmt.Errorf("unable to extract channel "+
×
1365
                                "policies: %w", err)
×
1366
                }
×
1367

1368
                p1, p2, err := getAndBuildChanPolicies(
×
1369
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1370
                )
×
1371
                if err != nil {
×
1372
                        return fmt.Errorf("unable to build channel "+
×
1373
                                "policies: %w", err)
×
1374
                }
×
1375

1376
                err = cb(edge, p1, p2)
×
1377
                if err != nil {
×
1378
                        return fmt.Errorf("callback failed for channel "+
×
1379
                                "id=%d: %w", edge.ChannelID, err)
×
1380
                }
×
1381

1382
                return nil
×
1383
        }
1384

1385
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1386
                lastID := int64(-1)
×
1387
                for {
×
1388
                        //nolint:ll
×
1389
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1390
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1391
                                        Version: int16(ProtocolV1),
×
1392
                                        ID:      lastID,
×
1393
                                        Limit:   pageSize,
×
1394
                                },
×
1395
                        )
×
1396
                        if err != nil {
×
1397
                                return err
×
1398
                        }
×
1399

1400
                        if len(rows) == 0 {
×
1401
                                break
×
1402
                        }
1403

1404
                        for _, row := range rows {
×
1405
                                err := handleChannel(db, row)
×
1406
                                if err != nil {
×
1407
                                        return err
×
1408
                                }
×
1409

1410
                                lastID = row.Channel.ID
×
1411
                        }
1412
                }
1413

1414
                return nil
×
1415
        }, reset)
1416
}
1417

1418
// FilterChannelRange returns the channel ID's of all known channels which were
1419
// mined in a block height within the passed range. The channel IDs are grouped
1420
// by their common block height. This method can be used to quickly share with a
1421
// peer the set of channels we know of within a particular range to catch them
1422
// up after a period of time offline. If withTimestamps is true then the
1423
// timestamp info of the latest received channel update messages of the channel
1424
// will be included in the response.
1425
//
1426
// NOTE: This is part of the V1Store interface.
1427
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1428
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1429

×
1430
        var (
×
1431
                ctx       = context.TODO()
×
1432
                startSCID = &lnwire.ShortChannelID{
×
1433
                        BlockHeight: startHeight,
×
1434
                }
×
1435
                endSCID = lnwire.ShortChannelID{
×
1436
                        BlockHeight: endHeight,
×
1437
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1438
                        TxPosition:  math.MaxUint16,
×
1439
                }
×
1440
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1441
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1442
        )
×
1443

×
1444
        // 1) get all channels where channelID is between start and end chan ID.
×
1445
        // 2) skip if not public (ie, no channel_proof)
×
1446
        // 3) collect that channel.
×
1447
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1448
        //    and add those timestamps to the collected channel.
×
1449
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1450
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1451
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1452
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1453
                                StartScid: chanIDStart,
×
1454
                                EndScid:   chanIDEnd,
×
1455
                        },
×
1456
                )
×
1457
                if err != nil {
×
1458
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1459
                                err)
×
1460
                }
×
1461

1462
                for _, dbChan := range dbChans {
×
1463
                        cid := lnwire.NewShortChanIDFromInt(
×
1464
                                byteOrder.Uint64(dbChan.Scid),
×
1465
                        )
×
1466
                        chanInfo := NewChannelUpdateInfo(
×
1467
                                cid, time.Time{}, time.Time{},
×
1468
                        )
×
1469

×
1470
                        if !withTimestamps {
×
1471
                                channelsPerBlock[cid.BlockHeight] = append(
×
1472
                                        channelsPerBlock[cid.BlockHeight],
×
1473
                                        chanInfo,
×
1474
                                )
×
1475

×
1476
                                continue
×
1477
                        }
1478

1479
                        //nolint:ll
1480
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1481
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1482
                                        Version:   int16(ProtocolV1),
×
1483
                                        ChannelID: dbChan.ID,
×
1484
                                        NodeID:    dbChan.NodeID1,
×
1485
                                },
×
1486
                        )
×
1487
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1488
                                return fmt.Errorf("unable to fetch node1 "+
×
1489
                                        "policy: %w", err)
×
1490
                        } else if err == nil {
×
1491
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1492
                                        node1Policy.LastUpdate.Int64, 0,
×
1493
                                )
×
1494
                        }
×
1495

1496
                        //nolint:ll
1497
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1498
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1499
                                        Version:   int16(ProtocolV1),
×
1500
                                        ChannelID: dbChan.ID,
×
1501
                                        NodeID:    dbChan.NodeID2,
×
1502
                                },
×
1503
                        )
×
1504
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1505
                                return fmt.Errorf("unable to fetch node2 "+
×
1506
                                        "policy: %w", err)
×
1507
                        } else if err == nil {
×
1508
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1509
                                        node2Policy.LastUpdate.Int64, 0,
×
1510
                                )
×
1511
                        }
×
1512

1513
                        channelsPerBlock[cid.BlockHeight] = append(
×
1514
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1515
                        )
×
1516
                }
1517

1518
                return nil
×
1519
        }, func() {
×
1520
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1521
        })
×
1522
        if err != nil {
×
1523
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1524
        }
×
1525

1526
        if len(channelsPerBlock) == 0 {
×
1527
                return nil, nil
×
1528
        }
×
1529

1530
        // Return the channel ranges in ascending block height order.
1531
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1532
        slices.Sort(blocks)
×
1533

×
1534
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1535
                return BlockChannelRange{
×
1536
                        Height:   block,
×
1537
                        Channels: channelsPerBlock[block],
×
1538
                }
×
1539
        }), nil
×
1540
}
1541

1542
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1543
// zombie. This method is used on an ad-hoc basis, when channels need to be
1544
// marked as zombies outside the normal pruning cycle.
1545
//
1546
// NOTE: part of the V1Store interface.
1547
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1548
        pubKey1, pubKey2 [33]byte) error {
×
1549

×
1550
        ctx := context.TODO()
×
1551

×
1552
        s.cacheMu.Lock()
×
1553
        defer s.cacheMu.Unlock()
×
1554

×
1555
        chanIDB := channelIDToBytes(chanID)
×
1556

×
1557
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1558
                return db.UpsertZombieChannel(
×
1559
                        ctx, sqlc.UpsertZombieChannelParams{
×
1560
                                Version:  int16(ProtocolV1),
×
1561
                                Scid:     chanIDB,
×
1562
                                NodeKey1: pubKey1[:],
×
1563
                                NodeKey2: pubKey2[:],
×
1564
                        },
×
1565
                )
×
1566
        }, sqldb.NoOpReset)
×
1567
        if err != nil {
×
1568
                return fmt.Errorf("unable to upsert zombie channel "+
×
1569
                        "(channel_id=%d): %w", chanID, err)
×
1570
        }
×
1571

1572
        s.rejectCache.remove(chanID)
×
1573
        s.chanCache.remove(chanID)
×
1574

×
1575
        return nil
×
1576
}
1577

1578
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1579
//
1580
// NOTE: part of the V1Store interface.
1581
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1582
        s.cacheMu.Lock()
×
1583
        defer s.cacheMu.Unlock()
×
1584

×
1585
        var (
×
1586
                ctx     = context.TODO()
×
1587
                chanIDB = channelIDToBytes(chanID)
×
1588
        )
×
1589

×
1590
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1591
                res, err := db.DeleteZombieChannel(
×
1592
                        ctx, sqlc.DeleteZombieChannelParams{
×
1593
                                Scid:    chanIDB,
×
1594
                                Version: int16(ProtocolV1),
×
1595
                        },
×
1596
                )
×
1597
                if err != nil {
×
1598
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1599
                                err)
×
1600
                }
×
1601

1602
                rows, err := res.RowsAffected()
×
1603
                if err != nil {
×
1604
                        return err
×
1605
                }
×
1606

1607
                if rows == 0 {
×
1608
                        return ErrZombieEdgeNotFound
×
1609
                } else if rows > 1 {
×
1610
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1611
                                "expected 1", rows)
×
1612
                }
×
1613

1614
                return nil
×
1615
        }, sqldb.NoOpReset)
1616
        if err != nil {
×
1617
                return fmt.Errorf("unable to mark edge live "+
×
1618
                        "(channel_id=%d): %w", chanID, err)
×
1619
        }
×
1620

1621
        s.rejectCache.remove(chanID)
×
1622
        s.chanCache.remove(chanID)
×
1623

×
1624
        return err
×
1625
}
1626

1627
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1628
// zombie, then the two node public keys corresponding to this edge are also
1629
// returned.
1630
//
1631
// NOTE: part of the V1Store interface.
1632
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1633
        error) {
×
1634

×
1635
        var (
×
1636
                ctx              = context.TODO()
×
1637
                isZombie         bool
×
1638
                pubKey1, pubKey2 route.Vertex
×
1639
                chanIDB          = channelIDToBytes(chanID)
×
1640
        )
×
1641

×
1642
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1643
                zombie, err := db.GetZombieChannel(
×
1644
                        ctx, sqlc.GetZombieChannelParams{
×
1645
                                Scid:    chanIDB,
×
1646
                                Version: int16(ProtocolV1),
×
1647
                        },
×
1648
                )
×
1649
                if errors.Is(err, sql.ErrNoRows) {
×
1650
                        return nil
×
1651
                }
×
1652
                if err != nil {
×
1653
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1654
                                err)
×
1655
                }
×
1656

1657
                copy(pubKey1[:], zombie.NodeKey1)
×
1658
                copy(pubKey2[:], zombie.NodeKey2)
×
1659
                isZombie = true
×
1660

×
1661
                return nil
×
1662
        }, sqldb.NoOpReset)
1663
        if err != nil {
×
1664
                return false, route.Vertex{}, route.Vertex{},
×
1665
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1666
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1667
        }
×
1668

1669
        return isZombie, pubKey1, pubKey2, nil
×
1670
}
1671

1672
// NumZombies returns the current number of zombie channels in the graph.
1673
//
1674
// NOTE: part of the V1Store interface.
1675
func (s *SQLStore) NumZombies() (uint64, error) {
×
1676
        var (
×
1677
                ctx        = context.TODO()
×
1678
                numZombies uint64
×
1679
        )
×
1680
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1681
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1682
                if err != nil {
×
1683
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1684
                                err)
×
1685
                }
×
1686

1687
                numZombies = uint64(count)
×
1688

×
1689
                return nil
×
1690
        }, sqldb.NoOpReset)
1691
        if err != nil {
×
1692
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1693
        }
×
1694

1695
        return numZombies, nil
×
1696
}
1697

1698
// DeleteChannelEdges removes edges with the given channel IDs from the
1699
// database and marks them as zombies. This ensures that we're unable to re-add
1700
// it to our database once again. If an edge does not exist within the
1701
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1702
// true, then when we mark these edges as zombies, we'll set up the keys such
1703
// that we require the node that failed to send the fresh update to be the one
1704
// that resurrects the channel from its zombie state. The markZombie bool
1705
// denotes whether to mark the channel as a zombie.
1706
//
1707
// NOTE: part of the V1Store interface.
1708
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1709
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1710

×
1711
        s.cacheMu.Lock()
×
1712
        defer s.cacheMu.Unlock()
×
1713

×
1714
        var (
×
1715
                ctx     = context.TODO()
×
1716
                deleted []*models.ChannelEdgeInfo
×
1717
        )
×
1718
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1719
                for _, chanID := range chanIDs {
×
1720
                        chanIDB := channelIDToBytes(chanID)
×
1721

×
1722
                        row, err := db.GetChannelBySCIDWithPolicies(
×
1723
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1724
                                        Scid:    chanIDB,
×
1725
                                        Version: int16(ProtocolV1),
×
1726
                                },
×
1727
                        )
×
1728
                        if errors.Is(err, sql.ErrNoRows) {
×
1729
                                return ErrEdgeNotFound
×
1730
                        } else if err != nil {
×
1731
                                return fmt.Errorf("unable to fetch channel: %w",
×
1732
                                        err)
×
1733
                        }
×
1734

1735
                        node1, node2, err := buildNodeVertices(
×
1736
                                row.Node.PubKey, row.Node_2.PubKey,
×
1737
                        )
×
1738
                        if err != nil {
×
1739
                                return err
×
1740
                        }
×
1741

1742
                        info, err := getAndBuildEdgeInfo(
×
1743
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1744
                                row.Channel, node1, node2,
×
1745
                        )
×
1746
                        if err != nil {
×
1747
                                return err
×
1748
                        }
×
1749

1750
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
1751
                        if err != nil {
×
1752
                                return fmt.Errorf("unable to delete "+
×
1753
                                        "channel: %w", err)
×
1754
                        }
×
1755

1756
                        deleted = append(deleted, info)
×
1757

×
1758
                        if !markZombie {
×
1759
                                continue
×
1760
                        }
1761

1762
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1763
                                info.NodeKey2Bytes
×
1764
                        if strictZombiePruning {
×
1765
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1766
                                if row.Policy1LastUpdate.Valid {
×
1767
                                        e1Time := time.Unix(
×
1768
                                                row.Policy1LastUpdate.Int64, 0,
×
1769
                                        )
×
1770
                                        e1UpdateTime = &e1Time
×
1771
                                }
×
1772
                                if row.Policy2LastUpdate.Valid {
×
1773
                                        e2Time := time.Unix(
×
1774
                                                row.Policy2LastUpdate.Int64, 0,
×
1775
                                        )
×
1776
                                        e2UpdateTime = &e2Time
×
1777
                                }
×
1778

1779
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1780
                                        info, e1UpdateTime, e2UpdateTime,
×
1781
                                )
×
1782
                        }
1783

1784
                        err = db.UpsertZombieChannel(
×
1785
                                ctx, sqlc.UpsertZombieChannelParams{
×
1786
                                        Version:  int16(ProtocolV1),
×
1787
                                        Scid:     chanIDB,
×
1788
                                        NodeKey1: nodeKey1[:],
×
1789
                                        NodeKey2: nodeKey2[:],
×
1790
                                },
×
1791
                        )
×
1792
                        if err != nil {
×
1793
                                return fmt.Errorf("unable to mark channel as "+
×
1794
                                        "zombie: %w", err)
×
1795
                        }
×
1796
                }
1797

1798
                return nil
×
1799
        }, func() {
×
1800
                deleted = nil
×
1801
        })
×
1802
        if err != nil {
×
1803
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1804
                        err)
×
1805
        }
×
1806

1807
        for _, chanID := range chanIDs {
×
1808
                s.rejectCache.remove(chanID)
×
1809
                s.chanCache.remove(chanID)
×
1810
        }
×
1811

1812
        return deleted, nil
×
1813
}
1814

1815
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1816
// channel identified by the channel ID. If the channel can't be found, then
1817
// ErrEdgeNotFound is returned. A struct which houses the general information
1818
// for the channel itself is returned as well as two structs that contain the
1819
// routing policies for the channel in either direction.
1820
//
1821
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1822
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1823
// the ChannelEdgeInfo will only include the public keys of each node.
1824
//
1825
// NOTE: part of the V1Store interface.
1826
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1827
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1828
        *models.ChannelEdgePolicy, error) {
×
1829

×
1830
        var (
×
1831
                ctx              = context.TODO()
×
1832
                edge             *models.ChannelEdgeInfo
×
1833
                policy1, policy2 *models.ChannelEdgePolicy
×
1834
                chanIDB          = channelIDToBytes(chanID)
×
1835
        )
×
1836
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1837
                row, err := db.GetChannelBySCIDWithPolicies(
×
1838
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1839
                                Scid:    chanIDB,
×
1840
                                Version: int16(ProtocolV1),
×
1841
                        },
×
1842
                )
×
1843
                if errors.Is(err, sql.ErrNoRows) {
×
1844
                        // First check if this edge is perhaps in the zombie
×
1845
                        // index.
×
1846
                        zombie, err := db.GetZombieChannel(
×
1847
                                ctx, sqlc.GetZombieChannelParams{
×
1848
                                        Scid:    chanIDB,
×
1849
                                        Version: int16(ProtocolV1),
×
1850
                                },
×
1851
                        )
×
1852
                        if errors.Is(err, sql.ErrNoRows) {
×
1853
                                return ErrEdgeNotFound
×
1854
                        } else if err != nil {
×
1855
                                return fmt.Errorf("unable to check if "+
×
1856
                                        "channel is zombie: %w", err)
×
1857
                        }
×
1858

1859
                        // At this point, we know the channel is a zombie, so
1860
                        // we'll return an error indicating this, and we will
1861
                        // populate the edge info with the public keys of each
1862
                        // party as this is the only information we have about
1863
                        // it.
1864
                        edge = &models.ChannelEdgeInfo{}
×
1865
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1866
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1867

×
1868
                        return ErrZombieEdge
×
1869
                } else if err != nil {
×
1870
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1871
                }
×
1872

1873
                node1, node2, err := buildNodeVertices(
×
1874
                        row.Node.PubKey, row.Node_2.PubKey,
×
1875
                )
×
1876
                if err != nil {
×
1877
                        return err
×
1878
                }
×
1879

1880
                edge, err = getAndBuildEdgeInfo(
×
1881
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1882
                        node1, node2,
×
1883
                )
×
1884
                if err != nil {
×
1885
                        return fmt.Errorf("unable to build channel info: %w",
×
1886
                                err)
×
1887
                }
×
1888

1889
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1890
                if err != nil {
×
1891
                        return fmt.Errorf("unable to extract channel "+
×
1892
                                "policies: %w", err)
×
1893
                }
×
1894

1895
                policy1, policy2, err = getAndBuildChanPolicies(
×
1896
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1897
                )
×
1898
                if err != nil {
×
1899
                        return fmt.Errorf("unable to build channel "+
×
1900
                                "policies: %w", err)
×
1901
                }
×
1902

1903
                return nil
×
1904
        }, sqldb.NoOpReset)
1905
        if err != nil {
×
1906
                // If we are returning the ErrZombieEdge, then we also need to
×
1907
                // return the edge info as the method comment indicates that
×
1908
                // this will be populated when the edge is a zombie.
×
1909
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1910
                        err)
×
1911
        }
×
1912

1913
        return edge, policy1, policy2, nil
×
1914
}
1915

1916
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1917
// the channel identified by the funding outpoint. If the channel can't be
1918
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1919
// information for the channel itself is returned as well as two structs that
1920
// contain the routing policies for the channel in either direction.
1921
//
1922
// NOTE: part of the V1Store interface.
1923
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1924
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1925
        *models.ChannelEdgePolicy, error) {
×
1926

×
1927
        var (
×
1928
                ctx              = context.TODO()
×
1929
                edge             *models.ChannelEdgeInfo
×
1930
                policy1, policy2 *models.ChannelEdgePolicy
×
1931
        )
×
1932
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1933
                row, err := db.GetChannelByOutpointWithPolicies(
×
1934
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1935
                                Outpoint: op.String(),
×
1936
                                Version:  int16(ProtocolV1),
×
1937
                        },
×
1938
                )
×
1939
                if errors.Is(err, sql.ErrNoRows) {
×
1940
                        return ErrEdgeNotFound
×
1941
                } else if err != nil {
×
1942
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1943
                }
×
1944

1945
                node1, node2, err := buildNodeVertices(
×
1946
                        row.Node1Pubkey, row.Node2Pubkey,
×
1947
                )
×
1948
                if err != nil {
×
1949
                        return err
×
1950
                }
×
1951

1952
                edge, err = getAndBuildEdgeInfo(
×
1953
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1954
                        node1, node2,
×
1955
                )
×
1956
                if err != nil {
×
1957
                        return fmt.Errorf("unable to build channel info: %w",
×
1958
                                err)
×
1959
                }
×
1960

1961
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1962
                if err != nil {
×
1963
                        return fmt.Errorf("unable to extract channel "+
×
1964
                                "policies: %w", err)
×
1965
                }
×
1966

1967
                policy1, policy2, err = getAndBuildChanPolicies(
×
1968
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1969
                )
×
1970
                if err != nil {
×
1971
                        return fmt.Errorf("unable to build channel "+
×
1972
                                "policies: %w", err)
×
1973
                }
×
1974

1975
                return nil
×
1976
        }, sqldb.NoOpReset)
1977
        if err != nil {
×
1978
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1979
                        err)
×
1980
        }
×
1981

1982
        return edge, policy1, policy2, nil
×
1983
}
1984

1985
// HasChannelEdge returns true if the database knows of a channel edge with the
1986
// passed channel ID, and false otherwise. If an edge with that ID is found
1987
// within the graph, then two time stamps representing the last time the edge
1988
// was updated for both directed edges are returned along with the boolean. If
1989
// it is not found, then the zombie index is checked and its result is returned
1990
// as the second boolean.
1991
//
1992
// NOTE: part of the V1Store interface.
1993
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1994
        bool, error) {
×
1995

×
1996
        ctx := context.TODO()
×
1997

×
1998
        var (
×
1999
                exists          bool
×
2000
                isZombie        bool
×
2001
                node1LastUpdate time.Time
×
2002
                node2LastUpdate time.Time
×
2003
        )
×
2004

×
2005
        // We'll query the cache with the shared lock held to allow multiple
×
2006
        // readers to access values in the cache concurrently if they exist.
×
2007
        s.cacheMu.RLock()
×
2008
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2009
                s.cacheMu.RUnlock()
×
2010
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2011
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2012
                exists, isZombie = entry.flags.unpack()
×
2013

×
2014
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2015
        }
×
2016
        s.cacheMu.RUnlock()
×
2017

×
2018
        s.cacheMu.Lock()
×
2019
        defer s.cacheMu.Unlock()
×
2020

×
2021
        // The item was not found with the shared lock, so we'll acquire the
×
2022
        // exclusive lock and check the cache again in case another method added
×
2023
        // the entry to the cache while no lock was held.
×
2024
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2025
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2026
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2027
                exists, isZombie = entry.flags.unpack()
×
2028

×
2029
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2030
        }
×
2031

2032
        chanIDB := channelIDToBytes(chanID)
×
2033
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2034
                channel, err := db.GetChannelBySCID(
×
2035
                        ctx, sqlc.GetChannelBySCIDParams{
×
2036
                                Scid:    chanIDB,
×
2037
                                Version: int16(ProtocolV1),
×
2038
                        },
×
2039
                )
×
2040
                if errors.Is(err, sql.ErrNoRows) {
×
2041
                        // Check if it is a zombie channel.
×
2042
                        isZombie, err = db.IsZombieChannel(
×
2043
                                ctx, sqlc.IsZombieChannelParams{
×
2044
                                        Scid:    chanIDB,
×
2045
                                        Version: int16(ProtocolV1),
×
2046
                                },
×
2047
                        )
×
2048
                        if err != nil {
×
2049
                                return fmt.Errorf("could not check if channel "+
×
2050
                                        "is zombie: %w", err)
×
2051
                        }
×
2052

2053
                        return nil
×
2054
                } else if err != nil {
×
2055
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2056
                }
×
2057

2058
                exists = true
×
2059

×
2060
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2061
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2062
                                Version:   int16(ProtocolV1),
×
2063
                                ChannelID: channel.ID,
×
2064
                                NodeID:    channel.NodeID1,
×
2065
                        },
×
2066
                )
×
2067
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2068
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2069
                                err)
×
2070
                } else if err == nil {
×
2071
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2072
                }
×
2073

2074
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2075
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2076
                                Version:   int16(ProtocolV1),
×
2077
                                ChannelID: channel.ID,
×
2078
                                NodeID:    channel.NodeID2,
×
2079
                        },
×
2080
                )
×
2081
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2082
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2083
                                err)
×
2084
                } else if err == nil {
×
2085
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2086
                }
×
2087

2088
                return nil
×
2089
        }, sqldb.NoOpReset)
2090
        if err != nil {
×
2091
                return time.Time{}, time.Time{}, false, false,
×
2092
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2093
        }
×
2094

2095
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2096
                upd1Time: node1LastUpdate.Unix(),
×
2097
                upd2Time: node2LastUpdate.Unix(),
×
2098
                flags:    packRejectFlags(exists, isZombie),
×
2099
        })
×
2100

×
2101
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2102
}
2103

2104
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2105
// passed channel point (outpoint). If the passed channel doesn't exist within
2106
// the database, then ErrEdgeNotFound is returned.
2107
//
2108
// NOTE: part of the V1Store interface.
2109
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2110
        var (
×
2111
                ctx       = context.TODO()
×
2112
                channelID uint64
×
2113
        )
×
2114
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2115
                chanID, err := db.GetSCIDByOutpoint(
×
2116
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2117
                                Outpoint: chanPoint.String(),
×
2118
                                Version:  int16(ProtocolV1),
×
2119
                        },
×
2120
                )
×
2121
                if errors.Is(err, sql.ErrNoRows) {
×
2122
                        return ErrEdgeNotFound
×
2123
                } else if err != nil {
×
2124
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2125
                                err)
×
2126
                }
×
2127

2128
                channelID = byteOrder.Uint64(chanID)
×
2129

×
2130
                return nil
×
2131
        }, sqldb.NoOpReset)
2132
        if err != nil {
×
2133
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2134
        }
×
2135

2136
        return channelID, nil
×
2137
}
2138

2139
// IsPublicNode is a helper method that determines whether the node with the
2140
// given public key is seen as a public node in the graph from the graph's
2141
// source node's point of view.
2142
//
2143
// NOTE: part of the V1Store interface.
2144
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2145
        ctx := context.TODO()
×
2146

×
2147
        var isPublic bool
×
2148
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2149
                var err error
×
2150
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2151

×
2152
                return err
×
2153
        }, sqldb.NoOpReset)
×
2154
        if err != nil {
×
2155
                return false, fmt.Errorf("unable to check if node is "+
×
2156
                        "public: %w", err)
×
2157
        }
×
2158

2159
        return isPublic, nil
×
2160
}
2161

2162
// FetchChanInfos returns the set of channel edges that correspond to the passed
2163
// channel ID's. If an edge is the query is unknown to the database, it will
2164
// skipped and the result will contain only those edges that exist at the time
2165
// of the query. This can be used to respond to peer queries that are seeking to
2166
// fill in gaps in their view of the channel graph.
2167
//
2168
// NOTE: part of the V1Store interface.
2169
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2170
        var (
×
2171
                ctx   = context.TODO()
×
NEW
2172
                edges = make(map[uint64]ChannelEdge)
×
2173
        )
×
2174
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2175
                chanCallBack := func(ctx context.Context,
×
NEW
2176
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2177

×
2178
                        node1, node2, err := buildNodes(
×
2179
                                ctx, db, row.Node, row.Node_2,
×
2180
                        )
×
2181
                        if err != nil {
×
2182
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2183
                                        err)
×
2184
                        }
×
2185

2186
                        edge, err := getAndBuildEdgeInfo(
×
2187
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2188
                                row.Channel, node1.PubKeyBytes,
×
2189
                                node2.PubKeyBytes,
×
2190
                        )
×
2191
                        if err != nil {
×
2192
                                return fmt.Errorf("unable to build "+
×
2193
                                        "channel info: %w", err)
×
2194
                        }
×
2195

2196
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2197
                        if err != nil {
×
2198
                                return fmt.Errorf("unable to extract channel "+
×
2199
                                        "policies: %w", err)
×
2200
                        }
×
2201

2202
                        p1, p2, err := getAndBuildChanPolicies(
×
2203
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2204
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2205
                        )
×
2206
                        if err != nil {
×
2207
                                return fmt.Errorf("unable to build channel "+
×
2208
                                        "policies: %w", err)
×
2209
                        }
×
2210

NEW
2211
                        edges[edge.ChannelID] = ChannelEdge{
×
2212
                                Info:    edge,
×
2213
                                Policy1: p1,
×
2214
                                Policy2: p2,
×
2215
                                Node1:   node1,
×
2216
                                Node2:   node2,
×
NEW
2217
                        }
×
NEW
2218

×
NEW
2219
                        return nil
×
2220
                }
2221

NEW
2222
                queryWrapper := func(ctx context.Context, scids [][]byte) (
×
NEW
2223
                        []sqlc.GetChannelsBySCIDWithPoliciesRow, error) {
×
NEW
2224

×
NEW
2225
                        return db.GetChannelsBySCIDWithPolicies(
×
NEW
2226
                                ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
NEW
2227
                                        Scids:   scids,
×
NEW
2228
                                        Version: int16(ProtocolV1),
×
NEW
2229
                                },
×
NEW
2230
                        )
×
NEW
2231
                }
×
2232

NEW
2233
                err := sqldb.ExecutePagedQuery(
×
NEW
2234
                        ctx, s.cfg.PaginationCfg, chanIDs, channelIDToBytes,
×
NEW
2235
                        queryWrapper, chanCallBack,
×
NEW
2236
                )
×
NEW
2237
                if err != nil {
×
NEW
2238
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
UNCOV
2239
                }
×
2240

2241
                return nil
×
2242
        }, func() {
×
NEW
2243
                clear(edges)
×
2244
        })
×
2245
        if err != nil {
×
2246
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2247
        }
×
2248

NEW
2249
        res := make([]ChannelEdge, 0, len(edges))
×
NEW
2250
        for _, chanID := range chanIDs {
×
NEW
2251
                edge, ok := edges[chanID]
×
NEW
2252
                if !ok {
×
NEW
2253
                        continue
×
2254
                }
2255

NEW
2256
                res = append(res, edge)
×
2257
        }
2258

NEW
2259
        return res, nil
×
2260
}
2261

2262
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2263
// ID's that we don't know and are not known zombies of the passed set. In other
2264
// words, we perform a set difference of our set of chan ID's and the ones
2265
// passed in. This method can be used by callers to determine the set of
2266
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2267
// known zombies is also returned.
2268
//
2269
// NOTE: part of the V1Store interface.
2270
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2271
        []ChannelUpdateInfo, error) {
×
2272

×
2273
        var (
×
2274
                ctx          = context.TODO()
×
2275
                newChanIDs   []uint64
×
2276
                knownZombies []ChannelUpdateInfo
×
NEW
2277
                infoLookup   = make(
×
NEW
2278
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
NEW
2279
                )
×
2280
        )
×
NEW
2281

×
NEW
2282
        for _, chanInfo := range chansInfo {
×
NEW
2283
                scid := chanInfo.ShortChannelID.ToUint64()
×
NEW
2284
                infoLookup[scid] = chanInfo
×
NEW
2285
        }
×
2286

2287
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2288
                queryWrapper := func(ctx context.Context,
×
NEW
2289
                        scids [][]byte) ([]sqlc.Channel, error) {
×
2290

×
NEW
2291
                        return db.GetChannelsBySCIDs(
×
NEW
2292
                                ctx, sqlc.GetChannelsBySCIDsParams{
×
2293
                                        Version: int16(ProtocolV1),
×
NEW
2294
                                        Scids:   scids,
×
2295
                                },
×
2296
                        )
×
NEW
2297
                }
×
2298

NEW
2299
                chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
NEW
2300
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
NEW
2301

×
NEW
2302
                        return channelIDToBytes(channelID)
×
NEW
2303
                }
×
2304

NEW
2305
                cb := func(ctx context.Context, channel sqlc.Channel) error {
×
NEW
2306
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
NEW
2307

×
NEW
2308
                        return nil
×
NEW
2309
                }
×
2310

NEW
2311
                err := sqldb.ExecutePagedQuery(
×
NEW
2312
                        ctx, s.cfg.PaginationCfg, chansInfo, chanIDConverter,
×
NEW
2313
                        queryWrapper, cb,
×
NEW
2314
                )
×
NEW
2315
                if err != nil {
×
NEW
2316
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
2317
                }
×
2318

NEW
2319
                for _, chanInfo := range chansInfo {
×
NEW
2320
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
NEW
2321
                        if _, ok := infoLookup[channelID]; !ok {
×
2322
                                continue
×
2323
                        }
NEW
2324
                        chanIDB := channelIDToBytes(channelID)
×
2325

×
2326
                        isZombie, err := db.IsZombieChannel(
×
2327
                                ctx, sqlc.IsZombieChannelParams{
×
2328
                                        Scid:    chanIDB,
×
2329
                                        Version: int16(ProtocolV1),
×
2330
                                },
×
2331
                        )
×
2332
                        if err != nil {
×
2333
                                return fmt.Errorf("unable to fetch zombie "+
×
2334
                                        "channel: %w", err)
×
2335
                        }
×
2336

2337
                        if isZombie {
×
2338
                                knownZombies = append(knownZombies, chanInfo)
×
2339

×
2340
                                continue
×
2341
                        }
2342

2343
                        newChanIDs = append(newChanIDs, channelID)
×
2344
                }
2345

2346
                return nil
×
2347
        }, func() {
×
2348
                newChanIDs = nil
×
2349
                knownZombies = nil
×
2350
        })
×
2351
        if err != nil {
×
2352
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2353
        }
×
2354

2355
        return newChanIDs, knownZombies, nil
×
2356
}
2357

2358
// PruneGraphNodes is a garbage collection method which attempts to prune out
2359
// any nodes from the channel graph that are currently unconnected. This ensure
2360
// that we only maintain a graph of reachable nodes. In the event that a pruned
2361
// node gains more channels, it will be re-added back to the graph.
2362
//
2363
// NOTE: this prunes nodes across protocol versions. It will never prune the
2364
// source nodes.
2365
//
2366
// NOTE: part of the V1Store interface.
2367
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2368
        var ctx = context.TODO()
×
2369

×
2370
        var prunedNodes []route.Vertex
×
2371
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2372
                var err error
×
2373
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2374

×
2375
                return err
×
2376
        }, func() {
×
2377
                prunedNodes = nil
×
2378
        })
×
2379
        if err != nil {
×
2380
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2381
        }
×
2382

2383
        return prunedNodes, nil
×
2384
}
2385

2386
// PruneGraph prunes newly closed channels from the channel graph in response
2387
// to a new block being solved on the network. Any transactions which spend the
2388
// funding output of any known channels within he graph will be deleted.
2389
// Additionally, the "prune tip", or the last block which has been used to
2390
// prune the graph is stored so callers can ensure the graph is fully in sync
2391
// with the current UTXO state. A slice of channels that have been closed by
2392
// the target block along with any pruned nodes are returned if the function
2393
// succeeds without error.
2394
//
2395
// NOTE: part of the V1Store interface.
2396
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2397
        blockHash *chainhash.Hash, blockHeight uint32) (
2398
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2399

×
2400
        ctx := context.TODO()
×
2401

×
2402
        s.cacheMu.Lock()
×
2403
        defer s.cacheMu.Unlock()
×
2404

×
2405
        var (
×
2406
                closedChans []*models.ChannelEdgeInfo
×
2407
                prunedNodes []route.Vertex
×
2408
        )
×
2409
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2410
                // Define the callback function for processing each channel
×
NEW
2411
                channelCallback := func(ctx context.Context,
×
NEW
2412
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2413

×
2414
                        node1, node2, err := buildNodeVertices(
×
2415
                                row.Node1Pubkey, row.Node2Pubkey,
×
2416
                        )
×
2417
                        if err != nil {
×
2418
                                return err
×
2419
                        }
×
2420

2421
                        info, err := getAndBuildEdgeInfo(
×
2422
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2423
                                row.Channel, node1, node2,
×
2424
                        )
×
2425
                        if err != nil {
×
2426
                                return err
×
2427
                        }
×
2428

2429
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2430
                        if err != nil {
×
2431
                                return fmt.Errorf("unable to delete "+
×
2432
                                        "channel: %w", err)
×
2433
                        }
×
2434

2435
                        closedChans = append(closedChans, info)
×
NEW
2436
                        return nil
×
2437
                }
2438

NEW
2439
                err := s.forEachChanInOutpoints(
×
NEW
2440
                        ctx, db, spentOutputs, channelCallback,
×
NEW
2441
                )
×
NEW
2442
                if err != nil {
×
NEW
2443
                        return fmt.Errorf("unable to fetch channels by "+
×
NEW
2444
                                "outpoints: %w", err)
×
NEW
2445
                }
×
2446

NEW
2447
                err = db.UpsertPruneLogEntry(
×
2448
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2449
                                BlockHash:   blockHash[:],
×
2450
                                BlockHeight: int64(blockHeight),
×
2451
                        },
×
2452
                )
×
2453
                if err != nil {
×
2454
                        return fmt.Errorf("unable to insert prune log "+
×
2455
                                "entry: %w", err)
×
2456
                }
×
2457

2458
                // Now that we've pruned some channels, we'll also prune any
2459
                // nodes that no longer have any channels.
2460
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2461
                if err != nil {
×
2462
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2463
                                err)
×
2464
                }
×
2465

2466
                return nil
×
2467
        }, func() {
×
2468
                prunedNodes = nil
×
2469
                closedChans = nil
×
2470
        })
×
2471
        if err != nil {
×
2472
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2473
        }
×
2474

2475
        for _, channel := range closedChans {
×
2476
                s.rejectCache.remove(channel.ChannelID)
×
2477
                s.chanCache.remove(channel.ChannelID)
×
2478
        }
×
2479

2480
        return closedChans, prunedNodes, nil
×
2481
}
2482

2483
// forEachChanInOutpoints is a helper function that executes a paginated
2484
// query to fetch channels by their outpoints and applies the given call-back
2485
// to each.
2486
//
2487
// NOTE: this fetches channels for all protocol versions.
2488
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2489
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
NEW
2490
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
NEW
2491

×
NEW
2492
        // Create a wrapper that uses the transaction's db instance to execute
×
NEW
2493
        // the query.
×
NEW
2494
        queryWrapper := func(ctx context.Context,
×
NEW
2495
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
NEW
2496
                error) {
×
NEW
2497

×
NEW
2498
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
NEW
2499
        }
×
2500

2501
        // Define the conversion function from Outpoint to string
NEW
2502
        outpointToString := func(outpoint *wire.OutPoint) string {
×
NEW
2503
                return outpoint.String()
×
NEW
2504
        }
×
2505

NEW
2506
        return sqldb.ExecutePagedQuery(
×
NEW
2507
                ctx, s.cfg.PaginationCfg, outpoints, outpointToString,
×
NEW
2508
                queryWrapper, cb,
×
NEW
2509
        )
×
2510
}
2511

2512
// ChannelView returns the verifiable edge information for each active channel
2513
// within the known channel graph. The set of UTXOs (along with their scripts)
2514
// returned are the ones that need to be watched on chain to detect channel
2515
// closes on the resident blockchain.
2516
//
2517
// NOTE: part of the V1Store interface.
2518
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2519
        var (
×
2520
                ctx        = context.TODO()
×
2521
                edgePoints []EdgePoint
×
2522
        )
×
2523

×
2524
        handleChannel := func(db SQLQueries,
×
2525
                channel sqlc.ListChannelsPaginatedRow) error {
×
2526

×
2527
                pkScript, err := genMultiSigP2WSH(
×
2528
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
2529
                )
×
2530
                if err != nil {
×
2531
                        return err
×
2532
                }
×
2533

2534
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2535
                if err != nil {
×
2536
                        return err
×
2537
                }
×
2538

2539
                edgePoints = append(edgePoints, EdgePoint{
×
2540
                        FundingPkScript: pkScript,
×
2541
                        OutPoint:        *op,
×
2542
                })
×
2543

×
2544
                return nil
×
2545
        }
2546

2547
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2548
                lastID := int64(-1)
×
2549
                for {
×
2550
                        rows, err := db.ListChannelsPaginated(
×
2551
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2552
                                        Version: int16(ProtocolV1),
×
2553
                                        ID:      lastID,
×
2554
                                        Limit:   pageSize,
×
2555
                                },
×
2556
                        )
×
2557
                        if err != nil {
×
2558
                                return err
×
2559
                        }
×
2560

2561
                        if len(rows) == 0 {
×
2562
                                break
×
2563
                        }
2564

2565
                        for _, row := range rows {
×
2566
                                err := handleChannel(db, row)
×
2567
                                if err != nil {
×
2568
                                        return err
×
2569
                                }
×
2570

2571
                                lastID = row.ID
×
2572
                        }
2573
                }
2574

2575
                return nil
×
2576
        }, func() {
×
2577
                edgePoints = nil
×
2578
        })
×
2579
        if err != nil {
×
2580
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2581
        }
×
2582

2583
        return edgePoints, nil
×
2584
}
2585

2586
// PruneTip returns the block height and hash of the latest block that has been
2587
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2588
// to tell if the graph is currently in sync with the current best known UTXO
2589
// state.
2590
//
2591
// NOTE: part of the V1Store interface.
2592
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2593
        var (
×
2594
                ctx       = context.TODO()
×
2595
                tipHash   chainhash.Hash
×
2596
                tipHeight uint32
×
2597
        )
×
2598
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2599
                pruneTip, err := db.GetPruneTip(ctx)
×
2600
                if errors.Is(err, sql.ErrNoRows) {
×
2601
                        return ErrGraphNeverPruned
×
2602
                } else if err != nil {
×
2603
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2604
                }
×
2605

2606
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2607
                tipHeight = uint32(pruneTip.BlockHeight)
×
2608

×
2609
                return nil
×
2610
        }, sqldb.NoOpReset)
2611
        if err != nil {
×
2612
                return nil, 0, err
×
2613
        }
×
2614

2615
        return &tipHash, tipHeight, nil
×
2616
}
2617

2618
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2619
//
2620
// NOTE: this prunes nodes across protocol versions. It will never prune the
2621
// source nodes.
2622
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2623
        db SQLQueries) ([]route.Vertex, error) {
×
2624

×
2625
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2626
        if err != nil {
×
2627
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2628
                        "nodes: %w", err)
×
2629
        }
×
2630

2631
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2632
        for i, nodeKey := range nodeKeys {
×
2633
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2634
                if err != nil {
×
2635
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2636
                                "from bytes: %w", err)
×
2637
                }
×
2638

2639
                prunedNodes[i] = pub
×
2640
        }
2641

2642
        return prunedNodes, nil
×
2643
}
2644

2645
// DisconnectBlockAtHeight is used to indicate that the block specified
2646
// by the passed height has been disconnected from the main chain. This
2647
// will "rewind" the graph back to the height below, deleting channels
2648
// that are no longer confirmed from the graph. The prune log will be
2649
// set to the last prune height valid for the remaining chain.
2650
// Channels that were removed from the graph resulting from the
2651
// disconnected block are returned.
2652
//
2653
// NOTE: part of the V1Store interface.
2654
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2655
        []*models.ChannelEdgeInfo, error) {
×
2656

×
2657
        ctx := context.TODO()
×
2658

×
2659
        var (
×
2660
                // Every channel having a ShortChannelID starting at 'height'
×
2661
                // will no longer be confirmed.
×
2662
                startShortChanID = lnwire.ShortChannelID{
×
2663
                        BlockHeight: height,
×
2664
                }
×
2665

×
2666
                // Delete everything after this height from the db up until the
×
2667
                // SCID alias range.
×
2668
                endShortChanID = aliasmgr.StartingAlias
×
2669

×
2670
                removedChans []*models.ChannelEdgeInfo
×
2671

×
2672
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2673
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2674
        )
×
2675

×
2676
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2677
                rows, err := db.GetChannelsBySCIDRange(
×
2678
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2679
                                StartScid: chanIDStart,
×
2680
                                EndScid:   chanIDEnd,
×
2681
                        },
×
2682
                )
×
2683
                if err != nil {
×
2684
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2685
                }
×
2686

2687
                for _, row := range rows {
×
2688
                        node1, node2, err := buildNodeVertices(
×
2689
                                row.Node1PubKey, row.Node2PubKey,
×
2690
                        )
×
2691
                        if err != nil {
×
2692
                                return err
×
2693
                        }
×
2694

2695
                        channel, err := getAndBuildEdgeInfo(
×
2696
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2697
                                row.Channel, node1, node2,
×
2698
                        )
×
2699
                        if err != nil {
×
2700
                                return err
×
2701
                        }
×
2702

2703
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2704
                        if err != nil {
×
2705
                                return fmt.Errorf("unable to delete "+
×
2706
                                        "channel: %w", err)
×
2707
                        }
×
2708

2709
                        removedChans = append(removedChans, channel)
×
2710
                }
2711

2712
                return db.DeletePruneLogEntriesInRange(
×
2713
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2714
                                StartHeight: int64(height),
×
2715
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2716
                        },
×
2717
                )
×
2718
        }, func() {
×
2719
                removedChans = nil
×
2720
        })
×
2721
        if err != nil {
×
2722
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2723
                        "height: %w", err)
×
2724
        }
×
2725

2726
        for _, channel := range removedChans {
×
2727
                s.rejectCache.remove(channel.ChannelID)
×
2728
                s.chanCache.remove(channel.ChannelID)
×
2729
        }
×
2730

2731
        return removedChans, nil
×
2732
}
2733

2734
// AddEdgeProof sets the proof of an existing edge in the graph database.
2735
//
2736
// NOTE: part of the V1Store interface.
2737
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2738
        proof *models.ChannelAuthProof) error {
×
2739

×
2740
        var (
×
2741
                ctx       = context.TODO()
×
2742
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2743
        )
×
2744

×
2745
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2746
                res, err := db.AddV1ChannelProof(
×
2747
                        ctx, sqlc.AddV1ChannelProofParams{
×
2748
                                Scid:              scidBytes,
×
2749
                                Node1Signature:    proof.NodeSig1Bytes,
×
2750
                                Node2Signature:    proof.NodeSig2Bytes,
×
2751
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2752
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2753
                        },
×
2754
                )
×
2755
                if err != nil {
×
2756
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2757
                }
×
2758

2759
                n, err := res.RowsAffected()
×
2760
                if err != nil {
×
2761
                        return err
×
2762
                }
×
2763

2764
                if n == 0 {
×
2765
                        return fmt.Errorf("no rows affected when adding edge "+
×
2766
                                "proof for SCID %v", scid)
×
2767
                } else if n > 1 {
×
2768
                        return fmt.Errorf("multiple rows affected when adding "+
×
2769
                                "edge proof for SCID %v: %d rows affected",
×
2770
                                scid, n)
×
2771
                }
×
2772

2773
                return nil
×
2774
        }, sqldb.NoOpReset)
2775
        if err != nil {
×
2776
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2777
        }
×
2778

2779
        return nil
×
2780
}
2781

2782
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2783
// that we can ignore channel announcements that we know to be closed without
2784
// having to validate them and fetch a block.
2785
//
2786
// NOTE: part of the V1Store interface.
2787
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2788
        var (
×
2789
                ctx     = context.TODO()
×
2790
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2791
        )
×
2792

×
2793
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2794
                return db.InsertClosedChannel(ctx, chanIDB)
×
2795
        }, sqldb.NoOpReset)
×
2796
}
2797

2798
// IsClosedScid checks whether a channel identified by the passed in scid is
2799
// closed. This helps avoid having to perform expensive validation checks.
2800
//
2801
// NOTE: part of the V1Store interface.
2802
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2803
        var (
×
2804
                ctx      = context.TODO()
×
2805
                isClosed bool
×
2806
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2807
        )
×
2808
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2809
                var err error
×
2810
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2811
                if err != nil {
×
2812
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2813
                                err)
×
2814
                }
×
2815

2816
                return nil
×
2817
        }, sqldb.NoOpReset)
2818
        if err != nil {
×
2819
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2820
                        err)
×
2821
        }
×
2822

2823
        return isClosed, nil
×
2824
}
2825

2826
// GraphSession will provide the call-back with access to a NodeTraverser
2827
// instance which can be used to perform queries against the channel graph.
2828
//
2829
// NOTE: part of the V1Store interface.
2830
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
2831
        reset func()) error {
×
2832

×
2833
        var ctx = context.TODO()
×
2834

×
2835
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2836
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2837
        }, reset)
×
2838
}
2839

2840
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2841
// read only transaction for a consistent view of the graph.
2842
type sqlNodeTraverser struct {
2843
        db    SQLQueries
2844
        chain chainhash.Hash
2845
}
2846

2847
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2848
// NodeTraverser interface.
2849
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2850

2851
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2852
func newSQLNodeTraverser(db SQLQueries,
2853
        chain chainhash.Hash) *sqlNodeTraverser {
×
2854

×
2855
        return &sqlNodeTraverser{
×
2856
                db:    db,
×
2857
                chain: chain,
×
2858
        }
×
2859
}
×
2860

2861
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2862
// node.
2863
//
2864
// NOTE: Part of the NodeTraverser interface.
2865
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2866
        cb func(channel *DirectedChannel) error, _ func()) error {
×
2867

×
2868
        ctx := context.TODO()
×
2869

×
2870
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2871
}
×
2872

2873
// FetchNodeFeatures returns the features of the given node. If the node is
2874
// unknown, assume no additional features are supported.
2875
//
2876
// NOTE: Part of the NodeTraverser interface.
2877
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2878
        *lnwire.FeatureVector, error) {
×
2879

×
2880
        ctx := context.TODO()
×
2881

×
2882
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2883
}
×
2884

2885
// forEachNodeDirectedChannel iterates through all channels of a given
2886
// node, executing the passed callback on the directed edge representing the
2887
// channel and its incoming policy. If the node is not found, no error is
2888
// returned.
2889
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2890
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2891

×
2892
        toNodeCallback := func() route.Vertex {
×
2893
                return nodePub
×
2894
        }
×
2895

2896
        dbID, err := db.GetNodeIDByPubKey(
×
2897
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2898
                        Version: int16(ProtocolV1),
×
2899
                        PubKey:  nodePub[:],
×
2900
                },
×
2901
        )
×
2902
        if errors.Is(err, sql.ErrNoRows) {
×
2903
                return nil
×
2904
        } else if err != nil {
×
2905
                return fmt.Errorf("unable to fetch node: %w", err)
×
2906
        }
×
2907

2908
        rows, err := db.ListChannelsByNodeID(
×
2909
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2910
                        Version: int16(ProtocolV1),
×
2911
                        NodeID1: dbID,
×
2912
                },
×
2913
        )
×
2914
        if err != nil {
×
2915
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2916
        }
×
2917

2918
        // Exit early if there are no channels for this node so we don't
2919
        // do the unnecessary feature fetching.
2920
        if len(rows) == 0 {
×
2921
                return nil
×
2922
        }
×
2923

2924
        features, err := getNodeFeatures(ctx, db, dbID)
×
2925
        if err != nil {
×
2926
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2927
        }
×
2928

2929
        for _, row := range rows {
×
2930
                node1, node2, err := buildNodeVertices(
×
2931
                        row.Node1Pubkey, row.Node2Pubkey,
×
2932
                )
×
2933
                if err != nil {
×
2934
                        return fmt.Errorf("unable to build node vertices: %w",
×
2935
                                err)
×
2936
                }
×
2937

2938
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
2939

×
2940
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2941
                if err != nil {
×
2942
                        return err
×
2943
                }
×
2944

2945
                var p1, p2 *models.CachedEdgePolicy
×
2946
                if dbPol1 != nil {
×
2947
                        policy1, err := buildChanPolicy(
×
2948
                                *dbPol1, edge.ChannelID, nil, node2,
×
2949
                        )
×
2950
                        if err != nil {
×
2951
                                return err
×
2952
                        }
×
2953

2954
                        p1 = models.NewCachedPolicy(policy1)
×
2955
                }
2956
                if dbPol2 != nil {
×
2957
                        policy2, err := buildChanPolicy(
×
2958
                                *dbPol2, edge.ChannelID, nil, node1,
×
2959
                        )
×
2960
                        if err != nil {
×
2961
                                return err
×
2962
                        }
×
2963

2964
                        p2 = models.NewCachedPolicy(policy2)
×
2965
                }
2966

2967
                // Determine the outgoing and incoming policy for this
2968
                // channel and node combo.
2969
                outPolicy, inPolicy := p1, p2
×
2970
                if p1 != nil && node2 == nodePub {
×
2971
                        outPolicy, inPolicy = p2, p1
×
2972
                } else if p2 != nil && node1 != nodePub {
×
2973
                        outPolicy, inPolicy = p2, p1
×
2974
                }
×
2975

2976
                var cachedInPolicy *models.CachedEdgePolicy
×
2977
                if inPolicy != nil {
×
2978
                        cachedInPolicy = inPolicy
×
2979
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2980
                        cachedInPolicy.ToNodeFeatures = features
×
2981
                }
×
2982

2983
                directedChannel := &DirectedChannel{
×
2984
                        ChannelID:    edge.ChannelID,
×
2985
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2986
                        OtherNode:    edge.NodeKey2Bytes,
×
2987
                        Capacity:     edge.Capacity,
×
2988
                        OutPolicySet: outPolicy != nil,
×
2989
                        InPolicy:     cachedInPolicy,
×
2990
                }
×
2991
                if outPolicy != nil {
×
2992
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2993
                                directedChannel.InboundFee = fee
×
2994
                        })
×
2995
                }
2996

2997
                if nodePub == edge.NodeKey2Bytes {
×
2998
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2999
                }
×
3000

3001
                if err := cb(directedChannel); err != nil {
×
3002
                        return err
×
3003
                }
×
3004
        }
3005

3006
        return nil
×
3007
}
3008

3009
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3010
// and executes the provided callback for each node.
3011
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
3012
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
3013

×
3014
        lastID := int64(-1)
×
3015

×
3016
        for {
×
3017
                nodes, err := db.ListNodeIDsAndPubKeys(
×
3018
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3019
                                Version: int16(ProtocolV1),
×
3020
                                ID:      lastID,
×
3021
                                Limit:   pageSize,
×
3022
                        },
×
3023
                )
×
3024
                if err != nil {
×
3025
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
3026
                }
×
3027

3028
                if len(nodes) == 0 {
×
3029
                        break
×
3030
                }
3031

3032
                for _, node := range nodes {
×
3033
                        var pub route.Vertex
×
3034
                        copy(pub[:], node.PubKey)
×
3035

×
3036
                        if err := cb(node.ID, pub); err != nil {
×
3037
                                return fmt.Errorf("forEachNodeCacheable "+
×
3038
                                        "callback failed for node(id=%d): %w",
×
3039
                                        node.ID, err)
×
3040
                        }
×
3041

3042
                        lastID = node.ID
×
3043
                }
3044
        }
3045

3046
        return nil
×
3047
}
3048

3049
// forEachNodeChannel iterates through all channels of a node, executing
3050
// the passed callback on each. The call-back is provided with the channel's
3051
// edge information, the outgoing policy and the incoming policy for the
3052
// channel and node combo.
3053
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3054
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
3055
                *models.ChannelEdgePolicy,
3056
                *models.ChannelEdgePolicy) error) error {
×
3057

×
3058
        // Get all the V1 channels for this node.Add commentMore actions
×
3059
        rows, err := db.ListChannelsByNodeID(
×
3060
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3061
                        Version: int16(ProtocolV1),
×
3062
                        NodeID1: id,
×
3063
                },
×
3064
        )
×
3065
        if err != nil {
×
3066
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3067
        }
×
3068

3069
        // Call the call-back for each channel and its known policies.
3070
        for _, row := range rows {
×
3071
                node1, node2, err := buildNodeVertices(
×
3072
                        row.Node1Pubkey, row.Node2Pubkey,
×
3073
                )
×
3074
                if err != nil {
×
3075
                        return fmt.Errorf("unable to build node vertices: %w",
×
3076
                                err)
×
3077
                }
×
3078

3079
                edge, err := getAndBuildEdgeInfo(
×
3080
                        ctx, db, chain, row.Channel.ID, row.Channel, node1,
×
3081
                        node2,
×
3082
                )
×
3083
                if err != nil {
×
3084
                        return fmt.Errorf("unable to build channel info: %w",
×
3085
                                err)
×
3086
                }
×
3087

3088
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3089
                if err != nil {
×
3090
                        return fmt.Errorf("unable to extract channel "+
×
3091
                                "policies: %w", err)
×
3092
                }
×
3093

3094
                p1, p2, err := getAndBuildChanPolicies(
×
3095
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3096
                )
×
3097
                if err != nil {
×
3098
                        return fmt.Errorf("unable to build channel "+
×
3099
                                "policies: %w", err)
×
3100
                }
×
3101

3102
                // Determine the outgoing and incoming policy for this
3103
                // channel and node combo.
3104
                p1ToNode := row.Channel.NodeID2
×
3105
                p2ToNode := row.Channel.NodeID1
×
3106
                outPolicy, inPolicy := p1, p2
×
3107
                if (p1 != nil && p1ToNode == id) ||
×
3108
                        (p2 != nil && p2ToNode != id) {
×
3109

×
3110
                        outPolicy, inPolicy = p2, p1
×
3111
                }
×
3112

3113
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3114
                        return err
×
3115
                }
×
3116
        }
3117

3118
        return nil
×
3119
}
3120

3121
// updateChanEdgePolicy upserts the channel policy info we have stored for
3122
// a channel we already know of.
3123
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3124
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3125
        error) {
×
3126

×
3127
        var (
×
3128
                node1Pub, node2Pub route.Vertex
×
3129
                isNode1            bool
×
3130
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3131
        )
×
3132

×
3133
        // Check that this edge policy refers to a channel that we already
×
3134
        // know of. We do this explicitly so that we can return the appropriate
×
3135
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3136
        // abort the transaction which would abort the entire batch.
×
3137
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3138
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3139
                        Scid:    chanIDB,
×
3140
                        Version: int16(ProtocolV1),
×
3141
                },
×
3142
        )
×
3143
        if errors.Is(err, sql.ErrNoRows) {
×
3144
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3145
        } else if err != nil {
×
3146
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3147
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3148
        }
×
3149

3150
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3151
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3152

×
3153
        // Figure out which node this edge is from.
×
3154
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3155
        nodeID := dbChan.NodeID1
×
3156
        if !isNode1 {
×
3157
                nodeID = dbChan.NodeID2
×
3158
        }
×
3159

3160
        var (
×
3161
                inboundBase sql.NullInt64
×
3162
                inboundRate sql.NullInt64
×
3163
        )
×
3164
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3165
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3166
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3167
        })
×
3168

3169
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3170
                Version:     int16(ProtocolV1),
×
3171
                ChannelID:   dbChan.ID,
×
3172
                NodeID:      nodeID,
×
3173
                Timelock:    int32(edge.TimeLockDelta),
×
3174
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3175
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3176
                MinHtlcMsat: int64(edge.MinHTLC),
×
3177
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3178
                Disabled: sql.NullBool{
×
3179
                        Valid: true,
×
3180
                        Bool:  edge.IsDisabled(),
×
3181
                },
×
3182
                MaxHtlcMsat: sql.NullInt64{
×
3183
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3184
                        Int64: int64(edge.MaxHTLC),
×
3185
                },
×
3186
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3187
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3188
                InboundBaseFeeMsat:      inboundBase,
×
3189
                InboundFeeRateMilliMsat: inboundRate,
×
3190
                Signature:               edge.SigBytes,
×
3191
        })
×
3192
        if err != nil {
×
3193
                return node1Pub, node2Pub, isNode1,
×
3194
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3195
        }
×
3196

3197
        // Convert the flat extra opaque data into a map of TLV types to
3198
        // values.
3199
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3200
        if err != nil {
×
3201
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3202
                        "marshal extra opaque data: %w", err)
×
3203
        }
×
3204

3205
        // Update the channel policy's extra signed fields.
3206
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3207
        if err != nil {
×
3208
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3209
                        "policy extra TLVs: %w", err)
×
3210
        }
×
3211

3212
        return node1Pub, node2Pub, isNode1, nil
×
3213
}
3214

3215
// getNodeByPubKey attempts to look up a target node by its public key.
3216
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3217
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3218

×
3219
        dbNode, err := db.GetNodeByPubKey(
×
3220
                ctx, sqlc.GetNodeByPubKeyParams{
×
3221
                        Version: int16(ProtocolV1),
×
3222
                        PubKey:  pubKey[:],
×
3223
                },
×
3224
        )
×
3225
        if errors.Is(err, sql.ErrNoRows) {
×
3226
                return 0, nil, ErrGraphNodeNotFound
×
3227
        } else if err != nil {
×
3228
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3229
        }
×
3230

3231
        node, err := buildNode(ctx, db, &dbNode)
×
3232
        if err != nil {
×
3233
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3234
        }
×
3235

3236
        return dbNode.ID, node, nil
×
3237
}
3238

3239
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3240
// provided database channel row and the public keys of the two nodes
3241
// involved in the channel.
3242
func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub,
3243
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3244

×
3245
        return &models.CachedEdgeInfo{
×
3246
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
3247
                NodeKey1Bytes: node1Pub,
×
3248
                NodeKey2Bytes: node2Pub,
×
3249
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
3250
        }
×
3251
}
×
3252

3253
// buildNode constructs a LightningNode instance from the given database node
3254
// record. The node's features, addresses and extra signed fields are also
3255
// fetched from the database and set on the node.
3256
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
3257
        *models.LightningNode, error) {
×
3258

×
3259
        if dbNode.Version != int16(ProtocolV1) {
×
3260
                return nil, fmt.Errorf("unsupported node version: %d",
×
3261
                        dbNode.Version)
×
3262
        }
×
3263

3264
        var pub [33]byte
×
3265
        copy(pub[:], dbNode.PubKey)
×
3266

×
3267
        node := &models.LightningNode{
×
3268
                PubKeyBytes: pub,
×
3269
                Features:    lnwire.EmptyFeatureVector(),
×
3270
                LastUpdate:  time.Unix(0, 0),
×
3271
        }
×
3272

×
3273
        if len(dbNode.Signature) == 0 {
×
3274
                return node, nil
×
3275
        }
×
3276

3277
        node.HaveNodeAnnouncement = true
×
3278
        node.AuthSigBytes = dbNode.Signature
×
3279
        node.Alias = dbNode.Alias.String
×
3280
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3281

×
3282
        var err error
×
3283
        if dbNode.Color.Valid {
×
3284
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3285
                if err != nil {
×
3286
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3287
                                err)
×
3288
                }
×
3289
        }
3290

3291
        // Fetch the node's features.
3292
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
3293
        if err != nil {
×
3294
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3295
                        "features: %w", dbNode.ID, err)
×
3296
        }
×
3297

3298
        // Fetch the node's addresses.
3299
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
3300
        if err != nil {
×
3301
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3302
                        "addresses: %w", dbNode.ID, err)
×
3303
        }
×
3304

3305
        // Fetch the node's extra signed fields.
3306
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
3307
        if err != nil {
×
3308
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3309
                        "extra signed fields: %w", dbNode.ID, err)
×
3310
        }
×
3311

3312
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
3313
        if err != nil {
×
3314
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3315
                        "fields: %w", err)
×
3316
        }
×
3317

3318
        if len(recs) != 0 {
×
3319
                node.ExtraOpaqueData = recs
×
3320
        }
×
3321

3322
        return node, nil
×
3323
}
3324

3325
// getNodeFeatures fetches the feature bits and constructs the feature vector
3326
// for a node with the given DB ID.
3327
func getNodeFeatures(ctx context.Context, db SQLQueries,
3328
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3329

×
3330
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3331
        if err != nil {
×
3332
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3333
                        nodeID, err)
×
3334
        }
×
3335

3336
        features := lnwire.EmptyFeatureVector()
×
3337
        for _, feature := range rows {
×
3338
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3339
        }
×
3340

3341
        return features, nil
×
3342
}
3343

3344
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3345
// given DB ID.
3346
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3347
        nodeID int64) (map[uint64][]byte, error) {
×
3348

×
3349
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3350
        if err != nil {
×
3351
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
3352
                        "signed fields: %w", nodeID, err)
×
3353
        }
×
3354

3355
        extraFields := make(map[uint64][]byte)
×
3356
        for _, field := range fields {
×
3357
                extraFields[uint64(field.Type)] = field.Value
×
3358
        }
×
3359

3360
        return extraFields, nil
×
3361
}
3362

3363
// upsertNode upserts the node record into the database. If the node already
3364
// exists, then the node's information is updated. If the node doesn't exist,
3365
// then a new node is created. The node's features, addresses and extra TLV
3366
// types are also updated. The node's DB ID is returned.
3367
func upsertNode(ctx context.Context, db SQLQueries,
3368
        node *models.LightningNode) (int64, error) {
×
3369

×
3370
        params := sqlc.UpsertNodeParams{
×
3371
                Version: int16(ProtocolV1),
×
3372
                PubKey:  node.PubKeyBytes[:],
×
3373
        }
×
3374

×
3375
        if node.HaveNodeAnnouncement {
×
3376
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3377
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3378
                params.Alias = sqldb.SQLStr(node.Alias)
×
3379
                params.Signature = node.AuthSigBytes
×
3380
        }
×
3381

3382
        nodeID, err := db.UpsertNode(ctx, params)
×
3383
        if err != nil {
×
3384
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3385
                        err)
×
3386
        }
×
3387

3388
        // We can exit here if we don't have the announcement yet.
3389
        if !node.HaveNodeAnnouncement {
×
3390
                return nodeID, nil
×
3391
        }
×
3392

3393
        // Update the node's features.
3394
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3395
        if err != nil {
×
3396
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3397
        }
×
3398

3399
        // Update the node's addresses.
3400
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3401
        if err != nil {
×
3402
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3403
        }
×
3404

3405
        // Convert the flat extra opaque data into a map of TLV types to
3406
        // values.
3407
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3408
        if err != nil {
×
3409
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3410
                        err)
×
3411
        }
×
3412

3413
        // Update the node's extra signed fields.
3414
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3415
        if err != nil {
×
3416
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3417
        }
×
3418

3419
        return nodeID, nil
×
3420
}
3421

3422
// upsertNodeFeatures updates the node's features node_features table. This
3423
// includes deleting any feature bits no longer present and inserting any new
3424
// feature bits. If the feature bit does not yet exist in the features table,
3425
// then an entry is created in that table first.
3426
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3427
        features *lnwire.FeatureVector) error {
×
3428

×
3429
        // Get any existing features for the node.
×
3430
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3431
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3432
                return err
×
3433
        }
×
3434

3435
        // Copy the nodes latest set of feature bits.
3436
        newFeatures := make(map[int32]struct{})
×
3437
        if features != nil {
×
3438
                for feature := range features.Features() {
×
3439
                        newFeatures[int32(feature)] = struct{}{}
×
3440
                }
×
3441
        }
3442

3443
        // For any current feature that already exists in the DB, remove it from
3444
        // the in-memory map. For any existing feature that does not exist in
3445
        // the in-memory map, delete it from the database.
3446
        for _, feature := range existingFeatures {
×
3447
                // The feature is still present, so there are no updates to be
×
3448
                // made.
×
3449
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3450
                        delete(newFeatures, feature.FeatureBit)
×
3451
                        continue
×
3452
                }
3453

3454
                // The feature is no longer present, so we remove it from the
3455
                // database.
3456
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3457
                        NodeID:     nodeID,
×
3458
                        FeatureBit: feature.FeatureBit,
×
3459
                })
×
3460
                if err != nil {
×
3461
                        return fmt.Errorf("unable to delete node(%d) "+
×
3462
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3463
                                err)
×
3464
                }
×
3465
        }
3466

3467
        // Any remaining entries in newFeatures are new features that need to be
3468
        // added to the database for the first time.
3469
        for feature := range newFeatures {
×
3470
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3471
                        NodeID:     nodeID,
×
3472
                        FeatureBit: feature,
×
3473
                })
×
3474
                if err != nil {
×
3475
                        return fmt.Errorf("unable to insert node(%d) "+
×
3476
                                "feature(%v): %w", nodeID, feature, err)
×
3477
                }
×
3478
        }
3479

3480
        return nil
×
3481
}
3482

3483
// fetchNodeFeatures fetches the features for a node with the given public key.
3484
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3485
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3486

×
3487
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3488
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3489
                        PubKey:  nodePub[:],
×
3490
                        Version: int16(ProtocolV1),
×
3491
                },
×
3492
        )
×
3493
        if err != nil {
×
3494
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3495
                        nodePub, err)
×
3496
        }
×
3497

3498
        features := lnwire.EmptyFeatureVector()
×
3499
        for _, bit := range rows {
×
3500
                features.Set(lnwire.FeatureBit(bit))
×
3501
        }
×
3502

3503
        return features, nil
×
3504
}
3505

3506
// dbAddressType is an enum type that represents the different address types
3507
// that we store in the node_addresses table. The address type determines how
3508
// the address is to be serialised/deserialize.
3509
type dbAddressType uint8
3510

3511
const (
3512
        addressTypeIPv4   dbAddressType = 1
3513
        addressTypeIPv6   dbAddressType = 2
3514
        addressTypeTorV2  dbAddressType = 3
3515
        addressTypeTorV3  dbAddressType = 4
3516
        addressTypeOpaque dbAddressType = math.MaxInt8
3517
)
3518

3519
// upsertNodeAddresses updates the node's addresses in the database. This
3520
// includes deleting any existing addresses and inserting the new set of
3521
// addresses. The deletion is necessary since the ordering of the addresses may
3522
// change, and we need to ensure that the database reflects the latest set of
3523
// addresses so that at the time of reconstructing the node announcement, the
3524
// order is preserved and the signature over the message remains valid.
3525
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3526
        addresses []net.Addr) error {
×
3527

×
3528
        // Delete any existing addresses for the node. This is required since
×
3529
        // even if the new set of addresses is the same, the ordering may have
×
3530
        // changed for a given address type.
×
3531
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3532
        if err != nil {
×
3533
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3534
                        nodeID, err)
×
3535
        }
×
3536

3537
        // Copy the nodes latest set of addresses.
3538
        newAddresses := map[dbAddressType][]string{
×
3539
                addressTypeIPv4:   {},
×
3540
                addressTypeIPv6:   {},
×
3541
                addressTypeTorV2:  {},
×
3542
                addressTypeTorV3:  {},
×
3543
                addressTypeOpaque: {},
×
3544
        }
×
3545
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3546
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3547
        }
×
3548

3549
        for _, address := range addresses {
×
3550
                switch addr := address.(type) {
×
3551
                case *net.TCPAddr:
×
3552
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3553
                                addAddr(addressTypeIPv4, addr)
×
3554
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3555
                                addAddr(addressTypeIPv6, addr)
×
3556
                        } else {
×
3557
                                return fmt.Errorf("unhandled IP address: %v",
×
3558
                                        addr)
×
3559
                        }
×
3560

3561
                case *tor.OnionAddr:
×
3562
                        switch len(addr.OnionService) {
×
3563
                        case tor.V2Len:
×
3564
                                addAddr(addressTypeTorV2, addr)
×
3565
                        case tor.V3Len:
×
3566
                                addAddr(addressTypeTorV3, addr)
×
3567
                        default:
×
3568
                                return fmt.Errorf("invalid length for a tor " +
×
3569
                                        "address")
×
3570
                        }
3571

3572
                case *lnwire.OpaqueAddrs:
×
3573
                        addAddr(addressTypeOpaque, addr)
×
3574

3575
                default:
×
3576
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3577
                }
3578
        }
3579

3580
        // Any remaining entries in newAddresses are new addresses that need to
3581
        // be added to the database for the first time.
3582
        for addrType, addrList := range newAddresses {
×
3583
                for position, addr := range addrList {
×
3584
                        err := db.InsertNodeAddress(
×
3585
                                ctx, sqlc.InsertNodeAddressParams{
×
3586
                                        NodeID:   nodeID,
×
3587
                                        Type:     int16(addrType),
×
3588
                                        Address:  addr,
×
3589
                                        Position: int32(position),
×
3590
                                },
×
3591
                        )
×
3592
                        if err != nil {
×
3593
                                return fmt.Errorf("unable to insert "+
×
3594
                                        "node(%d) address(%v): %w", nodeID,
×
3595
                                        addr, err)
×
3596
                        }
×
3597
                }
3598
        }
3599

3600
        return nil
×
3601
}
3602

3603
// getNodeAddresses fetches the addresses for a node with the given public key.
3604
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3605
        []net.Addr, error) {
×
3606

×
3607
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
3608
        // are returned in the same order as they were inserted.
×
3609
        rows, err := db.GetNodeAddressesByPubKey(
×
3610
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
3611
                        Version: int16(ProtocolV1),
×
3612
                        PubKey:  nodePub,
×
3613
                },
×
3614
        )
×
3615
        if err != nil {
×
3616
                return false, nil, err
×
3617
        }
×
3618

3619
        // GetNodeAddressesByPubKey uses a left join so there should always be
3620
        // at least one row returned if the node exists even if it has no
3621
        // addresses.
3622
        if len(rows) == 0 {
×
3623
                return false, nil, nil
×
3624
        }
×
3625

3626
        addresses := make([]net.Addr, 0, len(rows))
×
3627
        for _, addr := range rows {
×
3628
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3629
                        continue
×
3630
                }
3631

3632
                address := addr.Address.String
×
3633

×
3634
                switch dbAddressType(addr.Type.Int16) {
×
3635
                case addressTypeIPv4:
×
3636
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3637
                        if err != nil {
×
3638
                                return false, nil, nil
×
3639
                        }
×
3640
                        tcp.IP = tcp.IP.To4()
×
3641

×
3642
                        addresses = append(addresses, tcp)
×
3643

3644
                case addressTypeIPv6:
×
3645
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3646
                        if err != nil {
×
3647
                                return false, nil, nil
×
3648
                        }
×
3649
                        addresses = append(addresses, tcp)
×
3650

3651
                case addressTypeTorV3, addressTypeTorV2:
×
3652
                        service, portStr, err := net.SplitHostPort(address)
×
3653
                        if err != nil {
×
3654
                                return false, nil, fmt.Errorf("unable to "+
×
3655
                                        "split tor v3 address: %v",
×
3656
                                        addr.Address)
×
3657
                        }
×
3658

3659
                        port, err := strconv.Atoi(portStr)
×
3660
                        if err != nil {
×
3661
                                return false, nil, err
×
3662
                        }
×
3663

3664
                        addresses = append(addresses, &tor.OnionAddr{
×
3665
                                OnionService: service,
×
3666
                                Port:         port,
×
3667
                        })
×
3668

3669
                case addressTypeOpaque:
×
3670
                        opaque, err := hex.DecodeString(address)
×
3671
                        if err != nil {
×
3672
                                return false, nil, fmt.Errorf("unable to "+
×
3673
                                        "decode opaque address: %v", addr)
×
3674
                        }
×
3675

3676
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3677
                                Payload: opaque,
×
3678
                        })
×
3679

3680
                default:
×
3681
                        return false, nil, fmt.Errorf("unknown address "+
×
3682
                                "type: %v", addr.Type)
×
3683
                }
3684
        }
3685

3686
        // If we have no addresses, then we'll return nil instead of an
3687
        // empty slice.
3688
        if len(addresses) == 0 {
×
3689
                addresses = nil
×
3690
        }
×
3691

3692
        return true, addresses, nil
×
3693
}
3694

3695
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3696
// database. This includes updating any existing types, inserting any new types,
3697
// and deleting any types that are no longer present.
3698
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3699
        nodeID int64, extraFields map[uint64][]byte) error {
×
3700

×
3701
        // Get any existing extra signed fields for the node.
×
3702
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3703
        if err != nil {
×
3704
                return err
×
3705
        }
×
3706

3707
        // Make a lookup map of the existing field types so that we can use it
3708
        // to keep track of any fields we should delete.
3709
        m := make(map[uint64]bool)
×
3710
        for _, field := range existingFields {
×
3711
                m[uint64(field.Type)] = true
×
3712
        }
×
3713

3714
        // For all the new fields, we'll upsert them and remove them from the
3715
        // map of existing fields.
3716
        for tlvType, value := range extraFields {
×
3717
                err = db.UpsertNodeExtraType(
×
3718
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3719
                                NodeID: nodeID,
×
3720
                                Type:   int64(tlvType),
×
3721
                                Value:  value,
×
3722
                        },
×
3723
                )
×
3724
                if err != nil {
×
3725
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3726
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3727
                }
×
3728

3729
                // Remove the field from the map of existing fields if it was
3730
                // present.
3731
                delete(m, tlvType)
×
3732
        }
3733

3734
        // For all the fields that are left in the map of existing fields, we'll
3735
        // delete them as they are no longer present in the new set of fields.
3736
        for tlvType := range m {
×
3737
                err = db.DeleteExtraNodeType(
×
3738
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3739
                                NodeID: nodeID,
×
3740
                                Type:   int64(tlvType),
×
3741
                        },
×
3742
                )
×
3743
                if err != nil {
×
3744
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3745
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3746
                }
×
3747
        }
3748

3749
        return nil
×
3750
}
3751

3752
// srcNodeInfo holds the information about the source node of the graph.
3753
type srcNodeInfo struct {
3754
        // id is the DB level ID of the source node entry in the "nodes" table.
3755
        id int64
3756

3757
        // pub is the public key of the source node.
3758
        pub route.Vertex
3759
}
3760

3761
// sourceNode returns the DB node ID and pub key of the source node for the
3762
// specified protocol version.
3763
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3764
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3765

×
3766
        s.srcNodeMu.Lock()
×
3767
        defer s.srcNodeMu.Unlock()
×
3768

×
3769
        // If we already have the source node ID and pub key cached, then
×
3770
        // return them.
×
3771
        if info, ok := s.srcNodes[version]; ok {
×
3772
                return info.id, info.pub, nil
×
3773
        }
×
3774

3775
        var pubKey route.Vertex
×
3776

×
3777
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3778
        if err != nil {
×
3779
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3780
                        err)
×
3781
        }
×
3782

3783
        if len(nodes) == 0 {
×
3784
                return 0, pubKey, ErrSourceNodeNotSet
×
3785
        } else if len(nodes) > 1 {
×
3786
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3787
                        "protocol %s found", version)
×
3788
        }
×
3789

3790
        copy(pubKey[:], nodes[0].PubKey)
×
3791

×
3792
        s.srcNodes[version] = &srcNodeInfo{
×
3793
                id:  nodes[0].NodeID,
×
3794
                pub: pubKey,
×
3795
        }
×
3796

×
3797
        return nodes[0].NodeID, pubKey, nil
×
3798
}
3799

3800
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3801
// This then produces a map from TLV type to value. If the input is not a
3802
// valid TLV stream, then an error is returned.
3803
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3804
        r := bytes.NewReader(data)
×
3805

×
3806
        tlvStream, err := tlv.NewStream()
×
3807
        if err != nil {
×
3808
                return nil, err
×
3809
        }
×
3810

3811
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3812
        // pass it into the P2P decoding variant.
3813
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3814
        if err != nil {
×
3815
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3816
        }
×
3817
        if len(parsedTypes) == 0 {
×
3818
                return nil, nil
×
3819
        }
×
3820

3821
        records := make(map[uint64][]byte)
×
3822
        for k, v := range parsedTypes {
×
3823
                records[uint64(k)] = v
×
3824
        }
×
3825

3826
        return records, nil
×
3827
}
3828

3829
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3830
// channel.
3831
type dbChanInfo struct {
3832
        channelID int64
3833
        node1ID   int64
3834
        node2ID   int64
3835
}
3836

3837
// insertChannel inserts a new channel record into the database.
3838
func insertChannel(ctx context.Context, db SQLQueries,
3839
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3840

×
3841
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3842

×
3843
        // Make sure that the channel doesn't already exist. We do this
×
3844
        // explicitly instead of relying on catching a unique constraint error
×
3845
        // because relying on SQL to throw that error would abort the entire
×
3846
        // batch of transactions.
×
3847
        _, err := db.GetChannelBySCID(
×
3848
                ctx, sqlc.GetChannelBySCIDParams{
×
3849
                        Scid:    chanIDB,
×
3850
                        Version: int16(ProtocolV1),
×
3851
                },
×
3852
        )
×
3853
        if err == nil {
×
3854
                return nil, ErrEdgeAlreadyExist
×
3855
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3856
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3857
        }
×
3858

3859
        // Make sure that at least a "shell" entry for each node is present in
3860
        // the nodes table.
3861
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3862
        if err != nil {
×
3863
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3864
        }
×
3865

3866
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3867
        if err != nil {
×
3868
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3869
        }
×
3870

3871
        var capacity sql.NullInt64
×
3872
        if edge.Capacity != 0 {
×
3873
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3874
        }
×
3875

3876
        createParams := sqlc.CreateChannelParams{
×
3877
                Version:     int16(ProtocolV1),
×
3878
                Scid:        chanIDB,
×
3879
                NodeID1:     node1DBID,
×
3880
                NodeID2:     node2DBID,
×
3881
                Outpoint:    edge.ChannelPoint.String(),
×
3882
                Capacity:    capacity,
×
3883
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3884
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3885
        }
×
3886

×
3887
        if edge.AuthProof != nil {
×
3888
                proof := edge.AuthProof
×
3889

×
3890
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3891
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3892
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3893
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3894
        }
×
3895

3896
        // Insert the new channel record.
3897
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3898
        if err != nil {
×
3899
                return nil, err
×
3900
        }
×
3901

3902
        // Insert any channel features.
3903
        for feature := range edge.Features.Features() {
×
3904
                err = db.InsertChannelFeature(
×
3905
                        ctx, sqlc.InsertChannelFeatureParams{
×
3906
                                ChannelID:  dbChanID,
×
3907
                                FeatureBit: int32(feature),
×
3908
                        },
×
3909
                )
×
3910
                if err != nil {
×
3911
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
3912
                                "feature(%v): %w", dbChanID, feature, err)
×
3913
                }
×
3914
        }
3915

3916
        // Finally, insert any extra TLV fields in the channel announcement.
3917
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3918
        if err != nil {
×
3919
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
3920
                        "data: %w", err)
×
3921
        }
×
3922

3923
        for tlvType, value := range extra {
×
3924
                err := db.CreateChannelExtraType(
×
3925
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3926
                                ChannelID: dbChanID,
×
3927
                                Type:      int64(tlvType),
×
3928
                                Value:     value,
×
3929
                        },
×
3930
                )
×
3931
                if err != nil {
×
3932
                        return nil, fmt.Errorf("unable to upsert "+
×
3933
                                "channel(%d) extra signed field(%v): %w",
×
3934
                                edge.ChannelID, tlvType, err)
×
3935
                }
×
3936
        }
3937

3938
        return &dbChanInfo{
×
3939
                channelID: dbChanID,
×
3940
                node1ID:   node1DBID,
×
3941
                node2ID:   node2DBID,
×
3942
        }, nil
×
3943
}
3944

3945
// maybeCreateShellNode checks if a shell node entry exists for the
3946
// given public key. If it does not exist, then a new shell node entry is
3947
// created. The ID of the node is returned. A shell node only has a protocol
3948
// version and public key persisted.
3949
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3950
        pubKey route.Vertex) (int64, error) {
×
3951

×
3952
        dbNode, err := db.GetNodeByPubKey(
×
3953
                ctx, sqlc.GetNodeByPubKeyParams{
×
3954
                        PubKey:  pubKey[:],
×
3955
                        Version: int16(ProtocolV1),
×
3956
                },
×
3957
        )
×
3958
        // The node exists. Return the ID.
×
3959
        if err == nil {
×
3960
                return dbNode.ID, nil
×
3961
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3962
                return 0, err
×
3963
        }
×
3964

3965
        // Otherwise, the node does not exist, so we create a shell entry for
3966
        // it.
3967
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3968
                Version: int16(ProtocolV1),
×
3969
                PubKey:  pubKey[:],
×
3970
        })
×
3971
        if err != nil {
×
3972
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3973
        }
×
3974

3975
        return id, nil
×
3976
}
3977

3978
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3979
// the database. This includes deleting any existing types and then inserting
3980
// the new types.
3981
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3982
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3983

×
3984
        // Delete all existing extra signed fields for the channel policy.
×
3985
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3986
        if err != nil {
×
3987
                return fmt.Errorf("unable to delete "+
×
3988
                        "existing policy extra signed fields for policy %d: %w",
×
3989
                        chanPolicyID, err)
×
3990
        }
×
3991

3992
        // Insert all new extra signed fields for the channel policy.
3993
        for tlvType, value := range extraFields {
×
3994
                err = db.InsertChanPolicyExtraType(
×
3995
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3996
                                ChannelPolicyID: chanPolicyID,
×
3997
                                Type:            int64(tlvType),
×
3998
                                Value:           value,
×
3999
                        },
×
4000
                )
×
4001
                if err != nil {
×
4002
                        return fmt.Errorf("unable to insert "+
×
4003
                                "channel_policy(%d) extra signed field(%v): %w",
×
4004
                                chanPolicyID, tlvType, err)
×
4005
                }
×
4006
        }
4007

4008
        return nil
×
4009
}
4010

4011
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4012
// provided dbChanRow and also fetches any other required information
4013
// to construct the edge info.
4014
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
4015
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1,
4016
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4017

×
4018
        if dbChan.Version != int16(ProtocolV1) {
×
4019
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4020
                        dbChan.Version)
×
4021
        }
×
4022

4023
        fv, extras, err := getChanFeaturesAndExtras(
×
4024
                ctx, db, dbChanID,
×
4025
        )
×
4026
        if err != nil {
×
4027
                return nil, err
×
4028
        }
×
4029

4030
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4031
        if err != nil {
×
4032
                return nil, err
×
4033
        }
×
4034

4035
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4036
        if err != nil {
×
4037
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4038
                        "fields: %w", err)
×
4039
        }
×
4040
        if recs == nil {
×
4041
                recs = make([]byte, 0)
×
4042
        }
×
4043

4044
        var btcKey1, btcKey2 route.Vertex
×
4045
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4046
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4047

×
4048
        channel := &models.ChannelEdgeInfo{
×
4049
                ChainHash:        chain,
×
4050
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4051
                NodeKey1Bytes:    node1,
×
4052
                NodeKey2Bytes:    node2,
×
4053
                BitcoinKey1Bytes: btcKey1,
×
4054
                BitcoinKey2Bytes: btcKey2,
×
4055
                ChannelPoint:     *op,
×
4056
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4057
                Features:         fv,
×
4058
                ExtraOpaqueData:  recs,
×
4059
        }
×
4060

×
4061
        // We always set all the signatures at the same time, so we can
×
4062
        // safely check if one signature is present to determine if we have the
×
4063
        // rest of the signatures for the auth proof.
×
4064
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4065
                channel.AuthProof = &models.ChannelAuthProof{
×
4066
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4067
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4068
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4069
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4070
                }
×
4071
        }
×
4072

4073
        return channel, nil
×
4074
}
4075

4076
// buildNodeVertices is a helper that converts raw node public keys
4077
// into route.Vertex instances.
4078
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4079
        route.Vertex, error) {
×
4080

×
4081
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4082
        if err != nil {
×
4083
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4084
                        "create vertex from node1 pubkey: %w", err)
×
4085
        }
×
4086

4087
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4088
        if err != nil {
×
4089
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4090
                        "create vertex from node2 pubkey: %w", err)
×
4091
        }
×
4092

4093
        return node1Vertex, node2Vertex, nil
×
4094
}
4095

4096
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
4097
// for a channel with the given ID.
4098
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
4099
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
4100

×
4101
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
4102
        if err != nil {
×
4103
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
4104
                        "features and extras: %w", err)
×
4105
        }
×
4106

4107
        var (
×
4108
                fv     = lnwire.EmptyFeatureVector()
×
4109
                extras = make(map[uint64][]byte)
×
4110
        )
×
4111
        for _, row := range rows {
×
4112
                if row.IsFeature {
×
4113
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
4114

×
4115
                        continue
×
4116
                }
4117

4118
                tlvType, ok := row.ExtraKey.(int64)
×
4119
                if !ok {
×
4120
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4121
                                "TLV type: %T", row.ExtraKey)
×
4122
                }
×
4123

4124
                valueBytes, ok := row.Value.([]byte)
×
4125
                if !ok {
×
4126
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4127
                                "Value: %T", row.Value)
×
4128
                }
×
4129

4130
                extras[uint64(tlvType)] = valueBytes
×
4131
        }
4132

4133
        return fv, extras, nil
×
4134
}
4135

4136
// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves
4137
// all the extra info required to build the complete models.ChannelEdgePolicy
4138
// types. It returns two policies, which may be nil if the provided
4139
// sqlc.ChannelPolicy records are nil.
4140
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4141
        dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1,
4142
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4143
        *models.ChannelEdgePolicy, error) {
×
4144

×
4145
        if dbPol1 == nil && dbPol2 == nil {
×
4146
                return nil, nil, nil
×
4147
        }
×
4148

4149
        var (
×
4150
                policy1ID int64
×
4151
                policy2ID int64
×
4152
        )
×
4153
        if dbPol1 != nil {
×
4154
                policy1ID = dbPol1.ID
×
4155
        }
×
4156
        if dbPol2 != nil {
×
4157
                policy2ID = dbPol2.ID
×
4158
        }
×
4159
        rows, err := db.GetChannelPolicyExtraTypes(
×
4160
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
4161
                        ID:   policy1ID,
×
4162
                        ID_2: policy2ID,
×
4163
                },
×
4164
        )
×
4165
        if err != nil {
×
4166
                return nil, nil, err
×
4167
        }
×
4168

4169
        var (
×
4170
                dbPol1Extras = make(map[uint64][]byte)
×
4171
                dbPol2Extras = make(map[uint64][]byte)
×
4172
        )
×
4173
        for _, row := range rows {
×
4174
                switch row.PolicyID {
×
4175
                case policy1ID:
×
4176
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
4177
                case policy2ID:
×
4178
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
4179
                default:
×
4180
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
4181
                                "in row: %v", row.PolicyID, row)
×
4182
                }
4183
        }
4184

4185
        var pol1, pol2 *models.ChannelEdgePolicy
×
4186
        if dbPol1 != nil {
×
4187
                pol1, err = buildChanPolicy(
×
4188
                        *dbPol1, channelID, dbPol1Extras, node2,
×
4189
                )
×
4190
                if err != nil {
×
4191
                        return nil, nil, err
×
4192
                }
×
4193
        }
4194
        if dbPol2 != nil {
×
4195
                pol2, err = buildChanPolicy(
×
4196
                        *dbPol2, channelID, dbPol2Extras, node1,
×
4197
                )
×
4198
                if err != nil {
×
4199
                        return nil, nil, err
×
4200
                }
×
4201
        }
4202

4203
        return pol1, pol2, nil
×
4204
}
4205

4206
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4207
// provided sqlc.ChannelPolicy and other required information.
4208
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
4209
        extras map[uint64][]byte,
4210
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4211

×
4212
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4213
        if err != nil {
×
4214
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4215
                        "fields: %w", err)
×
4216
        }
×
4217

4218
        var inboundFee fn.Option[lnwire.Fee]
×
4219
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4220
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4221

×
4222
                inboundFee = fn.Some(lnwire.Fee{
×
4223
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4224
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4225
                })
×
4226
        }
×
4227

4228
        return &models.ChannelEdgePolicy{
×
4229
                SigBytes:  dbPolicy.Signature,
×
4230
                ChannelID: channelID,
×
4231
                LastUpdate: time.Unix(
×
4232
                        dbPolicy.LastUpdate.Int64, 0,
×
4233
                ),
×
4234
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4235
                        dbPolicy.MessageFlags,
×
4236
                ),
×
4237
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4238
                        dbPolicy.ChannelFlags,
×
4239
                ),
×
4240
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4241
                MinHTLC: lnwire.MilliSatoshi(
×
4242
                        dbPolicy.MinHtlcMsat,
×
4243
                ),
×
4244
                MaxHTLC: lnwire.MilliSatoshi(
×
4245
                        dbPolicy.MaxHtlcMsat.Int64,
×
4246
                ),
×
4247
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4248
                        dbPolicy.BaseFeeMsat,
×
4249
                ),
×
4250
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4251
                ToNode:                    toNode,
×
4252
                InboundFee:                inboundFee,
×
4253
                ExtraOpaqueData:           recs,
×
4254
        }, nil
×
4255
}
4256

4257
// buildNodes builds the models.LightningNode instances for the
4258
// given row which is expected to be a sqlc type that contains node information.
4259
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4260
        dbNode2 sqlc.Node) (*models.LightningNode, *models.LightningNode,
4261
        error) {
×
4262

×
4263
        node1, err := buildNode(ctx, db, &dbNode1)
×
4264
        if err != nil {
×
4265
                return nil, nil, err
×
4266
        }
×
4267

4268
        node2, err := buildNode(ctx, db, &dbNode2)
×
4269
        if err != nil {
×
4270
                return nil, nil, err
×
4271
        }
×
4272

4273
        return node1, node2, nil
×
4274
}
4275

4276
// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give
4277
// row which is expected to be a sqlc type that contains channel policy
4278
// information. It returns two policies, which may be nil if the policy
4279
// information is not present in the row.
4280
//
4281
//nolint:ll,dupl,funlen
4282
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
4283
        error) {
×
4284

×
4285
        var policy1, policy2 *sqlc.ChannelPolicy
×
4286
        switch r := row.(type) {
×
NEW
4287
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
NEW
4288
                if r.Policy1ID.Valid {
×
NEW
4289
                        policy1 = &sqlc.ChannelPolicy{
×
NEW
4290
                                ID:                      r.Policy1ID.Int64,
×
NEW
4291
                                Version:                 r.Policy1Version.Int16,
×
NEW
4292
                                ChannelID:               r.Channel.ID,
×
NEW
4293
                                NodeID:                  r.Policy1NodeID.Int64,
×
NEW
4294
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
4295
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
4296
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
4297
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
4298
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
4299
                                LastUpdate:              r.Policy1LastUpdate,
×
NEW
4300
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
4301
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
4302
                                Disabled:                r.Policy1Disabled,
×
NEW
4303
                                MessageFlags:            r.Policy1MessageFlags,
×
NEW
4304
                                ChannelFlags:            r.Policy1ChannelFlags,
×
NEW
4305
                                Signature:               r.Policy1Signature,
×
NEW
4306
                        }
×
NEW
4307
                }
×
NEW
4308
                if r.Policy2ID.Valid {
×
NEW
4309
                        policy2 = &sqlc.ChannelPolicy{
×
NEW
4310
                                ID:                      r.Policy2ID.Int64,
×
NEW
4311
                                Version:                 r.Policy2Version.Int16,
×
NEW
4312
                                ChannelID:               r.Channel.ID,
×
NEW
4313
                                NodeID:                  r.Policy2NodeID.Int64,
×
NEW
4314
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
4315
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
4316
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
4317
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
4318
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
4319
                                LastUpdate:              r.Policy2LastUpdate,
×
NEW
4320
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
4321
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
4322
                                Disabled:                r.Policy2Disabled,
×
NEW
4323
                                MessageFlags:            r.Policy2MessageFlags,
×
NEW
4324
                                ChannelFlags:            r.Policy2ChannelFlags,
×
NEW
4325
                                Signature:               r.Policy2Signature,
×
NEW
4326
                        }
×
NEW
4327
                }
×
4328

NEW
4329
                return policy1, policy2, nil
×
4330

4331
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4332
                if r.Policy1ID.Valid {
×
4333
                        policy1 = &sqlc.ChannelPolicy{
×
4334
                                ID:                      r.Policy1ID.Int64,
×
4335
                                Version:                 r.Policy1Version.Int16,
×
4336
                                ChannelID:               r.Channel.ID,
×
4337
                                NodeID:                  r.Policy1NodeID.Int64,
×
4338
                                Timelock:                r.Policy1Timelock.Int32,
×
4339
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4340
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4341
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4342
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4343
                                LastUpdate:              r.Policy1LastUpdate,
×
4344
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4345
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4346
                                Disabled:                r.Policy1Disabled,
×
4347
                                MessageFlags:            r.Policy1MessageFlags,
×
4348
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4349
                                Signature:               r.Policy1Signature,
×
4350
                        }
×
4351
                }
×
4352
                if r.Policy2ID.Valid {
×
4353
                        policy2 = &sqlc.ChannelPolicy{
×
4354
                                ID:                      r.Policy2ID.Int64,
×
4355
                                Version:                 r.Policy2Version.Int16,
×
4356
                                ChannelID:               r.Channel.ID,
×
4357
                                NodeID:                  r.Policy2NodeID.Int64,
×
4358
                                Timelock:                r.Policy2Timelock.Int32,
×
4359
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4360
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4361
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4362
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4363
                                LastUpdate:              r.Policy2LastUpdate,
×
4364
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4365
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4366
                                Disabled:                r.Policy2Disabled,
×
4367
                                MessageFlags:            r.Policy2MessageFlags,
×
4368
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4369
                                Signature:               r.Policy2Signature,
×
4370
                        }
×
4371
                }
×
4372

4373
                return policy1, policy2, nil
×
4374

4375
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4376
                if r.Policy1ID.Valid {
×
4377
                        policy1 = &sqlc.ChannelPolicy{
×
4378
                                ID:                      r.Policy1ID.Int64,
×
4379
                                Version:                 r.Policy1Version.Int16,
×
4380
                                ChannelID:               r.Channel.ID,
×
4381
                                NodeID:                  r.Policy1NodeID.Int64,
×
4382
                                Timelock:                r.Policy1Timelock.Int32,
×
4383
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4384
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4385
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4386
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4387
                                LastUpdate:              r.Policy1LastUpdate,
×
4388
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4389
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4390
                                Disabled:                r.Policy1Disabled,
×
4391
                                MessageFlags:            r.Policy1MessageFlags,
×
4392
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4393
                                Signature:               r.Policy1Signature,
×
4394
                        }
×
4395
                }
×
4396
                if r.Policy2ID.Valid {
×
4397
                        policy2 = &sqlc.ChannelPolicy{
×
4398
                                ID:                      r.Policy2ID.Int64,
×
4399
                                Version:                 r.Policy2Version.Int16,
×
4400
                                ChannelID:               r.Channel.ID,
×
4401
                                NodeID:                  r.Policy2NodeID.Int64,
×
4402
                                Timelock:                r.Policy2Timelock.Int32,
×
4403
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4404
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4405
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4406
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4407
                                LastUpdate:              r.Policy2LastUpdate,
×
4408
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4409
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4410
                                Disabled:                r.Policy2Disabled,
×
4411
                                MessageFlags:            r.Policy2MessageFlags,
×
4412
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4413
                                Signature:               r.Policy2Signature,
×
4414
                        }
×
4415
                }
×
4416

4417
                return policy1, policy2, nil
×
4418

4419
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4420
                if r.Policy1ID.Valid {
×
4421
                        policy1 = &sqlc.ChannelPolicy{
×
4422
                                ID:                      r.Policy1ID.Int64,
×
4423
                                Version:                 r.Policy1Version.Int16,
×
4424
                                ChannelID:               r.Channel.ID,
×
4425
                                NodeID:                  r.Policy1NodeID.Int64,
×
4426
                                Timelock:                r.Policy1Timelock.Int32,
×
4427
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4428
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4429
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4430
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4431
                                LastUpdate:              r.Policy1LastUpdate,
×
4432
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4433
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4434
                                Disabled:                r.Policy1Disabled,
×
4435
                                MessageFlags:            r.Policy1MessageFlags,
×
4436
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4437
                                Signature:               r.Policy1Signature,
×
4438
                        }
×
4439
                }
×
4440
                if r.Policy2ID.Valid {
×
4441
                        policy2 = &sqlc.ChannelPolicy{
×
4442
                                ID:                      r.Policy2ID.Int64,
×
4443
                                Version:                 r.Policy2Version.Int16,
×
4444
                                ChannelID:               r.Channel.ID,
×
4445
                                NodeID:                  r.Policy2NodeID.Int64,
×
4446
                                Timelock:                r.Policy2Timelock.Int32,
×
4447
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4448
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4449
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4450
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4451
                                LastUpdate:              r.Policy2LastUpdate,
×
4452
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4453
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4454
                                Disabled:                r.Policy2Disabled,
×
4455
                                MessageFlags:            r.Policy2MessageFlags,
×
4456
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4457
                                Signature:               r.Policy2Signature,
×
4458
                        }
×
4459
                }
×
4460

4461
                return policy1, policy2, nil
×
4462

4463
        case sqlc.ListChannelsByNodeIDRow:
×
4464
                if r.Policy1ID.Valid {
×
4465
                        policy1 = &sqlc.ChannelPolicy{
×
4466
                                ID:                      r.Policy1ID.Int64,
×
4467
                                Version:                 r.Policy1Version.Int16,
×
4468
                                ChannelID:               r.Channel.ID,
×
4469
                                NodeID:                  r.Policy1NodeID.Int64,
×
4470
                                Timelock:                r.Policy1Timelock.Int32,
×
4471
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4472
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4473
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4474
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4475
                                LastUpdate:              r.Policy1LastUpdate,
×
4476
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4477
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4478
                                Disabled:                r.Policy1Disabled,
×
4479
                                MessageFlags:            r.Policy1MessageFlags,
×
4480
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4481
                                Signature:               r.Policy1Signature,
×
4482
                        }
×
4483
                }
×
4484
                if r.Policy2ID.Valid {
×
4485
                        policy2 = &sqlc.ChannelPolicy{
×
4486
                                ID:                      r.Policy2ID.Int64,
×
4487
                                Version:                 r.Policy2Version.Int16,
×
4488
                                ChannelID:               r.Channel.ID,
×
4489
                                NodeID:                  r.Policy2NodeID.Int64,
×
4490
                                Timelock:                r.Policy2Timelock.Int32,
×
4491
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4492
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4493
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4494
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4495
                                LastUpdate:              r.Policy2LastUpdate,
×
4496
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4497
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4498
                                Disabled:                r.Policy2Disabled,
×
4499
                                MessageFlags:            r.Policy2MessageFlags,
×
4500
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4501
                                Signature:               r.Policy2Signature,
×
4502
                        }
×
4503
                }
×
4504

4505
                return policy1, policy2, nil
×
4506

4507
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4508
                if r.Policy1ID.Valid {
×
4509
                        policy1 = &sqlc.ChannelPolicy{
×
4510
                                ID:                      r.Policy1ID.Int64,
×
4511
                                Version:                 r.Policy1Version.Int16,
×
4512
                                ChannelID:               r.Channel.ID,
×
4513
                                NodeID:                  r.Policy1NodeID.Int64,
×
4514
                                Timelock:                r.Policy1Timelock.Int32,
×
4515
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4516
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4517
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4518
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4519
                                LastUpdate:              r.Policy1LastUpdate,
×
4520
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4521
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4522
                                Disabled:                r.Policy1Disabled,
×
4523
                                MessageFlags:            r.Policy1MessageFlags,
×
4524
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4525
                                Signature:               r.Policy1Signature,
×
4526
                        }
×
4527
                }
×
4528
                if r.Policy2ID.Valid {
×
4529
                        policy2 = &sqlc.ChannelPolicy{
×
4530
                                ID:                      r.Policy2ID.Int64,
×
4531
                                Version:                 r.Policy2Version.Int16,
×
4532
                                ChannelID:               r.Channel.ID,
×
4533
                                NodeID:                  r.Policy2NodeID.Int64,
×
4534
                                Timelock:                r.Policy2Timelock.Int32,
×
4535
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4536
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4537
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4538
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4539
                                LastUpdate:              r.Policy2LastUpdate,
×
4540
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4541
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4542
                                Disabled:                r.Policy2Disabled,
×
4543
                                MessageFlags:            r.Policy2MessageFlags,
×
4544
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4545
                                Signature:               r.Policy2Signature,
×
4546
                        }
×
4547
                }
×
4548

4549
                return policy1, policy2, nil
×
4550
        default:
×
4551
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4552
                        "extractChannelPolicies: %T", r)
×
4553
        }
4554
}
4555

4556
// channelIDToBytes converts a channel ID (SCID) to a byte array
4557
// representation.
4558
func channelIDToBytes(channelID uint64) []byte {
×
4559
        var chanIDB [8]byte
×
4560
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4561

×
4562
        return chanIDB[:]
×
4563
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc