• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16313565012

16 Jul 2025 07:46AM UTC coverage: 67.205% (-0.1%) from 67.321%
16313565012

Pull #10081

github

web-flow
Merge bf32adb8a into 9059a4e7b
Pull Request #10081: graph/db: use `/*SLICE:<field_name>*/` to optimise various graph queries

0 of 379 new or added lines in 4 files covered. (0.0%)

99 existing lines in 24 files now uncovered.

135374 of 201433 relevant lines covered (67.21%)

21718.6 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
72
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
73
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
74

75
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
96
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
97
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
98
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
99
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
100
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
101
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
102
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
103
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
104
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
105
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
106
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
107
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
108
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
109
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
110
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
111
        DeleteChannel(ctx context.Context, id int64) error
112

113
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
114
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
115

116
        /*
117
                Channel Policy table queries.
118
        */
119
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
120
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
121
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
122

123
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
124
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
125
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
126

127
        /*
128
                Zombie index queries.
129
        */
130
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
131
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
132
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
133
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
134
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
135

136
        /*
137
                Prune log table queries.
138
        */
139
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
140
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
141
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
142
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
143

144
        /*
145
                Closed SCID table queries.
146
        */
147
        InsertClosedChannel(ctx context.Context, scid []byte) error
148
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
149
}
150

151
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
152
// database operations.
153
type BatchedSQLQueries interface {
154
        SQLQueries
155
        sqldb.BatchedTx[SQLQueries]
156
}
157

158
// SQLStore is an implementation of the V1Store interface that uses a SQL
159
// database as the backend.
160
type SQLStore struct {
161
        cfg *SQLStoreConfig
162
        db  BatchedSQLQueries
163

164
        // cacheMu guards all caches (rejectCache and chanCache). If
165
        // this mutex will be acquired at the same time as the DB mutex then
166
        // the cacheMu MUST be acquired first to prevent deadlock.
167
        cacheMu     sync.RWMutex
168
        rejectCache *rejectCache
169
        chanCache   *channelCache
170

171
        chanScheduler batch.Scheduler[SQLQueries]
172
        nodeScheduler batch.Scheduler[SQLQueries]
173

174
        srcNodes  map[ProtocolVersion]*srcNodeInfo
175
        srcNodeMu sync.Mutex
176
}
177

178
// A compile-time assertion to ensure that SQLStore implements the V1Store
179
// interface.
180
var _ V1Store = (*SQLStore)(nil)
181

182
// SQLStoreConfig holds the configuration for the SQLStore.
183
type SQLStoreConfig struct {
184
        // ChainHash is the genesis hash for the chain that all the gossip
185
        // messages in this store are aimed at.
186
        ChainHash chainhash.Hash
187

188
        // PaginationCfg is the configuration for paginated queries.
189
        PaginationCfg *sqldb.PagedQueryConfig
190
}
191

192
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
193
// storage backend.
194
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
195
        options ...StoreOptionModifier) (*SQLStore, error) {
×
196

×
197
        opts := DefaultOptions()
×
198
        for _, o := range options {
×
199
                o(opts)
×
200
        }
×
201

202
        if opts.NoMigration {
×
203
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
204
                        "supported for SQL stores")
×
205
        }
×
206

207
        s := &SQLStore{
×
208
                cfg:         cfg,
×
209
                db:          db,
×
210
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
211
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
212
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
213
        }
×
214

×
215
        s.chanScheduler = batch.NewTimeScheduler(
×
216
                db, &s.cacheMu, opts.BatchCommitInterval,
×
217
        )
×
218
        s.nodeScheduler = batch.NewTimeScheduler(
×
219
                db, nil, opts.BatchCommitInterval,
×
220
        )
×
221

×
222
        return s, nil
×
223
}
224

225
// AddLightningNode adds a vertex/node to the graph database. If the node is not
226
// in the database from before, this will add a new, unconnected one to the
227
// graph. If it is present from before, this will update that node's
228
// information.
229
//
230
// NOTE: part of the V1Store interface.
231
func (s *SQLStore) AddLightningNode(ctx context.Context,
232
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
233

×
234
        r := &batch.Request[SQLQueries]{
×
235
                Opts: batch.NewSchedulerOptions(opts...),
×
236
                Do: func(queries SQLQueries) error {
×
237
                        _, err := upsertNode(ctx, queries, node)
×
238
                        return err
×
239
                },
×
240
        }
241

242
        return s.nodeScheduler.Execute(ctx, r)
×
243
}
244

245
// FetchLightningNode attempts to look up a target node by its identity public
246
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
247
// returned.
248
//
249
// NOTE: part of the V1Store interface.
250
func (s *SQLStore) FetchLightningNode(ctx context.Context,
251
        pubKey route.Vertex) (*models.LightningNode, error) {
×
252

×
253
        var node *models.LightningNode
×
254
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
255
                var err error
×
256
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
257

×
258
                return err
×
259
        }, sqldb.NoOpReset)
×
260
        if err != nil {
×
261
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
262
        }
×
263

264
        return node, nil
×
265
}
266

267
// HasLightningNode determines if the graph has a vertex identified by the
268
// target node identity public key. If the node exists in the database, a
269
// timestamp of when the data for the node was lasted updated is returned along
270
// with a true boolean. Otherwise, an empty time.Time is returned with a false
271
// boolean.
272
//
273
// NOTE: part of the V1Store interface.
274
func (s *SQLStore) HasLightningNode(ctx context.Context,
275
        pubKey [33]byte) (time.Time, bool, error) {
×
276

×
277
        var (
×
278
                exists     bool
×
279
                lastUpdate time.Time
×
280
        )
×
281
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
282
                dbNode, err := db.GetNodeByPubKey(
×
283
                        ctx, sqlc.GetNodeByPubKeyParams{
×
284
                                Version: int16(ProtocolV1),
×
285
                                PubKey:  pubKey[:],
×
286
                        },
×
287
                )
×
288
                if errors.Is(err, sql.ErrNoRows) {
×
289
                        return nil
×
290
                } else if err != nil {
×
291
                        return fmt.Errorf("unable to fetch node: %w", err)
×
292
                }
×
293

294
                exists = true
×
295

×
296
                if dbNode.LastUpdate.Valid {
×
297
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
298
                }
×
299

300
                return nil
×
301
        }, sqldb.NoOpReset)
302
        if err != nil {
×
303
                return time.Time{}, false,
×
304
                        fmt.Errorf("unable to fetch node: %w", err)
×
305
        }
×
306

307
        return lastUpdate, exists, nil
×
308
}
309

310
// AddrsForNode returns all known addresses for the target node public key
311
// that the graph DB is aware of. The returned boolean indicates if the
312
// given node is unknown to the graph DB or not.
313
//
314
// NOTE: part of the V1Store interface.
315
func (s *SQLStore) AddrsForNode(ctx context.Context,
316
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
317

×
318
        var (
×
319
                addresses []net.Addr
×
320
                known     bool
×
321
        )
×
322
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
323
                var err error
×
324
                known, addresses, err = getNodeAddresses(
×
325
                        ctx, db, nodePub.SerializeCompressed(),
×
326
                )
×
327
                if err != nil {
×
328
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
329
                                err)
×
330
                }
×
331

332
                return nil
×
333
        }, sqldb.NoOpReset)
334
        if err != nil {
×
335
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
336
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
337
        }
×
338

339
        return known, addresses, nil
×
340
}
341

342
// DeleteLightningNode starts a new database transaction to remove a vertex/node
343
// from the database according to the node's public key.
344
//
345
// NOTE: part of the V1Store interface.
346
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
347
        pubKey route.Vertex) error {
×
348

×
349
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
350
                res, err := db.DeleteNodeByPubKey(
×
351
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
352
                                Version: int16(ProtocolV1),
×
353
                                PubKey:  pubKey[:],
×
354
                        },
×
355
                )
×
356
                if err != nil {
×
357
                        return err
×
358
                }
×
359

360
                rows, err := res.RowsAffected()
×
361
                if err != nil {
×
362
                        return err
×
363
                }
×
364

365
                if rows == 0 {
×
366
                        return ErrGraphNodeNotFound
×
367
                } else if rows > 1 {
×
368
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
369
                }
×
370

371
                return err
×
372
        }, sqldb.NoOpReset)
373
        if err != nil {
×
374
                return fmt.Errorf("unable to delete node: %w", err)
×
375
        }
×
376

377
        return nil
×
378
}
379

380
// FetchNodeFeatures returns the features of the given node. If no features are
381
// known for the node, an empty feature vector is returned.
382
//
383
// NOTE: this is part of the graphdb.NodeTraverser interface.
384
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
385
        *lnwire.FeatureVector, error) {
×
386

×
387
        ctx := context.TODO()
×
388

×
389
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
390
}
×
391

392
// DisabledChannelIDs returns the channel ids of disabled channels.
393
// A channel is disabled when two of the associated ChanelEdgePolicies
394
// have their disabled bit on.
395
//
396
// NOTE: part of the V1Store interface.
397
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
398
        var (
×
399
                ctx     = context.TODO()
×
400
                chanIDs []uint64
×
401
        )
×
402
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
403
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
404
                if err != nil {
×
405
                        return fmt.Errorf("unable to fetch disabled "+
×
406
                                "channels: %w", err)
×
407
                }
×
408

409
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
410

×
411
                return nil
×
412
        }, sqldb.NoOpReset)
413
        if err != nil {
×
414
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
415
                        err)
×
416
        }
×
417

418
        return chanIDs, nil
×
419
}
420

421
// LookupAlias attempts to return the alias as advertised by the target node.
422
//
423
// NOTE: part of the V1Store interface.
424
func (s *SQLStore) LookupAlias(ctx context.Context,
425
        pub *btcec.PublicKey) (string, error) {
×
426

×
427
        var alias string
×
428
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
429
                dbNode, err := db.GetNodeByPubKey(
×
430
                        ctx, sqlc.GetNodeByPubKeyParams{
×
431
                                Version: int16(ProtocolV1),
×
432
                                PubKey:  pub.SerializeCompressed(),
×
433
                        },
×
434
                )
×
435
                if errors.Is(err, sql.ErrNoRows) {
×
436
                        return ErrNodeAliasNotFound
×
437
                } else if err != nil {
×
438
                        return fmt.Errorf("unable to fetch node: %w", err)
×
439
                }
×
440

441
                if !dbNode.Alias.Valid {
×
442
                        return ErrNodeAliasNotFound
×
443
                }
×
444

445
                alias = dbNode.Alias.String
×
446

×
447
                return nil
×
448
        }, sqldb.NoOpReset)
449
        if err != nil {
×
450
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
451
        }
×
452

453
        return alias, nil
×
454
}
455

456
// SourceNode returns the source node of the graph. The source node is treated
457
// as the center node within a star-graph. This method may be used to kick off
458
// a path finding algorithm in order to explore the reachability of another
459
// node based off the source node.
460
//
461
// NOTE: part of the V1Store interface.
462
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
463
        error) {
×
464

×
465
        var node *models.LightningNode
×
466
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
467
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
468
                if err != nil {
×
469
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
470
                                err)
×
471
                }
×
472

473
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
474

×
475
                return err
×
476
        }, sqldb.NoOpReset)
477
        if err != nil {
×
478
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
479
        }
×
480

481
        return node, nil
×
482
}
483

484
// SetSourceNode sets the source node within the graph database. The source
485
// node is to be used as the center of a star-graph within path finding
486
// algorithms.
487
//
488
// NOTE: part of the V1Store interface.
489
func (s *SQLStore) SetSourceNode(ctx context.Context,
490
        node *models.LightningNode) error {
×
491

×
492
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
493
                id, err := upsertNode(ctx, db, node)
×
494
                if err != nil {
×
495
                        return fmt.Errorf("unable to upsert source node: %w",
×
496
                                err)
×
497
                }
×
498

499
                // Make sure that if a source node for this version is already
500
                // set, then the ID is the same as the one we are about to set.
501
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
502
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
503
                        return fmt.Errorf("unable to fetch source node: %w",
×
504
                                err)
×
505
                } else if err == nil {
×
506
                        if dbSourceNodeID != id {
×
507
                                return fmt.Errorf("v1 source node already "+
×
508
                                        "set to a different node: %d vs %d",
×
509
                                        dbSourceNodeID, id)
×
510
                        }
×
511

512
                        return nil
×
513
                }
514

515
                return db.AddSourceNode(ctx, id)
×
516
        }, sqldb.NoOpReset)
517
}
518

519
// NodeUpdatesInHorizon returns all the known lightning node which have an
520
// update timestamp within the passed range. This method can be used by two
521
// nodes to quickly determine if they have the same set of up to date node
522
// announcements.
523
//
524
// NOTE: This is part of the V1Store interface.
525
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
526
        endTime time.Time) ([]models.LightningNode, error) {
×
527

×
528
        ctx := context.TODO()
×
529

×
530
        var nodes []models.LightningNode
×
531
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
532
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
533
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
534
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
535
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
536
                        },
×
537
                )
×
538
                if err != nil {
×
539
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
540
                }
×
541

542
                for _, dbNode := range dbNodes {
×
543
                        node, err := buildNode(ctx, db, &dbNode)
×
544
                        if err != nil {
×
545
                                return fmt.Errorf("unable to build node: %w",
×
546
                                        err)
×
547
                        }
×
548

549
                        nodes = append(nodes, *node)
×
550
                }
551

552
                return nil
×
553
        }, sqldb.NoOpReset)
554
        if err != nil {
×
555
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
556
        }
×
557

558
        return nodes, nil
×
559
}
560

561
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
562
// undirected edge from the two target nodes are created. The information stored
563
// denotes the static attributes of the channel, such as the channelID, the keys
564
// involved in creation of the channel, and the set of features that the channel
565
// supports. The chanPoint and chanID are used to uniquely identify the edge
566
// globally within the database.
567
//
568
// NOTE: part of the V1Store interface.
569
func (s *SQLStore) AddChannelEdge(ctx context.Context,
570
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
571

×
572
        var alreadyExists bool
×
573
        r := &batch.Request[SQLQueries]{
×
574
                Opts: batch.NewSchedulerOptions(opts...),
×
575
                Reset: func() {
×
576
                        alreadyExists = false
×
577
                },
×
578
                Do: func(tx SQLQueries) error {
×
579
                        _, err := insertChannel(ctx, tx, edge)
×
580

×
581
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
582
                        // succeed, but propagate the error via local state.
×
583
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
584
                                alreadyExists = true
×
585
                                return nil
×
586
                        }
×
587

588
                        return err
×
589
                },
590
                OnCommit: func(err error) error {
×
591
                        switch {
×
592
                        case err != nil:
×
593
                                return err
×
594
                        case alreadyExists:
×
595
                                return ErrEdgeAlreadyExist
×
596
                        default:
×
597
                                s.rejectCache.remove(edge.ChannelID)
×
598
                                s.chanCache.remove(edge.ChannelID)
×
599
                                return nil
×
600
                        }
601
                },
602
        }
603

604
        return s.chanScheduler.Execute(ctx, r)
×
605
}
606

607
// HighestChanID returns the "highest" known channel ID in the channel graph.
608
// This represents the "newest" channel from the PoV of the chain. This method
609
// can be used by peers to quickly determine if their graphs are in sync.
610
//
611
// NOTE: This is part of the V1Store interface.
612
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
613
        var highestChanID uint64
×
614
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
615
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
616
                if errors.Is(err, sql.ErrNoRows) {
×
617
                        return nil
×
618
                } else if err != nil {
×
619
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
620
                                err)
×
621
                }
×
622

623
                highestChanID = byteOrder.Uint64(chanID)
×
624

×
625
                return nil
×
626
        }, sqldb.NoOpReset)
627
        if err != nil {
×
628
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
629
        }
×
630

631
        return highestChanID, nil
×
632
}
633

634
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
635
// within the database for the referenced channel. The `flags` attribute within
636
// the ChannelEdgePolicy determines which of the directed edges are being
637
// updated. If the flag is 1, then the first node's information is being
638
// updated, otherwise it's the second node's information. The node ordering is
639
// determined by the lexicographical ordering of the identity public keys of the
640
// nodes on either side of the channel.
641
//
642
// NOTE: part of the V1Store interface.
643
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
644
        edge *models.ChannelEdgePolicy,
645
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
646

×
647
        var (
×
648
                isUpdate1    bool
×
649
                edgeNotFound bool
×
650
                from, to     route.Vertex
×
651
        )
×
652

×
653
        r := &batch.Request[SQLQueries]{
×
654
                Opts: batch.NewSchedulerOptions(opts...),
×
655
                Reset: func() {
×
656
                        isUpdate1 = false
×
657
                        edgeNotFound = false
×
658
                },
×
659
                Do: func(tx SQLQueries) error {
×
660
                        var err error
×
661
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
662
                                ctx, tx, edge,
×
663
                        )
×
664
                        if err != nil {
×
665
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
666
                        }
×
667

668
                        // Silence ErrEdgeNotFound so that the batch can
669
                        // succeed, but propagate the error via local state.
670
                        if errors.Is(err, ErrEdgeNotFound) {
×
671
                                edgeNotFound = true
×
672
                                return nil
×
673
                        }
×
674

675
                        return err
×
676
                },
677
                OnCommit: func(err error) error {
×
678
                        switch {
×
679
                        case err != nil:
×
680
                                return err
×
681
                        case edgeNotFound:
×
682
                                return ErrEdgeNotFound
×
683
                        default:
×
684
                                s.updateEdgeCache(edge, isUpdate1)
×
685
                                return nil
×
686
                        }
687
                },
688
        }
689

690
        err := s.chanScheduler.Execute(ctx, r)
×
691

×
692
        return from, to, err
×
693
}
694

695
// updateEdgeCache updates our reject and channel caches with the new
696
// edge policy information.
697
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
698
        isUpdate1 bool) {
×
699

×
700
        // If an entry for this channel is found in reject cache, we'll modify
×
701
        // the entry with the updated timestamp for the direction that was just
×
702
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
703
        // during the next query for this edge.
×
704
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
705
                if isUpdate1 {
×
706
                        entry.upd1Time = e.LastUpdate.Unix()
×
707
                } else {
×
708
                        entry.upd2Time = e.LastUpdate.Unix()
×
709
                }
×
710
                s.rejectCache.insert(e.ChannelID, entry)
×
711
        }
712

713
        // If an entry for this channel is found in channel cache, we'll modify
714
        // the entry with the updated policy for the direction that was just
715
        // written. If the edge doesn't exist, we'll defer loading the info and
716
        // policies and lazily read from disk during the next query.
717
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
718
                if isUpdate1 {
×
719
                        channel.Policy1 = e
×
720
                } else {
×
721
                        channel.Policy2 = e
×
722
                }
×
723
                s.chanCache.insert(e.ChannelID, channel)
×
724
        }
725
}
726

727
// ForEachSourceNodeChannel iterates through all channels of the source node,
728
// executing the passed callback on each. The call-back is provided with the
729
// channel's outpoint, whether we have a policy for the channel and the channel
730
// peer's node information.
731
//
732
// NOTE: part of the V1Store interface.
733
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
734
        cb func(chanPoint wire.OutPoint, havePolicy bool,
735
                otherNode *models.LightningNode) error, reset func()) error {
×
736

×
737
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
738
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
739
                if err != nil {
×
740
                        return fmt.Errorf("unable to fetch source node: %w",
×
741
                                err)
×
742
                }
×
743

744
                return forEachNodeChannel(
×
745
                        ctx, db, s.cfg.ChainHash, nodeID,
×
746
                        func(info *models.ChannelEdgeInfo,
×
747
                                outPolicy *models.ChannelEdgePolicy,
×
748
                                _ *models.ChannelEdgePolicy) error {
×
749

×
750
                                // Fetch the other node.
×
751
                                var (
×
752
                                        otherNodePub [33]byte
×
753
                                        node1        = info.NodeKey1Bytes
×
754
                                        node2        = info.NodeKey2Bytes
×
755
                                )
×
756
                                switch {
×
757
                                case bytes.Equal(node1[:], nodePub[:]):
×
758
                                        otherNodePub = node2
×
759
                                case bytes.Equal(node2[:], nodePub[:]):
×
760
                                        otherNodePub = node1
×
761
                                default:
×
762
                                        return fmt.Errorf("node not " +
×
763
                                                "participating in this channel")
×
764
                                }
765

766
                                _, otherNode, err := getNodeByPubKey(
×
767
                                        ctx, db, otherNodePub,
×
768
                                )
×
769
                                if err != nil {
×
770
                                        return fmt.Errorf("unable to fetch "+
×
771
                                                "other node(%x): %w",
×
772
                                                otherNodePub, err)
×
773
                                }
×
774

775
                                return cb(
×
776
                                        info.ChannelPoint, outPolicy != nil,
×
777
                                        otherNode,
×
778
                                )
×
779
                        },
780
                )
781
        }, reset)
782
}
783

784
// ForEachNode iterates through all the stored vertices/nodes in the graph,
785
// executing the passed callback with each node encountered. If the callback
786
// returns an error, then the transaction is aborted and the iteration stops
787
// early. Any operations performed on the NodeTx passed to the call-back are
788
// executed under the same read transaction and so, methods on the NodeTx object
789
// _MUST_ only be called from within the call-back.
790
//
791
// NOTE: part of the V1Store interface.
792
func (s *SQLStore) ForEachNode(ctx context.Context,
793
        cb func(tx NodeRTx) error, reset func()) error {
×
794

×
795
        var lastID int64 = 0
×
796
        handleNode := func(db SQLQueries, dbNode sqlc.GraphNode) error {
×
797
                node, err := buildNode(ctx, db, &dbNode)
×
798
                if err != nil {
×
799
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
800
                                dbNode.ID, err)
×
801
                }
×
802

803
                err = cb(
×
804
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
805
                )
×
806
                if err != nil {
×
807
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
808
                                dbNode.ID, err)
×
809
                }
×
810

811
                return nil
×
812
        }
813

814
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
815
                for {
×
816
                        nodes, err := db.ListNodesPaginated(
×
817
                                ctx, sqlc.ListNodesPaginatedParams{
×
818
                                        Version: int16(ProtocolV1),
×
819
                                        ID:      lastID,
×
820
                                        Limit:   pageSize,
×
821
                                },
×
822
                        )
×
823
                        if err != nil {
×
824
                                return fmt.Errorf("unable to fetch nodes: %w",
×
825
                                        err)
×
826
                        }
×
827

828
                        if len(nodes) == 0 {
×
829
                                break
×
830
                        }
831

832
                        for _, dbNode := range nodes {
×
833
                                err = handleNode(db, dbNode)
×
834
                                if err != nil {
×
835
                                        return err
×
836
                                }
×
837

838
                                lastID = dbNode.ID
×
839
                        }
840
                }
841

842
                return nil
×
843
        }, reset)
844
}
845

846
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
847
// SQLStore and a SQL transaction.
848
type sqlGraphNodeTx struct {
849
        db    SQLQueries
850
        id    int64
851
        node  *models.LightningNode
852
        chain chainhash.Hash
853
}
854

855
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
856
// interface.
857
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
858

859
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
860
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
861

×
862
        return &sqlGraphNodeTx{
×
863
                db:    db,
×
864
                chain: chain,
×
865
                id:    id,
×
866
                node:  node,
×
867
        }
×
868
}
×
869

870
// Node returns the raw information of the node.
871
//
872
// NOTE: This is a part of the NodeRTx interface.
873
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
874
        return s.node
×
875
}
×
876

877
// ForEachChannel can be used to iterate over the node's channels under the same
878
// transaction used to fetch the node.
879
//
880
// NOTE: This is a part of the NodeRTx interface.
881
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
882
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
883

×
884
        ctx := context.TODO()
×
885

×
886
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
887
}
×
888

889
// FetchNode fetches the node with the given pub key under the same transaction
890
// used to fetch the current node. The returned node is also a NodeRTx and any
891
// operations on that NodeRTx will also be done under the same transaction.
892
//
893
// NOTE: This is a part of the NodeRTx interface.
894
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
895
        ctx := context.TODO()
×
896

×
897
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
898
        if err != nil {
×
899
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
900
                        nodePub, err)
×
901
        }
×
902

903
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
904
}
905

906
// ForEachNodeDirectedChannel iterates through all channels of a given node,
907
// executing the passed callback on the directed edge representing the channel
908
// and its incoming policy. If the callback returns an error, then the iteration
909
// is halted with the error propagated back up to the caller.
910
//
911
// Unknown policies are passed into the callback as nil values.
912
//
913
// NOTE: this is part of the graphdb.NodeTraverser interface.
914
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
915
        cb func(channel *DirectedChannel) error, reset func()) error {
×
916

×
917
        var ctx = context.TODO()
×
918

×
919
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
920
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
921
        }, reset)
×
922
}
923

924
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
925
// graph, executing the passed callback with each node encountered. If the
926
// callback returns an error, then the transaction is aborted and the iteration
927
// stops early.
928
//
929
// NOTE: This is a part of the V1Store interface.
930
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
931
        cb func(route.Vertex, *lnwire.FeatureVector) error,
932
        reset func()) error {
×
933

×
934
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
935
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
936
                        nodePub route.Vertex) error {
×
937

×
938
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
939
                        if err != nil {
×
940
                                return fmt.Errorf("unable to fetch node "+
×
941
                                        "features: %w", err)
×
942
                        }
×
943

944
                        return cb(nodePub, features)
×
945
                })
946
        }, reset)
947
        if err != nil {
×
948
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
949
        }
×
950

951
        return nil
×
952
}
953

954
// ForEachNodeChannel iterates through all channels of the given node,
955
// executing the passed callback with an edge info structure and the policies
956
// of each end of the channel. The first edge policy is the outgoing edge *to*
957
// the connecting node, while the second is the incoming edge *from* the
958
// connecting node. If the callback returns an error, then the iteration is
959
// halted with the error propagated back up to the caller.
960
//
961
// Unknown policies are passed into the callback as nil values.
962
//
963
// NOTE: part of the V1Store interface.
964
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
965
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
966
                *models.ChannelEdgePolicy) error, reset func()) error {
×
967

×
968
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
969
                dbNode, err := db.GetNodeByPubKey(
×
970
                        ctx, sqlc.GetNodeByPubKeyParams{
×
971
                                Version: int16(ProtocolV1),
×
972
                                PubKey:  nodePub[:],
×
973
                        },
×
974
                )
×
975
                if errors.Is(err, sql.ErrNoRows) {
×
976
                        return nil
×
977
                } else if err != nil {
×
978
                        return fmt.Errorf("unable to fetch node: %w", err)
×
979
                }
×
980

981
                return forEachNodeChannel(
×
982
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
983
                )
×
984
        }, reset)
985
}
986

987
// ChanUpdatesInHorizon returns all the known channel edges which have at least
988
// one edge that has an update timestamp within the specified horizon.
989
//
990
// NOTE: This is part of the V1Store interface.
991
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
992
        endTime time.Time) ([]ChannelEdge, error) {
×
993

×
994
        s.cacheMu.Lock()
×
995
        defer s.cacheMu.Unlock()
×
996

×
997
        var (
×
998
                ctx = context.TODO()
×
999
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
1000
                // an additional map to keep track of the edges already seen to
×
1001
                // prevent re-adding it.
×
1002
                edgesSeen    = make(map[uint64]struct{})
×
1003
                edgesToCache = make(map[uint64]ChannelEdge)
×
1004
                edges        []ChannelEdge
×
1005
                hits         int
×
1006
        )
×
1007
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1008
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1009
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1010
                                Version:   int16(ProtocolV1),
×
1011
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1012
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1013
                        },
×
1014
                )
×
1015
                if err != nil {
×
1016
                        return err
×
1017
                }
×
1018

1019
                for _, row := range rows {
×
1020
                        // If we've already retrieved the info and policies for
×
1021
                        // this edge, then we can skip it as we don't need to do
×
1022
                        // so again.
×
1023
                        chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid)
×
1024
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1025
                                continue
×
1026
                        }
1027

1028
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1029
                                hits++
×
1030
                                edgesSeen[chanIDInt] = struct{}{}
×
1031
                                edges = append(edges, channel)
×
1032

×
1033
                                continue
×
1034
                        }
1035

1036
                        node1, node2, err := buildNodes(
×
1037
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
1038
                        )
×
1039
                        if err != nil {
×
1040
                                return err
×
1041
                        }
×
1042

1043
                        channel, err := getAndBuildEdgeInfo(
×
1044
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1045
                                row.GraphChannel, node1.PubKeyBytes,
×
1046
                                node2.PubKeyBytes,
×
1047
                        )
×
1048
                        if err != nil {
×
1049
                                return fmt.Errorf("unable to build channel "+
×
1050
                                        "info: %w", err)
×
1051
                        }
×
1052

1053
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1054
                        if err != nil {
×
1055
                                return fmt.Errorf("unable to extract channel "+
×
1056
                                        "policies: %w", err)
×
1057
                        }
×
1058

1059
                        p1, p2, err := getAndBuildChanPolicies(
×
1060
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1061
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1062
                        )
×
1063
                        if err != nil {
×
1064
                                return fmt.Errorf("unable to build channel "+
×
1065
                                        "policies: %w", err)
×
1066
                        }
×
1067

1068
                        edgesSeen[chanIDInt] = struct{}{}
×
1069
                        chanEdge := ChannelEdge{
×
1070
                                Info:    channel,
×
1071
                                Policy1: p1,
×
1072
                                Policy2: p2,
×
1073
                                Node1:   node1,
×
1074
                                Node2:   node2,
×
1075
                        }
×
1076
                        edges = append(edges, chanEdge)
×
1077
                        edgesToCache[chanIDInt] = chanEdge
×
1078
                }
1079

1080
                return nil
×
1081
        }, func() {
×
1082
                edgesSeen = make(map[uint64]struct{})
×
1083
                edgesToCache = make(map[uint64]ChannelEdge)
×
1084
                edges = nil
×
1085
        })
×
1086
        if err != nil {
×
1087
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1088
        }
×
1089

1090
        // Insert any edges loaded from disk into the cache.
1091
        for chanid, channel := range edgesToCache {
×
1092
                s.chanCache.insert(chanid, channel)
×
1093
        }
×
1094

1095
        if len(edges) > 0 {
×
1096
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1097
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1098
        } else {
×
1099
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1100
                        "horizon (%s, %s)", startTime, endTime)
×
1101
        }
×
1102

1103
        return edges, nil
×
1104
}
1105

1106
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1107
// data to the call-back.
1108
//
1109
// NOTE: The callback contents MUST not be modified.
1110
//
1111
// NOTE: part of the V1Store interface.
1112
func (s *SQLStore) ForEachNodeCached(ctx context.Context,
1113
        cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error,
1114
        reset func()) error {
×
1115

×
1116
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1117
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1118
                        nodePub route.Vertex) error {
×
1119

×
1120
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1121
                        if err != nil {
×
1122
                                return fmt.Errorf("unable to fetch "+
×
1123
                                        "node(id=%d) features: %w", nodeID, err)
×
1124
                        }
×
1125

1126
                        toNodeCallback := func() route.Vertex {
×
1127
                                return nodePub
×
1128
                        }
×
1129

1130
                        rows, err := db.ListChannelsByNodeID(
×
1131
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1132
                                        Version: int16(ProtocolV1),
×
1133
                                        NodeID1: nodeID,
×
1134
                                },
×
1135
                        )
×
1136
                        if err != nil {
×
1137
                                return fmt.Errorf("unable to fetch channels "+
×
1138
                                        "of node(id=%d): %w", nodeID, err)
×
1139
                        }
×
1140

1141
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1142
                        for _, row := range rows {
×
1143
                                node1, node2, err := buildNodeVertices(
×
1144
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1145
                                )
×
1146
                                if err != nil {
×
1147
                                        return err
×
1148
                                }
×
1149

1150
                                e, err := getAndBuildEdgeInfo(
×
1151
                                        ctx, db, s.cfg.ChainHash,
×
1152
                                        row.GraphChannel.ID, row.GraphChannel,
×
1153
                                        node1, node2,
×
1154
                                )
×
1155
                                if err != nil {
×
1156
                                        return fmt.Errorf("unable to build "+
×
1157
                                                "channel info: %w", err)
×
1158
                                }
×
1159

1160
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1161
                                        row,
×
1162
                                )
×
1163
                                if err != nil {
×
1164
                                        return fmt.Errorf("unable to "+
×
1165
                                                "extract channel "+
×
1166
                                                "policies: %w", err)
×
1167
                                }
×
1168

1169
                                p1, p2, err := getAndBuildChanPolicies(
×
1170
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1171
                                        node1, node2,
×
1172
                                )
×
1173
                                if err != nil {
×
1174
                                        return fmt.Errorf("unable to "+
×
1175
                                                "build channel policies: %w",
×
1176
                                                err)
×
1177
                                }
×
1178

1179
                                // Determine the outgoing and incoming policy
1180
                                // for this channel and node combo.
1181
                                outPolicy, inPolicy := p1, p2
×
1182
                                if p1 != nil && p1.ToNode == nodePub {
×
1183
                                        outPolicy, inPolicy = p2, p1
×
1184
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1185
                                        outPolicy, inPolicy = p2, p1
×
1186
                                }
×
1187

1188
                                var cachedInPolicy *models.CachedEdgePolicy
×
1189
                                if inPolicy != nil {
×
1190
                                        cachedInPolicy = models.NewCachedPolicy(
×
1191
                                                p2,
×
1192
                                        )
×
1193
                                        cachedInPolicy.ToNodePubKey =
×
1194
                                                toNodeCallback
×
1195
                                        cachedInPolicy.ToNodeFeatures =
×
1196
                                                features
×
1197
                                }
×
1198

1199
                                var inboundFee lnwire.Fee
×
1200
                                outPolicy.InboundFee.WhenSome(
×
1201
                                        func(fee lnwire.Fee) {
×
1202
                                                inboundFee = fee
×
1203
                                        },
×
1204
                                )
1205

1206
                                directedChannel := &DirectedChannel{
×
1207
                                        ChannelID: e.ChannelID,
×
1208
                                        IsNode1: nodePub ==
×
1209
                                                e.NodeKey1Bytes,
×
1210
                                        OtherNode:    e.NodeKey2Bytes,
×
1211
                                        Capacity:     e.Capacity,
×
1212
                                        OutPolicySet: p1 != nil,
×
1213
                                        InPolicy:     cachedInPolicy,
×
1214
                                        InboundFee:   inboundFee,
×
1215
                                }
×
1216

×
1217
                                if nodePub == e.NodeKey2Bytes {
×
1218
                                        directedChannel.OtherNode =
×
1219
                                                e.NodeKey1Bytes
×
1220
                                }
×
1221

1222
                                channels[e.ChannelID] = directedChannel
×
1223
                        }
1224

1225
                        return cb(nodePub, channels)
×
1226
                })
1227
        }, reset)
1228
}
1229

1230
// ForEachChannelCacheable iterates through all the channel edges stored
1231
// within the graph and invokes the passed callback for each edge. The
1232
// callback takes two edges as since this is a directed graph, both the
1233
// in/out edges are visited. If the callback returns an error, then the
1234
// transaction is aborted and the iteration stops early.
1235
//
1236
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1237
// pointer for that particular channel edge routing policy will be
1238
// passed into the callback.
1239
//
1240
// NOTE: this method is like ForEachChannel but fetches only the data
1241
// required for the graph cache.
1242
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1243
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1244
        reset func()) error {
×
1245

×
1246
        ctx := context.TODO()
×
1247

×
1248
        handleChannel := func(db SQLQueries,
×
1249
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1250

×
1251
                node1, node2, err := buildNodeVertices(
×
1252
                        row.Node1Pubkey, row.Node2Pubkey,
×
1253
                )
×
1254
                if err != nil {
×
1255
                        return err
×
1256
                }
×
1257

1258
                edge := buildCacheableChannelInfo(
×
1259
                        row.GraphChannel, node1, node2,
×
1260
                )
×
1261

×
1262
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1263
                if err != nil {
×
1264
                        return err
×
1265
                }
×
1266

1267
                var pol1, pol2 *models.CachedEdgePolicy
×
1268
                if dbPol1 != nil {
×
1269
                        policy1, err := buildChanPolicy(
×
1270
                                *dbPol1, edge.ChannelID, nil, node2,
×
1271
                        )
×
1272
                        if err != nil {
×
1273
                                return err
×
1274
                        }
×
1275

1276
                        pol1 = models.NewCachedPolicy(policy1)
×
1277
                }
1278
                if dbPol2 != nil {
×
1279
                        policy2, err := buildChanPolicy(
×
1280
                                *dbPol2, edge.ChannelID, nil, node1,
×
1281
                        )
×
1282
                        if err != nil {
×
1283
                                return err
×
1284
                        }
×
1285

1286
                        pol2 = models.NewCachedPolicy(policy2)
×
1287
                }
1288

1289
                if err := cb(edge, pol1, pol2); err != nil {
×
1290
                        return err
×
1291
                }
×
1292

1293
                return nil
×
1294
        }
1295

1296
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1297
                lastID := int64(-1)
×
1298
                for {
×
1299
                        //nolint:ll
×
1300
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1301
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1302
                                        Version: int16(ProtocolV1),
×
1303
                                        ID:      lastID,
×
1304
                                        Limit:   pageSize,
×
1305
                                },
×
1306
                        )
×
1307
                        if err != nil {
×
1308
                                return err
×
1309
                        }
×
1310

1311
                        if len(rows) == 0 {
×
1312
                                break
×
1313
                        }
1314

1315
                        for _, row := range rows {
×
1316
                                err := handleChannel(db, row)
×
1317
                                if err != nil {
×
1318
                                        return err
×
1319
                                }
×
1320

1321
                                lastID = row.GraphChannel.ID
×
1322
                        }
1323
                }
1324

1325
                return nil
×
1326
        }, reset)
1327
}
1328

1329
// ForEachChannel iterates through all the channel edges stored within the
1330
// graph and invokes the passed callback for each edge. The callback takes two
1331
// edges as since this is a directed graph, both the in/out edges are visited.
1332
// If the callback returns an error, then the transaction is aborted and the
1333
// iteration stops early.
1334
//
1335
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1336
// for that particular channel edge routing policy will be passed into the
1337
// callback.
1338
//
1339
// NOTE: part of the V1Store interface.
1340
func (s *SQLStore) ForEachChannel(ctx context.Context,
1341
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1342
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1343

×
1344
        handleChannel := func(db SQLQueries,
×
1345
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1346

×
1347
                node1, node2, err := buildNodeVertices(
×
1348
                        row.Node1Pubkey, row.Node2Pubkey,
×
1349
                )
×
1350
                if err != nil {
×
1351
                        return fmt.Errorf("unable to build node vertices: %w",
×
1352
                                err)
×
1353
                }
×
1354

1355
                edge, err := getAndBuildEdgeInfo(
×
1356
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1357
                        row.GraphChannel, node1, node2,
×
1358
                )
×
1359
                if err != nil {
×
1360
                        return fmt.Errorf("unable to build channel info: %w",
×
1361
                                err)
×
1362
                }
×
1363

1364
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1365
                if err != nil {
×
1366
                        return fmt.Errorf("unable to extract channel "+
×
1367
                                "policies: %w", err)
×
1368
                }
×
1369

1370
                p1, p2, err := getAndBuildChanPolicies(
×
1371
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1372
                )
×
1373
                if err != nil {
×
1374
                        return fmt.Errorf("unable to build channel "+
×
1375
                                "policies: %w", err)
×
1376
                }
×
1377

1378
                err = cb(edge, p1, p2)
×
1379
                if err != nil {
×
1380
                        return fmt.Errorf("callback failed for channel "+
×
1381
                                "id=%d: %w", edge.ChannelID, err)
×
1382
                }
×
1383

1384
                return nil
×
1385
        }
1386

1387
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1388
                lastID := int64(-1)
×
1389
                for {
×
1390
                        //nolint:ll
×
1391
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1392
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1393
                                        Version: int16(ProtocolV1),
×
1394
                                        ID:      lastID,
×
1395
                                        Limit:   pageSize,
×
1396
                                },
×
1397
                        )
×
1398
                        if err != nil {
×
1399
                                return err
×
1400
                        }
×
1401

1402
                        if len(rows) == 0 {
×
1403
                                break
×
1404
                        }
1405

1406
                        for _, row := range rows {
×
1407
                                err := handleChannel(db, row)
×
1408
                                if err != nil {
×
1409
                                        return err
×
1410
                                }
×
1411

1412
                                lastID = row.GraphChannel.ID
×
1413
                        }
1414
                }
1415

1416
                return nil
×
1417
        }, reset)
1418
}
1419

1420
// FilterChannelRange returns the channel ID's of all known channels which were
1421
// mined in a block height within the passed range. The channel IDs are grouped
1422
// by their common block height. This method can be used to quickly share with a
1423
// peer the set of channels we know of within a particular range to catch them
1424
// up after a period of time offline. If withTimestamps is true then the
1425
// timestamp info of the latest received channel update messages of the channel
1426
// will be included in the response.
1427
//
1428
// NOTE: This is part of the V1Store interface.
1429
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1430
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1431

×
1432
        var (
×
1433
                ctx       = context.TODO()
×
1434
                startSCID = &lnwire.ShortChannelID{
×
1435
                        BlockHeight: startHeight,
×
1436
                }
×
1437
                endSCID = lnwire.ShortChannelID{
×
1438
                        BlockHeight: endHeight,
×
1439
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1440
                        TxPosition:  math.MaxUint16,
×
1441
                }
×
1442
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1443
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1444
        )
×
1445

×
1446
        // 1) get all channels where channelID is between start and end chan ID.
×
1447
        // 2) skip if not public (ie, no channel_proof)
×
1448
        // 3) collect that channel.
×
1449
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1450
        //    and add those timestamps to the collected channel.
×
1451
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1452
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1453
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1454
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1455
                                StartScid: chanIDStart,
×
1456
                                EndScid:   chanIDEnd,
×
1457
                        },
×
1458
                )
×
1459
                if err != nil {
×
1460
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1461
                                err)
×
1462
                }
×
1463

1464
                for _, dbChan := range dbChans {
×
1465
                        cid := lnwire.NewShortChanIDFromInt(
×
1466
                                byteOrder.Uint64(dbChan.Scid),
×
1467
                        )
×
1468
                        chanInfo := NewChannelUpdateInfo(
×
1469
                                cid, time.Time{}, time.Time{},
×
1470
                        )
×
1471

×
1472
                        if !withTimestamps {
×
1473
                                channelsPerBlock[cid.BlockHeight] = append(
×
1474
                                        channelsPerBlock[cid.BlockHeight],
×
1475
                                        chanInfo,
×
1476
                                )
×
1477

×
1478
                                continue
×
1479
                        }
1480

1481
                        //nolint:ll
1482
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1483
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1484
                                        Version:   int16(ProtocolV1),
×
1485
                                        ChannelID: dbChan.ID,
×
1486
                                        NodeID:    dbChan.NodeID1,
×
1487
                                },
×
1488
                        )
×
1489
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1490
                                return fmt.Errorf("unable to fetch node1 "+
×
1491
                                        "policy: %w", err)
×
1492
                        } else if err == nil {
×
1493
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1494
                                        node1Policy.LastUpdate.Int64, 0,
×
1495
                                )
×
1496
                        }
×
1497

1498
                        //nolint:ll
1499
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1500
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1501
                                        Version:   int16(ProtocolV1),
×
1502
                                        ChannelID: dbChan.ID,
×
1503
                                        NodeID:    dbChan.NodeID2,
×
1504
                                },
×
1505
                        )
×
1506
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1507
                                return fmt.Errorf("unable to fetch node2 "+
×
1508
                                        "policy: %w", err)
×
1509
                        } else if err == nil {
×
1510
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1511
                                        node2Policy.LastUpdate.Int64, 0,
×
1512
                                )
×
1513
                        }
×
1514

1515
                        channelsPerBlock[cid.BlockHeight] = append(
×
1516
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1517
                        )
×
1518
                }
1519

1520
                return nil
×
1521
        }, func() {
×
1522
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1523
        })
×
1524
        if err != nil {
×
1525
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1526
        }
×
1527

1528
        if len(channelsPerBlock) == 0 {
×
1529
                return nil, nil
×
1530
        }
×
1531

1532
        // Return the channel ranges in ascending block height order.
1533
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1534
        slices.Sort(blocks)
×
1535

×
1536
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1537
                return BlockChannelRange{
×
1538
                        Height:   block,
×
1539
                        Channels: channelsPerBlock[block],
×
1540
                }
×
1541
        }), nil
×
1542
}
1543

1544
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1545
// zombie. This method is used on an ad-hoc basis, when channels need to be
1546
// marked as zombies outside the normal pruning cycle.
1547
//
1548
// NOTE: part of the V1Store interface.
1549
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1550
        pubKey1, pubKey2 [33]byte) error {
×
1551

×
1552
        ctx := context.TODO()
×
1553

×
1554
        s.cacheMu.Lock()
×
1555
        defer s.cacheMu.Unlock()
×
1556

×
1557
        chanIDB := channelIDToBytes(chanID)
×
1558

×
1559
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1560
                return db.UpsertZombieChannel(
×
1561
                        ctx, sqlc.UpsertZombieChannelParams{
×
1562
                                Version:  int16(ProtocolV1),
×
1563
                                Scid:     chanIDB,
×
1564
                                NodeKey1: pubKey1[:],
×
1565
                                NodeKey2: pubKey2[:],
×
1566
                        },
×
1567
                )
×
1568
        }, sqldb.NoOpReset)
×
1569
        if err != nil {
×
1570
                return fmt.Errorf("unable to upsert zombie channel "+
×
1571
                        "(channel_id=%d): %w", chanID, err)
×
1572
        }
×
1573

1574
        s.rejectCache.remove(chanID)
×
1575
        s.chanCache.remove(chanID)
×
1576

×
1577
        return nil
×
1578
}
1579

1580
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1581
//
1582
// NOTE: part of the V1Store interface.
1583
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1584
        s.cacheMu.Lock()
×
1585
        defer s.cacheMu.Unlock()
×
1586

×
1587
        var (
×
1588
                ctx     = context.TODO()
×
1589
                chanIDB = channelIDToBytes(chanID)
×
1590
        )
×
1591

×
1592
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1593
                res, err := db.DeleteZombieChannel(
×
1594
                        ctx, sqlc.DeleteZombieChannelParams{
×
1595
                                Scid:    chanIDB,
×
1596
                                Version: int16(ProtocolV1),
×
1597
                        },
×
1598
                )
×
1599
                if err != nil {
×
1600
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1601
                                err)
×
1602
                }
×
1603

1604
                rows, err := res.RowsAffected()
×
1605
                if err != nil {
×
1606
                        return err
×
1607
                }
×
1608

1609
                if rows == 0 {
×
1610
                        return ErrZombieEdgeNotFound
×
1611
                } else if rows > 1 {
×
1612
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1613
                                "expected 1", rows)
×
1614
                }
×
1615

1616
                return nil
×
1617
        }, sqldb.NoOpReset)
1618
        if err != nil {
×
1619
                return fmt.Errorf("unable to mark edge live "+
×
1620
                        "(channel_id=%d): %w", chanID, err)
×
1621
        }
×
1622

1623
        s.rejectCache.remove(chanID)
×
1624
        s.chanCache.remove(chanID)
×
1625

×
1626
        return err
×
1627
}
1628

1629
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1630
// zombie, then the two node public keys corresponding to this edge are also
1631
// returned.
1632
//
1633
// NOTE: part of the V1Store interface.
1634
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1635
        error) {
×
1636

×
1637
        var (
×
1638
                ctx              = context.TODO()
×
1639
                isZombie         bool
×
1640
                pubKey1, pubKey2 route.Vertex
×
1641
                chanIDB          = channelIDToBytes(chanID)
×
1642
        )
×
1643

×
1644
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1645
                zombie, err := db.GetZombieChannel(
×
1646
                        ctx, sqlc.GetZombieChannelParams{
×
1647
                                Scid:    chanIDB,
×
1648
                                Version: int16(ProtocolV1),
×
1649
                        },
×
1650
                )
×
1651
                if errors.Is(err, sql.ErrNoRows) {
×
1652
                        return nil
×
1653
                }
×
1654
                if err != nil {
×
1655
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1656
                                err)
×
1657
                }
×
1658

1659
                copy(pubKey1[:], zombie.NodeKey1)
×
1660
                copy(pubKey2[:], zombie.NodeKey2)
×
1661
                isZombie = true
×
1662

×
1663
                return nil
×
1664
        }, sqldb.NoOpReset)
1665
        if err != nil {
×
1666
                return false, route.Vertex{}, route.Vertex{},
×
1667
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1668
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1669
        }
×
1670

1671
        return isZombie, pubKey1, pubKey2, nil
×
1672
}
1673

1674
// NumZombies returns the current number of zombie channels in the graph.
1675
//
1676
// NOTE: part of the V1Store interface.
1677
func (s *SQLStore) NumZombies() (uint64, error) {
×
1678
        var (
×
1679
                ctx        = context.TODO()
×
1680
                numZombies uint64
×
1681
        )
×
1682
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1683
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1684
                if err != nil {
×
1685
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1686
                                err)
×
1687
                }
×
1688

1689
                numZombies = uint64(count)
×
1690

×
1691
                return nil
×
1692
        }, sqldb.NoOpReset)
1693
        if err != nil {
×
1694
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1695
        }
×
1696

1697
        return numZombies, nil
×
1698
}
1699

1700
// DeleteChannelEdges removes edges with the given channel IDs from the
1701
// database and marks them as zombies. This ensures that we're unable to re-add
1702
// it to our database once again. If an edge does not exist within the
1703
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1704
// true, then when we mark these edges as zombies, we'll set up the keys such
1705
// that we require the node that failed to send the fresh update to be the one
1706
// that resurrects the channel from its zombie state. The markZombie bool
1707
// denotes whether to mark the channel as a zombie.
1708
//
1709
// NOTE: part of the V1Store interface.
1710
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1711
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1712

×
1713
        s.cacheMu.Lock()
×
1714
        defer s.cacheMu.Unlock()
×
1715

×
1716
        var (
×
1717
                ctx     = context.TODO()
×
1718
                deleted []*models.ChannelEdgeInfo
×
1719
        )
×
1720
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1721
                for _, chanID := range chanIDs {
×
1722
                        chanIDB := channelIDToBytes(chanID)
×
1723

×
1724
                        row, err := db.GetChannelBySCIDWithPolicies(
×
1725
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1726
                                        Scid:    chanIDB,
×
1727
                                        Version: int16(ProtocolV1),
×
1728
                                },
×
1729
                        )
×
1730
                        if errors.Is(err, sql.ErrNoRows) {
×
1731
                                return ErrEdgeNotFound
×
1732
                        } else if err != nil {
×
1733
                                return fmt.Errorf("unable to fetch channel: %w",
×
1734
                                        err)
×
1735
                        }
×
1736

1737
                        node1, node2, err := buildNodeVertices(
×
1738
                                row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1739
                        )
×
1740
                        if err != nil {
×
1741
                                return err
×
1742
                        }
×
1743

1744
                        info, err := getAndBuildEdgeInfo(
×
1745
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1746
                                row.GraphChannel, node1, node2,
×
1747
                        )
×
1748
                        if err != nil {
×
1749
                                return err
×
1750
                        }
×
1751

1752
                        err = db.DeleteChannel(ctx, row.GraphChannel.ID)
×
1753
                        if err != nil {
×
1754
                                return fmt.Errorf("unable to delete "+
×
1755
                                        "channel: %w", err)
×
1756
                        }
×
1757

1758
                        deleted = append(deleted, info)
×
1759

×
1760
                        if !markZombie {
×
1761
                                continue
×
1762
                        }
1763

1764
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1765
                                info.NodeKey2Bytes
×
1766
                        if strictZombiePruning {
×
1767
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1768
                                if row.Policy1LastUpdate.Valid {
×
1769
                                        e1Time := time.Unix(
×
1770
                                                row.Policy1LastUpdate.Int64, 0,
×
1771
                                        )
×
1772
                                        e1UpdateTime = &e1Time
×
1773
                                }
×
1774
                                if row.Policy2LastUpdate.Valid {
×
1775
                                        e2Time := time.Unix(
×
1776
                                                row.Policy2LastUpdate.Int64, 0,
×
1777
                                        )
×
1778
                                        e2UpdateTime = &e2Time
×
1779
                                }
×
1780

1781
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1782
                                        info, e1UpdateTime, e2UpdateTime,
×
1783
                                )
×
1784
                        }
1785

1786
                        err = db.UpsertZombieChannel(
×
1787
                                ctx, sqlc.UpsertZombieChannelParams{
×
1788
                                        Version:  int16(ProtocolV1),
×
1789
                                        Scid:     chanIDB,
×
1790
                                        NodeKey1: nodeKey1[:],
×
1791
                                        NodeKey2: nodeKey2[:],
×
1792
                                },
×
1793
                        )
×
1794
                        if err != nil {
×
1795
                                return fmt.Errorf("unable to mark channel as "+
×
1796
                                        "zombie: %w", err)
×
1797
                        }
×
1798
                }
1799

1800
                return nil
×
1801
        }, func() {
×
1802
                deleted = nil
×
1803
        })
×
1804
        if err != nil {
×
1805
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1806
                        err)
×
1807
        }
×
1808

1809
        for _, chanID := range chanIDs {
×
1810
                s.rejectCache.remove(chanID)
×
1811
                s.chanCache.remove(chanID)
×
1812
        }
×
1813

1814
        return deleted, nil
×
1815
}
1816

1817
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1818
// channel identified by the channel ID. If the channel can't be found, then
1819
// ErrEdgeNotFound is returned. A struct which houses the general information
1820
// for the channel itself is returned as well as two structs that contain the
1821
// routing policies for the channel in either direction.
1822
//
1823
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1824
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1825
// the ChannelEdgeInfo will only include the public keys of each node.
1826
//
1827
// NOTE: part of the V1Store interface.
1828
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1829
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1830
        *models.ChannelEdgePolicy, error) {
×
1831

×
1832
        var (
×
1833
                ctx              = context.TODO()
×
1834
                edge             *models.ChannelEdgeInfo
×
1835
                policy1, policy2 *models.ChannelEdgePolicy
×
1836
                chanIDB          = channelIDToBytes(chanID)
×
1837
        )
×
1838
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1839
                row, err := db.GetChannelBySCIDWithPolicies(
×
1840
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1841
                                Scid:    chanIDB,
×
1842
                                Version: int16(ProtocolV1),
×
1843
                        },
×
1844
                )
×
1845
                if errors.Is(err, sql.ErrNoRows) {
×
1846
                        // First check if this edge is perhaps in the zombie
×
1847
                        // index.
×
1848
                        zombie, err := db.GetZombieChannel(
×
1849
                                ctx, sqlc.GetZombieChannelParams{
×
1850
                                        Scid:    chanIDB,
×
1851
                                        Version: int16(ProtocolV1),
×
1852
                                },
×
1853
                        )
×
1854
                        if errors.Is(err, sql.ErrNoRows) {
×
1855
                                return ErrEdgeNotFound
×
1856
                        } else if err != nil {
×
1857
                                return fmt.Errorf("unable to check if "+
×
1858
                                        "channel is zombie: %w", err)
×
1859
                        }
×
1860

1861
                        // At this point, we know the channel is a zombie, so
1862
                        // we'll return an error indicating this, and we will
1863
                        // populate the edge info with the public keys of each
1864
                        // party as this is the only information we have about
1865
                        // it.
1866
                        edge = &models.ChannelEdgeInfo{}
×
1867
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1868
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1869

×
1870
                        return ErrZombieEdge
×
1871
                } else if err != nil {
×
1872
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1873
                }
×
1874

1875
                node1, node2, err := buildNodeVertices(
×
1876
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1877
                )
×
1878
                if err != nil {
×
1879
                        return err
×
1880
                }
×
1881

1882
                edge, err = getAndBuildEdgeInfo(
×
1883
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1884
                        row.GraphChannel, node1, node2,
×
1885
                )
×
1886
                if err != nil {
×
1887
                        return fmt.Errorf("unable to build channel info: %w",
×
1888
                                err)
×
1889
                }
×
1890

1891
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1892
                if err != nil {
×
1893
                        return fmt.Errorf("unable to extract channel "+
×
1894
                                "policies: %w", err)
×
1895
                }
×
1896

1897
                policy1, policy2, err = getAndBuildChanPolicies(
×
1898
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1899
                )
×
1900
                if err != nil {
×
1901
                        return fmt.Errorf("unable to build channel "+
×
1902
                                "policies: %w", err)
×
1903
                }
×
1904

1905
                return nil
×
1906
        }, sqldb.NoOpReset)
1907
        if err != nil {
×
1908
                // If we are returning the ErrZombieEdge, then we also need to
×
1909
                // return the edge info as the method comment indicates that
×
1910
                // this will be populated when the edge is a zombie.
×
1911
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1912
                        err)
×
1913
        }
×
1914

1915
        return edge, policy1, policy2, nil
×
1916
}
1917

1918
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1919
// the channel identified by the funding outpoint. If the channel can't be
1920
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1921
// information for the channel itself is returned as well as two structs that
1922
// contain the routing policies for the channel in either direction.
1923
//
1924
// NOTE: part of the V1Store interface.
1925
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1926
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1927
        *models.ChannelEdgePolicy, error) {
×
1928

×
1929
        var (
×
1930
                ctx              = context.TODO()
×
1931
                edge             *models.ChannelEdgeInfo
×
1932
                policy1, policy2 *models.ChannelEdgePolicy
×
1933
        )
×
1934
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1935
                row, err := db.GetChannelByOutpointWithPolicies(
×
1936
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1937
                                Outpoint: op.String(),
×
1938
                                Version:  int16(ProtocolV1),
×
1939
                        },
×
1940
                )
×
1941
                if errors.Is(err, sql.ErrNoRows) {
×
1942
                        return ErrEdgeNotFound
×
1943
                } else if err != nil {
×
1944
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1945
                }
×
1946

1947
                node1, node2, err := buildNodeVertices(
×
1948
                        row.Node1Pubkey, row.Node2Pubkey,
×
1949
                )
×
1950
                if err != nil {
×
1951
                        return err
×
1952
                }
×
1953

1954
                edge, err = getAndBuildEdgeInfo(
×
1955
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1956
                        row.GraphChannel, node1, node2,
×
1957
                )
×
1958
                if err != nil {
×
1959
                        return fmt.Errorf("unable to build channel info: %w",
×
1960
                                err)
×
1961
                }
×
1962

1963
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1964
                if err != nil {
×
1965
                        return fmt.Errorf("unable to extract channel "+
×
1966
                                "policies: %w", err)
×
1967
                }
×
1968

1969
                policy1, policy2, err = getAndBuildChanPolicies(
×
1970
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1971
                )
×
1972
                if err != nil {
×
1973
                        return fmt.Errorf("unable to build channel "+
×
1974
                                "policies: %w", err)
×
1975
                }
×
1976

1977
                return nil
×
1978
        }, sqldb.NoOpReset)
1979
        if err != nil {
×
1980
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1981
                        err)
×
1982
        }
×
1983

1984
        return edge, policy1, policy2, nil
×
1985
}
1986

1987
// HasChannelEdge returns true if the database knows of a channel edge with the
1988
// passed channel ID, and false otherwise. If an edge with that ID is found
1989
// within the graph, then two time stamps representing the last time the edge
1990
// was updated for both directed edges are returned along with the boolean. If
1991
// it is not found, then the zombie index is checked and its result is returned
1992
// as the second boolean.
1993
//
1994
// NOTE: part of the V1Store interface.
1995
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1996
        bool, error) {
×
1997

×
1998
        ctx := context.TODO()
×
1999

×
2000
        var (
×
2001
                exists          bool
×
2002
                isZombie        bool
×
2003
                node1LastUpdate time.Time
×
2004
                node2LastUpdate time.Time
×
2005
        )
×
2006

×
2007
        // We'll query the cache with the shared lock held to allow multiple
×
2008
        // readers to access values in the cache concurrently if they exist.
×
2009
        s.cacheMu.RLock()
×
2010
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2011
                s.cacheMu.RUnlock()
×
2012
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2013
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2014
                exists, isZombie = entry.flags.unpack()
×
2015

×
2016
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2017
        }
×
2018
        s.cacheMu.RUnlock()
×
2019

×
2020
        s.cacheMu.Lock()
×
2021
        defer s.cacheMu.Unlock()
×
2022

×
2023
        // The item was not found with the shared lock, so we'll acquire the
×
2024
        // exclusive lock and check the cache again in case another method added
×
2025
        // the entry to the cache while no lock was held.
×
2026
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2027
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2028
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2029
                exists, isZombie = entry.flags.unpack()
×
2030

×
2031
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2032
        }
×
2033

2034
        chanIDB := channelIDToBytes(chanID)
×
2035
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2036
                channel, err := db.GetChannelBySCID(
×
2037
                        ctx, sqlc.GetChannelBySCIDParams{
×
2038
                                Scid:    chanIDB,
×
2039
                                Version: int16(ProtocolV1),
×
2040
                        },
×
2041
                )
×
2042
                if errors.Is(err, sql.ErrNoRows) {
×
2043
                        // Check if it is a zombie channel.
×
2044
                        isZombie, err = db.IsZombieChannel(
×
2045
                                ctx, sqlc.IsZombieChannelParams{
×
2046
                                        Scid:    chanIDB,
×
2047
                                        Version: int16(ProtocolV1),
×
2048
                                },
×
2049
                        )
×
2050
                        if err != nil {
×
2051
                                return fmt.Errorf("could not check if channel "+
×
2052
                                        "is zombie: %w", err)
×
2053
                        }
×
2054

2055
                        return nil
×
2056
                } else if err != nil {
×
2057
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2058
                }
×
2059

2060
                exists = true
×
2061

×
2062
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2063
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2064
                                Version:   int16(ProtocolV1),
×
2065
                                ChannelID: channel.ID,
×
2066
                                NodeID:    channel.NodeID1,
×
2067
                        },
×
2068
                )
×
2069
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2070
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2071
                                err)
×
2072
                } else if err == nil {
×
2073
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2074
                }
×
2075

2076
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2077
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2078
                                Version:   int16(ProtocolV1),
×
2079
                                ChannelID: channel.ID,
×
2080
                                NodeID:    channel.NodeID2,
×
2081
                        },
×
2082
                )
×
2083
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2084
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2085
                                err)
×
2086
                } else if err == nil {
×
2087
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2088
                }
×
2089

2090
                return nil
×
2091
        }, sqldb.NoOpReset)
2092
        if err != nil {
×
2093
                return time.Time{}, time.Time{}, false, false,
×
2094
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2095
        }
×
2096

2097
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2098
                upd1Time: node1LastUpdate.Unix(),
×
2099
                upd2Time: node2LastUpdate.Unix(),
×
2100
                flags:    packRejectFlags(exists, isZombie),
×
2101
        })
×
2102

×
2103
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2104
}
2105

2106
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2107
// passed channel point (outpoint). If the passed channel doesn't exist within
2108
// the database, then ErrEdgeNotFound is returned.
2109
//
2110
// NOTE: part of the V1Store interface.
2111
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2112
        var (
×
2113
                ctx       = context.TODO()
×
2114
                channelID uint64
×
2115
        )
×
2116
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2117
                chanID, err := db.GetSCIDByOutpoint(
×
2118
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2119
                                Outpoint: chanPoint.String(),
×
2120
                                Version:  int16(ProtocolV1),
×
2121
                        },
×
2122
                )
×
2123
                if errors.Is(err, sql.ErrNoRows) {
×
2124
                        return ErrEdgeNotFound
×
2125
                } else if err != nil {
×
2126
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2127
                                err)
×
2128
                }
×
2129

2130
                channelID = byteOrder.Uint64(chanID)
×
2131

×
2132
                return nil
×
2133
        }, sqldb.NoOpReset)
2134
        if err != nil {
×
2135
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2136
        }
×
2137

2138
        return channelID, nil
×
2139
}
2140

2141
// IsPublicNode is a helper method that determines whether the node with the
2142
// given public key is seen as a public node in the graph from the graph's
2143
// source node's point of view.
2144
//
2145
// NOTE: part of the V1Store interface.
2146
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2147
        ctx := context.TODO()
×
2148

×
2149
        var isPublic bool
×
2150
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2151
                var err error
×
2152
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2153

×
2154
                return err
×
2155
        }, sqldb.NoOpReset)
×
2156
        if err != nil {
×
2157
                return false, fmt.Errorf("unable to check if node is "+
×
2158
                        "public: %w", err)
×
2159
        }
×
2160

2161
        return isPublic, nil
×
2162
}
2163

2164
// FetchChanInfos returns the set of channel edges that correspond to the passed
2165
// channel ID's. If an edge is the query is unknown to the database, it will
2166
// skipped and the result will contain only those edges that exist at the time
2167
// of the query. This can be used to respond to peer queries that are seeking to
2168
// fill in gaps in their view of the channel graph.
2169
//
2170
// NOTE: part of the V1Store interface.
2171
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2172
        var (
×
2173
                ctx   = context.TODO()
×
NEW
2174
                edges = make(map[uint64]ChannelEdge)
×
2175
        )
×
2176
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2177
                chanCallBack := func(ctx context.Context,
×
NEW
2178
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2179

×
2180
                        node1, node2, err := buildNodes(
×
2181
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
2182
                        )
×
2183
                        if err != nil {
×
2184
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2185
                                        err)
×
2186
                        }
×
2187

2188
                        edge, err := getAndBuildEdgeInfo(
×
2189
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
2190
                                row.GraphChannel, node1.PubKeyBytes,
×
2191
                                node2.PubKeyBytes,
×
2192
                        )
×
2193
                        if err != nil {
×
2194
                                return fmt.Errorf("unable to build "+
×
2195
                                        "channel info: %w", err)
×
2196
                        }
×
2197

2198
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2199
                        if err != nil {
×
2200
                                return fmt.Errorf("unable to extract channel "+
×
2201
                                        "policies: %w", err)
×
2202
                        }
×
2203

2204
                        p1, p2, err := getAndBuildChanPolicies(
×
2205
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2206
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2207
                        )
×
2208
                        if err != nil {
×
2209
                                return fmt.Errorf("unable to build channel "+
×
2210
                                        "policies: %w", err)
×
2211
                        }
×
2212

NEW
2213
                        edges[edge.ChannelID] = ChannelEdge{
×
2214
                                Info:    edge,
×
2215
                                Policy1: p1,
×
2216
                                Policy2: p2,
×
2217
                                Node1:   node1,
×
2218
                                Node2:   node2,
×
NEW
2219
                        }
×
NEW
2220

×
NEW
2221
                        return nil
×
2222
                }
2223

NEW
2224
                return s.forEachChanWithPoliciesInSCIDList(
×
NEW
2225
                        ctx, db, chanCallBack, chanIDs,
×
NEW
2226
                )
×
2227
        }, func() {
×
NEW
2228
                clear(edges)
×
2229
        })
×
2230
        if err != nil {
×
2231
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2232
        }
×
2233

NEW
2234
        res := make([]ChannelEdge, 0, len(edges))
×
NEW
2235
        for _, chanID := range chanIDs {
×
NEW
2236
                edge, ok := edges[chanID]
×
NEW
2237
                if !ok {
×
NEW
2238
                        continue
×
2239
                }
2240

NEW
2241
                res = append(res, edge)
×
2242
        }
2243

NEW
2244
        return res, nil
×
2245
}
2246

2247
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2248
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2249
// channels in a paginated manner.
2250
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2251
        db SQLQueries, cb func(ctx context.Context,
2252
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
NEW
2253
        chanIDs []uint64) error {
×
NEW
2254

×
NEW
2255
        queryWrapper := func(ctx context.Context,
×
NEW
2256
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
NEW
2257
                error) {
×
NEW
2258

×
NEW
2259
                return db.GetChannelsBySCIDWithPolicies(
×
NEW
2260
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
NEW
2261
                                Version: int16(ProtocolV1),
×
NEW
2262
                                Scids:   scids,
×
NEW
2263
                        },
×
NEW
2264
                )
×
NEW
2265
        }
×
2266

NEW
2267
        return sqldb.ExecutePagedQuery(
×
NEW
2268
                ctx, s.cfg.PaginationCfg, chanIDs, channelIDToBytes,
×
NEW
2269
                queryWrapper, cb,
×
NEW
2270
        )
×
2271
}
2272

2273
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2274
// ID's that we don't know and are not known zombies of the passed set. In other
2275
// words, we perform a set difference of our set of chan ID's and the ones
2276
// passed in. This method can be used by callers to determine the set of
2277
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2278
// known zombies is also returned.
2279
//
2280
// NOTE: part of the V1Store interface.
2281
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2282
        []ChannelUpdateInfo, error) {
×
2283

×
2284
        var (
×
2285
                ctx          = context.TODO()
×
2286
                newChanIDs   []uint64
×
2287
                knownZombies []ChannelUpdateInfo
×
NEW
2288
                infoLookup   = make(
×
NEW
2289
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
NEW
2290
                )
×
2291
        )
×
NEW
2292

×
NEW
2293
        // We first build a lookup map of the channel ID's to the
×
NEW
2294
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
NEW
2295
        // already know about.
×
NEW
2296
        for _, chanInfo := range chansInfo {
×
NEW
2297
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
NEW
2298
        }
×
2299

2300
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2301
                // The call-back function deletes known channels from
×
NEW
2302
                // infoLookup, so that we can later check which channels are
×
NEW
2303
                // zombies by only looking at the remaining channels in the set.
×
NEW
2304
                cb := func(ctx context.Context,
×
NEW
2305
                        channel sqlc.GraphChannel) error {
×
NEW
2306

×
NEW
2307
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
NEW
2308

×
NEW
2309
                        return nil
×
NEW
2310
                }
×
2311

NEW
2312
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
NEW
2313
                if err != nil {
×
NEW
2314
                        return fmt.Errorf("unable to iterate through "+
×
NEW
2315
                                "channels: %w", err)
×
NEW
2316
                }
×
2317

2318
                // We want to ensure that we deal with the channels in the
2319
                // same order that they were passed in, so we iterate over the
2320
                // original chansInfo slice and then check if that channel is
2321
                // still in the infoLookup map.
2322
                for _, chanInfo := range chansInfo {
×
2323
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
NEW
2324
                        if _, ok := infoLookup[channelID]; !ok {
×
2325
                                continue
×
2326
                        }
2327

2328
                        isZombie, err := db.IsZombieChannel(
×
2329
                                ctx, sqlc.IsZombieChannelParams{
×
NEW
2330
                                        Scid:    channelIDToBytes(channelID),
×
2331
                                        Version: int16(ProtocolV1),
×
2332
                                },
×
2333
                        )
×
2334
                        if err != nil {
×
2335
                                return fmt.Errorf("unable to fetch zombie "+
×
2336
                                        "channel: %w", err)
×
2337
                        }
×
2338

2339
                        if isZombie {
×
2340
                                knownZombies = append(knownZombies, chanInfo)
×
2341

×
2342
                                continue
×
2343
                        }
2344

2345
                        newChanIDs = append(newChanIDs, channelID)
×
2346
                }
2347

2348
                return nil
×
2349
        }, func() {
×
2350
                newChanIDs = nil
×
2351
                knownZombies = nil
×
2352
        })
×
2353
        if err != nil {
×
2354
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2355
        }
×
2356

2357
        return newChanIDs, knownZombies, nil
×
2358
}
2359

2360
// forEachChanInSCIDList is a helper method that executes a paged query
2361
// against the database to fetch all channels that match the passed
2362
// ChannelUpdateInfo slice. The callback function is called for each channel
2363
// that is found.
2364
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2365
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
NEW
2366
        chansInfo []ChannelUpdateInfo) error {
×
NEW
2367

×
NEW
2368
        queryWrapper := func(ctx context.Context,
×
NEW
2369
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
NEW
2370

×
NEW
2371
                return db.GetChannelsBySCIDs(
×
NEW
2372
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
NEW
2373
                                Version: int16(ProtocolV1),
×
NEW
2374
                                Scids:   scids,
×
NEW
2375
                        },
×
NEW
2376
                )
×
NEW
2377
        }
×
2378

NEW
2379
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
NEW
2380
                channelID := chanInfo.ShortChannelID.ToUint64()
×
NEW
2381

×
NEW
2382
                return channelIDToBytes(channelID)
×
NEW
2383
        }
×
2384

NEW
2385
        return sqldb.ExecutePagedQuery(
×
NEW
2386
                ctx, s.cfg.PaginationCfg, chansInfo, chanIDConverter,
×
NEW
2387
                queryWrapper, cb,
×
NEW
2388
        )
×
2389
}
2390

2391
// PruneGraphNodes is a garbage collection method which attempts to prune out
2392
// any nodes from the channel graph that are currently unconnected. This ensure
2393
// that we only maintain a graph of reachable nodes. In the event that a pruned
2394
// node gains more channels, it will be re-added back to the graph.
2395
//
2396
// NOTE: this prunes nodes across protocol versions. It will never prune the
2397
// source nodes.
2398
//
2399
// NOTE: part of the V1Store interface.
2400
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2401
        var ctx = context.TODO()
×
2402

×
2403
        var prunedNodes []route.Vertex
×
2404
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2405
                var err error
×
2406
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2407

×
2408
                return err
×
2409
        }, func() {
×
2410
                prunedNodes = nil
×
2411
        })
×
2412
        if err != nil {
×
2413
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2414
        }
×
2415

2416
        return prunedNodes, nil
×
2417
}
2418

2419
// PruneGraph prunes newly closed channels from the channel graph in response
2420
// to a new block being solved on the network. Any transactions which spend the
2421
// funding output of any known channels within he graph will be deleted.
2422
// Additionally, the "prune tip", or the last block which has been used to
2423
// prune the graph is stored so callers can ensure the graph is fully in sync
2424
// with the current UTXO state. A slice of channels that have been closed by
2425
// the target block along with any pruned nodes are returned if the function
2426
// succeeds without error.
2427
//
2428
// NOTE: part of the V1Store interface.
2429
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2430
        blockHash *chainhash.Hash, blockHeight uint32) (
2431
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2432

×
2433
        ctx := context.TODO()
×
2434

×
2435
        s.cacheMu.Lock()
×
2436
        defer s.cacheMu.Unlock()
×
2437

×
2438
        var (
×
2439
                closedChans []*models.ChannelEdgeInfo
×
2440
                prunedNodes []route.Vertex
×
2441
        )
×
2442
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2443
                // Define the callback function for processing each channel
×
NEW
2444
                channelCallback := func(ctx context.Context,
×
NEW
2445
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2446

×
2447
                        node1, node2, err := buildNodeVertices(
×
2448
                                row.Node1Pubkey, row.Node2Pubkey,
×
2449
                        )
×
2450
                        if err != nil {
×
2451
                                return err
×
2452
                        }
×
2453

2454
                        info, err := getAndBuildEdgeInfo(
×
2455
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
2456
                                row.GraphChannel, node1, node2,
×
2457
                        )
×
2458
                        if err != nil {
×
2459
                                return err
×
2460
                        }
×
2461

2462
                        err = db.DeleteChannel(ctx, row.GraphChannel.ID)
×
2463
                        if err != nil {
×
2464
                                return fmt.Errorf("unable to delete "+
×
2465
                                        "channel: %w", err)
×
2466
                        }
×
2467

2468
                        closedChans = append(closedChans, info)
×
NEW
2469

×
NEW
2470
                        return nil
×
2471
                }
2472

NEW
2473
                err := s.forEachChanInOutpoints(
×
NEW
2474
                        ctx, db, spentOutputs, channelCallback,
×
NEW
2475
                )
×
NEW
2476
                if err != nil {
×
NEW
2477
                        return fmt.Errorf("unable to fetch channels by "+
×
NEW
2478
                                "outpoints: %w", err)
×
NEW
2479
                }
×
2480

NEW
2481
                err = db.UpsertPruneLogEntry(
×
2482
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2483
                                BlockHash:   blockHash[:],
×
2484
                                BlockHeight: int64(blockHeight),
×
2485
                        },
×
2486
                )
×
2487
                if err != nil {
×
2488
                        return fmt.Errorf("unable to insert prune log "+
×
2489
                                "entry: %w", err)
×
2490
                }
×
2491

2492
                // Now that we've pruned some channels, we'll also prune any
2493
                // nodes that no longer have any channels.
2494
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2495
                if err != nil {
×
2496
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2497
                                err)
×
2498
                }
×
2499

2500
                return nil
×
2501
        }, func() {
×
2502
                prunedNodes = nil
×
2503
                closedChans = nil
×
2504
        })
×
2505
        if err != nil {
×
2506
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2507
        }
×
2508

2509
        for _, channel := range closedChans {
×
2510
                s.rejectCache.remove(channel.ChannelID)
×
2511
                s.chanCache.remove(channel.ChannelID)
×
2512
        }
×
2513

2514
        return closedChans, prunedNodes, nil
×
2515
}
2516

2517
// forEachChanInOutpoints is a helper function that executes a paginated
2518
// query to fetch channels by their outpoints and applies the given call-back
2519
// to each.
2520
//
2521
// NOTE: this fetches channels for all protocol versions.
2522
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2523
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
NEW
2524
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
NEW
2525

×
NEW
2526
        // Create a wrapper that uses the transaction's db instance to execute
×
NEW
2527
        // the query.
×
NEW
2528
        queryWrapper := func(ctx context.Context,
×
NEW
2529
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
NEW
2530
                error) {
×
NEW
2531

×
NEW
2532
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
NEW
2533
        }
×
2534

2535
        // Define the conversion function from Outpoint to string
NEW
2536
        outpointToString := func(outpoint *wire.OutPoint) string {
×
NEW
2537
                return outpoint.String()
×
NEW
2538
        }
×
2539

NEW
2540
        return sqldb.ExecutePagedQuery(
×
NEW
2541
                ctx, s.cfg.PaginationCfg, outpoints, outpointToString,
×
NEW
2542
                queryWrapper, cb,
×
NEW
2543
        )
×
2544
}
2545

2546
// ChannelView returns the verifiable edge information for each active channel
2547
// within the known channel graph. The set of UTXOs (along with their scripts)
2548
// returned are the ones that need to be watched on chain to detect channel
2549
// closes on the resident blockchain.
2550
//
2551
// NOTE: part of the V1Store interface.
2552
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2553
        var (
×
2554
                ctx        = context.TODO()
×
2555
                edgePoints []EdgePoint
×
2556
        )
×
2557

×
2558
        handleChannel := func(db SQLQueries,
×
2559
                channel sqlc.ListChannelsPaginatedRow) error {
×
2560

×
2561
                pkScript, err := genMultiSigP2WSH(
×
2562
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
2563
                )
×
2564
                if err != nil {
×
2565
                        return err
×
2566
                }
×
2567

2568
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2569
                if err != nil {
×
2570
                        return err
×
2571
                }
×
2572

2573
                edgePoints = append(edgePoints, EdgePoint{
×
2574
                        FundingPkScript: pkScript,
×
2575
                        OutPoint:        *op,
×
2576
                })
×
2577

×
2578
                return nil
×
2579
        }
2580

2581
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2582
                lastID := int64(-1)
×
2583
                for {
×
2584
                        rows, err := db.ListChannelsPaginated(
×
2585
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2586
                                        Version: int16(ProtocolV1),
×
2587
                                        ID:      lastID,
×
2588
                                        Limit:   pageSize,
×
2589
                                },
×
2590
                        )
×
2591
                        if err != nil {
×
2592
                                return err
×
2593
                        }
×
2594

2595
                        if len(rows) == 0 {
×
2596
                                break
×
2597
                        }
2598

2599
                        for _, row := range rows {
×
2600
                                err := handleChannel(db, row)
×
2601
                                if err != nil {
×
2602
                                        return err
×
2603
                                }
×
2604

2605
                                lastID = row.ID
×
2606
                        }
2607
                }
2608

2609
                return nil
×
2610
        }, func() {
×
2611
                edgePoints = nil
×
2612
        })
×
2613
        if err != nil {
×
2614
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2615
        }
×
2616

2617
        return edgePoints, nil
×
2618
}
2619

2620
// PruneTip returns the block height and hash of the latest block that has been
2621
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2622
// to tell if the graph is currently in sync with the current best known UTXO
2623
// state.
2624
//
2625
// NOTE: part of the V1Store interface.
2626
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2627
        var (
×
2628
                ctx       = context.TODO()
×
2629
                tipHash   chainhash.Hash
×
2630
                tipHeight uint32
×
2631
        )
×
2632
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2633
                pruneTip, err := db.GetPruneTip(ctx)
×
2634
                if errors.Is(err, sql.ErrNoRows) {
×
2635
                        return ErrGraphNeverPruned
×
2636
                } else if err != nil {
×
2637
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2638
                }
×
2639

2640
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2641
                tipHeight = uint32(pruneTip.BlockHeight)
×
2642

×
2643
                return nil
×
2644
        }, sqldb.NoOpReset)
2645
        if err != nil {
×
2646
                return nil, 0, err
×
2647
        }
×
2648

2649
        return &tipHash, tipHeight, nil
×
2650
}
2651

2652
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2653
//
2654
// NOTE: this prunes nodes across protocol versions. It will never prune the
2655
// source nodes.
2656
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2657
        db SQLQueries) ([]route.Vertex, error) {
×
2658

×
2659
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2660
        if err != nil {
×
2661
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2662
                        "nodes: %w", err)
×
2663
        }
×
2664

2665
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2666
        for i, nodeKey := range nodeKeys {
×
2667
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2668
                if err != nil {
×
2669
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2670
                                "from bytes: %w", err)
×
2671
                }
×
2672

2673
                prunedNodes[i] = pub
×
2674
        }
2675

2676
        return prunedNodes, nil
×
2677
}
2678

2679
// DisconnectBlockAtHeight is used to indicate that the block specified
2680
// by the passed height has been disconnected from the main chain. This
2681
// will "rewind" the graph back to the height below, deleting channels
2682
// that are no longer confirmed from the graph. The prune log will be
2683
// set to the last prune height valid for the remaining chain.
2684
// Channels that were removed from the graph resulting from the
2685
// disconnected block are returned.
2686
//
2687
// NOTE: part of the V1Store interface.
2688
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2689
        []*models.ChannelEdgeInfo, error) {
×
2690

×
2691
        ctx := context.TODO()
×
2692

×
2693
        var (
×
2694
                // Every channel having a ShortChannelID starting at 'height'
×
2695
                // will no longer be confirmed.
×
2696
                startShortChanID = lnwire.ShortChannelID{
×
2697
                        BlockHeight: height,
×
2698
                }
×
2699

×
2700
                // Delete everything after this height from the db up until the
×
2701
                // SCID alias range.
×
2702
                endShortChanID = aliasmgr.StartingAlias
×
2703

×
2704
                removedChans []*models.ChannelEdgeInfo
×
2705

×
2706
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2707
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2708
        )
×
2709

×
2710
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2711
                rows, err := db.GetChannelsBySCIDRange(
×
2712
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2713
                                StartScid: chanIDStart,
×
2714
                                EndScid:   chanIDEnd,
×
2715
                        },
×
2716
                )
×
2717
                if err != nil {
×
2718
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2719
                }
×
2720

2721
                for _, row := range rows {
×
2722
                        node1, node2, err := buildNodeVertices(
×
2723
                                row.Node1PubKey, row.Node2PubKey,
×
2724
                        )
×
2725
                        if err != nil {
×
2726
                                return err
×
2727
                        }
×
2728

2729
                        channel, err := getAndBuildEdgeInfo(
×
2730
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
2731
                                row.GraphChannel, node1, node2,
×
2732
                        )
×
2733
                        if err != nil {
×
2734
                                return err
×
2735
                        }
×
2736

2737
                        err = db.DeleteChannel(ctx, row.GraphChannel.ID)
×
2738
                        if err != nil {
×
2739
                                return fmt.Errorf("unable to delete "+
×
2740
                                        "channel: %w", err)
×
2741
                        }
×
2742

2743
                        removedChans = append(removedChans, channel)
×
2744
                }
2745

2746
                return db.DeletePruneLogEntriesInRange(
×
2747
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2748
                                StartHeight: int64(height),
×
2749
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2750
                        },
×
2751
                )
×
2752
        }, func() {
×
2753
                removedChans = nil
×
2754
        })
×
2755
        if err != nil {
×
2756
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2757
                        "height: %w", err)
×
2758
        }
×
2759

2760
        for _, channel := range removedChans {
×
2761
                s.rejectCache.remove(channel.ChannelID)
×
2762
                s.chanCache.remove(channel.ChannelID)
×
2763
        }
×
2764

2765
        return removedChans, nil
×
2766
}
2767

2768
// AddEdgeProof sets the proof of an existing edge in the graph database.
2769
//
2770
// NOTE: part of the V1Store interface.
2771
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2772
        proof *models.ChannelAuthProof) error {
×
2773

×
2774
        var (
×
2775
                ctx       = context.TODO()
×
2776
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2777
        )
×
2778

×
2779
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2780
                res, err := db.AddV1ChannelProof(
×
2781
                        ctx, sqlc.AddV1ChannelProofParams{
×
2782
                                Scid:              scidBytes,
×
2783
                                Node1Signature:    proof.NodeSig1Bytes,
×
2784
                                Node2Signature:    proof.NodeSig2Bytes,
×
2785
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2786
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2787
                        },
×
2788
                )
×
2789
                if err != nil {
×
2790
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2791
                }
×
2792

2793
                n, err := res.RowsAffected()
×
2794
                if err != nil {
×
2795
                        return err
×
2796
                }
×
2797

2798
                if n == 0 {
×
2799
                        return fmt.Errorf("no rows affected when adding edge "+
×
2800
                                "proof for SCID %v", scid)
×
2801
                } else if n > 1 {
×
2802
                        return fmt.Errorf("multiple rows affected when adding "+
×
2803
                                "edge proof for SCID %v: %d rows affected",
×
2804
                                scid, n)
×
2805
                }
×
2806

2807
                return nil
×
2808
        }, sqldb.NoOpReset)
2809
        if err != nil {
×
2810
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2811
        }
×
2812

2813
        return nil
×
2814
}
2815

2816
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2817
// that we can ignore channel announcements that we know to be closed without
2818
// having to validate them and fetch a block.
2819
//
2820
// NOTE: part of the V1Store interface.
2821
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2822
        var (
×
2823
                ctx     = context.TODO()
×
2824
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2825
        )
×
2826

×
2827
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2828
                return db.InsertClosedChannel(ctx, chanIDB)
×
2829
        }, sqldb.NoOpReset)
×
2830
}
2831

2832
// IsClosedScid checks whether a channel identified by the passed in scid is
2833
// closed. This helps avoid having to perform expensive validation checks.
2834
//
2835
// NOTE: part of the V1Store interface.
2836
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2837
        var (
×
2838
                ctx      = context.TODO()
×
2839
                isClosed bool
×
2840
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2841
        )
×
2842
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2843
                var err error
×
2844
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2845
                if err != nil {
×
2846
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2847
                                err)
×
2848
                }
×
2849

2850
                return nil
×
2851
        }, sqldb.NoOpReset)
2852
        if err != nil {
×
2853
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2854
                        err)
×
2855
        }
×
2856

2857
        return isClosed, nil
×
2858
}
2859

2860
// GraphSession will provide the call-back with access to a NodeTraverser
2861
// instance which can be used to perform queries against the channel graph.
2862
//
2863
// NOTE: part of the V1Store interface.
2864
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
2865
        reset func()) error {
×
2866

×
2867
        var ctx = context.TODO()
×
2868

×
2869
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2870
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2871
        }, reset)
×
2872
}
2873

2874
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2875
// read only transaction for a consistent view of the graph.
2876
type sqlNodeTraverser struct {
2877
        db    SQLQueries
2878
        chain chainhash.Hash
2879
}
2880

2881
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2882
// NodeTraverser interface.
2883
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2884

2885
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2886
func newSQLNodeTraverser(db SQLQueries,
2887
        chain chainhash.Hash) *sqlNodeTraverser {
×
2888

×
2889
        return &sqlNodeTraverser{
×
2890
                db:    db,
×
2891
                chain: chain,
×
2892
        }
×
2893
}
×
2894

2895
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2896
// node.
2897
//
2898
// NOTE: Part of the NodeTraverser interface.
2899
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2900
        cb func(channel *DirectedChannel) error, _ func()) error {
×
2901

×
2902
        ctx := context.TODO()
×
2903

×
2904
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2905
}
×
2906

2907
// FetchNodeFeatures returns the features of the given node. If the node is
2908
// unknown, assume no additional features are supported.
2909
//
2910
// NOTE: Part of the NodeTraverser interface.
2911
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2912
        *lnwire.FeatureVector, error) {
×
2913

×
2914
        ctx := context.TODO()
×
2915

×
2916
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2917
}
×
2918

2919
// forEachNodeDirectedChannel iterates through all channels of a given
2920
// node, executing the passed callback on the directed edge representing the
2921
// channel and its incoming policy. If the node is not found, no error is
2922
// returned.
2923
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2924
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2925

×
2926
        toNodeCallback := func() route.Vertex {
×
2927
                return nodePub
×
2928
        }
×
2929

2930
        dbID, err := db.GetNodeIDByPubKey(
×
2931
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2932
                        Version: int16(ProtocolV1),
×
2933
                        PubKey:  nodePub[:],
×
2934
                },
×
2935
        )
×
2936
        if errors.Is(err, sql.ErrNoRows) {
×
2937
                return nil
×
2938
        } else if err != nil {
×
2939
                return fmt.Errorf("unable to fetch node: %w", err)
×
2940
        }
×
2941

2942
        rows, err := db.ListChannelsByNodeID(
×
2943
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2944
                        Version: int16(ProtocolV1),
×
2945
                        NodeID1: dbID,
×
2946
                },
×
2947
        )
×
2948
        if err != nil {
×
2949
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2950
        }
×
2951

2952
        // Exit early if there are no channels for this node so we don't
2953
        // do the unnecessary feature fetching.
2954
        if len(rows) == 0 {
×
2955
                return nil
×
2956
        }
×
2957

2958
        features, err := getNodeFeatures(ctx, db, dbID)
×
2959
        if err != nil {
×
2960
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2961
        }
×
2962

2963
        for _, row := range rows {
×
2964
                node1, node2, err := buildNodeVertices(
×
2965
                        row.Node1Pubkey, row.Node2Pubkey,
×
2966
                )
×
2967
                if err != nil {
×
2968
                        return fmt.Errorf("unable to build node vertices: %w",
×
2969
                                err)
×
2970
                }
×
2971

2972
                edge := buildCacheableChannelInfo(
×
2973
                        row.GraphChannel, node1, node2,
×
2974
                )
×
2975

×
2976
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2977
                if err != nil {
×
2978
                        return err
×
2979
                }
×
2980

2981
                var p1, p2 *models.CachedEdgePolicy
×
2982
                if dbPol1 != nil {
×
2983
                        policy1, err := buildChanPolicy(
×
2984
                                *dbPol1, edge.ChannelID, nil, node2,
×
2985
                        )
×
2986
                        if err != nil {
×
2987
                                return err
×
2988
                        }
×
2989

2990
                        p1 = models.NewCachedPolicy(policy1)
×
2991
                }
2992
                if dbPol2 != nil {
×
2993
                        policy2, err := buildChanPolicy(
×
2994
                                *dbPol2, edge.ChannelID, nil, node1,
×
2995
                        )
×
2996
                        if err != nil {
×
2997
                                return err
×
2998
                        }
×
2999

3000
                        p2 = models.NewCachedPolicy(policy2)
×
3001
                }
3002

3003
                // Determine the outgoing and incoming policy for this
3004
                // channel and node combo.
3005
                outPolicy, inPolicy := p1, p2
×
3006
                if p1 != nil && node2 == nodePub {
×
3007
                        outPolicy, inPolicy = p2, p1
×
3008
                } else if p2 != nil && node1 != nodePub {
×
3009
                        outPolicy, inPolicy = p2, p1
×
3010
                }
×
3011

3012
                var cachedInPolicy *models.CachedEdgePolicy
×
3013
                if inPolicy != nil {
×
3014
                        cachedInPolicy = inPolicy
×
3015
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3016
                        cachedInPolicy.ToNodeFeatures = features
×
3017
                }
×
3018

3019
                directedChannel := &DirectedChannel{
×
3020
                        ChannelID:    edge.ChannelID,
×
3021
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3022
                        OtherNode:    edge.NodeKey2Bytes,
×
3023
                        Capacity:     edge.Capacity,
×
3024
                        OutPolicySet: outPolicy != nil,
×
3025
                        InPolicy:     cachedInPolicy,
×
3026
                }
×
3027
                if outPolicy != nil {
×
3028
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3029
                                directedChannel.InboundFee = fee
×
3030
                        })
×
3031
                }
3032

3033
                if nodePub == edge.NodeKey2Bytes {
×
3034
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3035
                }
×
3036

3037
                if err := cb(directedChannel); err != nil {
×
3038
                        return err
×
3039
                }
×
3040
        }
3041

3042
        return nil
×
3043
}
3044

3045
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3046
// and executes the provided callback for each node.
3047
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
3048
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
3049

×
3050
        lastID := int64(-1)
×
3051

×
3052
        for {
×
3053
                nodes, err := db.ListNodeIDsAndPubKeys(
×
3054
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3055
                                Version: int16(ProtocolV1),
×
3056
                                ID:      lastID,
×
3057
                                Limit:   pageSize,
×
3058
                        },
×
3059
                )
×
3060
                if err != nil {
×
3061
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
3062
                }
×
3063

3064
                if len(nodes) == 0 {
×
3065
                        break
×
3066
                }
3067

3068
                for _, node := range nodes {
×
3069
                        var pub route.Vertex
×
3070
                        copy(pub[:], node.PubKey)
×
3071

×
3072
                        if err := cb(node.ID, pub); err != nil {
×
3073
                                return fmt.Errorf("forEachNodeCacheable "+
×
3074
                                        "callback failed for node(id=%d): %w",
×
3075
                                        node.ID, err)
×
3076
                        }
×
3077

3078
                        lastID = node.ID
×
3079
                }
3080
        }
3081

3082
        return nil
×
3083
}
3084

3085
// forEachNodeChannel iterates through all channels of a node, executing
3086
// the passed callback on each. The call-back is provided with the channel's
3087
// edge information, the outgoing policy and the incoming policy for the
3088
// channel and node combo.
3089
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3090
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
3091
                *models.ChannelEdgePolicy,
3092
                *models.ChannelEdgePolicy) error) error {
×
3093

×
3094
        // Get all the V1 channels for this node.Add commentMore actions
×
3095
        rows, err := db.ListChannelsByNodeID(
×
3096
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3097
                        Version: int16(ProtocolV1),
×
3098
                        NodeID1: id,
×
3099
                },
×
3100
        )
×
3101
        if err != nil {
×
3102
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3103
        }
×
3104

3105
        // Call the call-back for each channel and its known policies.
3106
        for _, row := range rows {
×
3107
                node1, node2, err := buildNodeVertices(
×
3108
                        row.Node1Pubkey, row.Node2Pubkey,
×
3109
                )
×
3110
                if err != nil {
×
3111
                        return fmt.Errorf("unable to build node vertices: %w",
×
3112
                                err)
×
3113
                }
×
3114

3115
                edge, err := getAndBuildEdgeInfo(
×
3116
                        ctx, db, chain, row.GraphChannel.ID, row.GraphChannel,
×
3117
                        node1, node2,
×
3118
                )
×
3119
                if err != nil {
×
3120
                        return fmt.Errorf("unable to build channel info: %w",
×
3121
                                err)
×
3122
                }
×
3123

3124
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3125
                if err != nil {
×
3126
                        return fmt.Errorf("unable to extract channel "+
×
3127
                                "policies: %w", err)
×
3128
                }
×
3129

3130
                p1, p2, err := getAndBuildChanPolicies(
×
3131
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3132
                )
×
3133
                if err != nil {
×
3134
                        return fmt.Errorf("unable to build channel "+
×
3135
                                "policies: %w", err)
×
3136
                }
×
3137

3138
                // Determine the outgoing and incoming policy for this
3139
                // channel and node combo.
3140
                p1ToNode := row.GraphChannel.NodeID2
×
3141
                p2ToNode := row.GraphChannel.NodeID1
×
3142
                outPolicy, inPolicy := p1, p2
×
3143
                if (p1 != nil && p1ToNode == id) ||
×
3144
                        (p2 != nil && p2ToNode != id) {
×
3145

×
3146
                        outPolicy, inPolicy = p2, p1
×
3147
                }
×
3148

3149
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3150
                        return err
×
3151
                }
×
3152
        }
3153

3154
        return nil
×
3155
}
3156

3157
// updateChanEdgePolicy upserts the channel policy info we have stored for
3158
// a channel we already know of.
3159
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3160
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3161
        error) {
×
3162

×
3163
        var (
×
3164
                node1Pub, node2Pub route.Vertex
×
3165
                isNode1            bool
×
3166
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3167
        )
×
3168

×
3169
        // Check that this edge policy refers to a channel that we already
×
3170
        // know of. We do this explicitly so that we can return the appropriate
×
3171
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3172
        // abort the transaction which would abort the entire batch.
×
3173
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3174
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3175
                        Scid:    chanIDB,
×
3176
                        Version: int16(ProtocolV1),
×
3177
                },
×
3178
        )
×
3179
        if errors.Is(err, sql.ErrNoRows) {
×
3180
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3181
        } else if err != nil {
×
3182
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3183
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3184
        }
×
3185

3186
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3187
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3188

×
3189
        // Figure out which node this edge is from.
×
3190
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3191
        nodeID := dbChan.NodeID1
×
3192
        if !isNode1 {
×
3193
                nodeID = dbChan.NodeID2
×
3194
        }
×
3195

3196
        var (
×
3197
                inboundBase sql.NullInt64
×
3198
                inboundRate sql.NullInt64
×
3199
        )
×
3200
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3201
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3202
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3203
        })
×
3204

3205
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3206
                Version:     int16(ProtocolV1),
×
3207
                ChannelID:   dbChan.ID,
×
3208
                NodeID:      nodeID,
×
3209
                Timelock:    int32(edge.TimeLockDelta),
×
3210
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3211
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3212
                MinHtlcMsat: int64(edge.MinHTLC),
×
3213
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3214
                Disabled: sql.NullBool{
×
3215
                        Valid: true,
×
3216
                        Bool:  edge.IsDisabled(),
×
3217
                },
×
3218
                MaxHtlcMsat: sql.NullInt64{
×
3219
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3220
                        Int64: int64(edge.MaxHTLC),
×
3221
                },
×
3222
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3223
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3224
                InboundBaseFeeMsat:      inboundBase,
×
3225
                InboundFeeRateMilliMsat: inboundRate,
×
3226
                Signature:               edge.SigBytes,
×
3227
        })
×
3228
        if err != nil {
×
3229
                return node1Pub, node2Pub, isNode1,
×
3230
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3231
        }
×
3232

3233
        // Convert the flat extra opaque data into a map of TLV types to
3234
        // values.
3235
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3236
        if err != nil {
×
3237
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3238
                        "marshal extra opaque data: %w", err)
×
3239
        }
×
3240

3241
        // Update the channel policy's extra signed fields.
3242
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3243
        if err != nil {
×
3244
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3245
                        "policy extra TLVs: %w", err)
×
3246
        }
×
3247

3248
        return node1Pub, node2Pub, isNode1, nil
×
3249
}
3250

3251
// getNodeByPubKey attempts to look up a target node by its public key.
3252
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3253
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3254

×
3255
        dbNode, err := db.GetNodeByPubKey(
×
3256
                ctx, sqlc.GetNodeByPubKeyParams{
×
3257
                        Version: int16(ProtocolV1),
×
3258
                        PubKey:  pubKey[:],
×
3259
                },
×
3260
        )
×
3261
        if errors.Is(err, sql.ErrNoRows) {
×
3262
                return 0, nil, ErrGraphNodeNotFound
×
3263
        } else if err != nil {
×
3264
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3265
        }
×
3266

3267
        node, err := buildNode(ctx, db, &dbNode)
×
3268
        if err != nil {
×
3269
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3270
        }
×
3271

3272
        return dbNode.ID, node, nil
×
3273
}
3274

3275
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3276
// provided database channel row and the public keys of the two nodes
3277
// involved in the channel.
3278
func buildCacheableChannelInfo(dbChan sqlc.GraphChannel, node1Pub,
3279
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3280

×
3281
        return &models.CachedEdgeInfo{
×
3282
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
3283
                NodeKey1Bytes: node1Pub,
×
3284
                NodeKey2Bytes: node2Pub,
×
3285
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
3286
        }
×
3287
}
×
3288

3289
// buildNode constructs a LightningNode instance from the given database node
3290
// record. The node's features, addresses and extra signed fields are also
3291
// fetched from the database and set on the node.
3292
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.GraphNode) (
3293
        *models.LightningNode, error) {
×
3294

×
3295
        if dbNode.Version != int16(ProtocolV1) {
×
3296
                return nil, fmt.Errorf("unsupported node version: %d",
×
3297
                        dbNode.Version)
×
3298
        }
×
3299

3300
        var pub [33]byte
×
3301
        copy(pub[:], dbNode.PubKey)
×
3302

×
3303
        node := &models.LightningNode{
×
3304
                PubKeyBytes: pub,
×
3305
                Features:    lnwire.EmptyFeatureVector(),
×
3306
                LastUpdate:  time.Unix(0, 0),
×
3307
        }
×
3308

×
3309
        if len(dbNode.Signature) == 0 {
×
3310
                return node, nil
×
3311
        }
×
3312

3313
        node.HaveNodeAnnouncement = true
×
3314
        node.AuthSigBytes = dbNode.Signature
×
3315
        node.Alias = dbNode.Alias.String
×
3316
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3317

×
3318
        var err error
×
3319
        if dbNode.Color.Valid {
×
3320
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3321
                if err != nil {
×
3322
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3323
                                err)
×
3324
                }
×
3325
        }
3326

3327
        // Fetch the node's features.
3328
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
3329
        if err != nil {
×
3330
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3331
                        "features: %w", dbNode.ID, err)
×
3332
        }
×
3333

3334
        // Fetch the node's addresses.
3335
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
3336
        if err != nil {
×
3337
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3338
                        "addresses: %w", dbNode.ID, err)
×
3339
        }
×
3340

3341
        // Fetch the node's extra signed fields.
3342
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
3343
        if err != nil {
×
3344
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3345
                        "extra signed fields: %w", dbNode.ID, err)
×
3346
        }
×
3347

3348
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
3349
        if err != nil {
×
3350
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3351
                        "fields: %w", err)
×
3352
        }
×
3353

3354
        if len(recs) != 0 {
×
3355
                node.ExtraOpaqueData = recs
×
3356
        }
×
3357

3358
        return node, nil
×
3359
}
3360

3361
// getNodeFeatures fetches the feature bits and constructs the feature vector
3362
// for a node with the given DB ID.
3363
func getNodeFeatures(ctx context.Context, db SQLQueries,
3364
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3365

×
3366
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3367
        if err != nil {
×
3368
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3369
                        nodeID, err)
×
3370
        }
×
3371

3372
        features := lnwire.EmptyFeatureVector()
×
3373
        for _, feature := range rows {
×
3374
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3375
        }
×
3376

3377
        return features, nil
×
3378
}
3379

3380
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3381
// given DB ID.
3382
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3383
        nodeID int64) (map[uint64][]byte, error) {
×
3384

×
3385
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3386
        if err != nil {
×
3387
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
3388
                        "signed fields: %w", nodeID, err)
×
3389
        }
×
3390

3391
        extraFields := make(map[uint64][]byte)
×
3392
        for _, field := range fields {
×
3393
                extraFields[uint64(field.Type)] = field.Value
×
3394
        }
×
3395

3396
        return extraFields, nil
×
3397
}
3398

3399
// upsertNode upserts the node record into the database. If the node already
3400
// exists, then the node's information is updated. If the node doesn't exist,
3401
// then a new node is created. The node's features, addresses and extra TLV
3402
// types are also updated. The node's DB ID is returned.
3403
func upsertNode(ctx context.Context, db SQLQueries,
3404
        node *models.LightningNode) (int64, error) {
×
3405

×
3406
        params := sqlc.UpsertNodeParams{
×
3407
                Version: int16(ProtocolV1),
×
3408
                PubKey:  node.PubKeyBytes[:],
×
3409
        }
×
3410

×
3411
        if node.HaveNodeAnnouncement {
×
3412
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3413
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3414
                params.Alias = sqldb.SQLStr(node.Alias)
×
3415
                params.Signature = node.AuthSigBytes
×
3416
        }
×
3417

3418
        nodeID, err := db.UpsertNode(ctx, params)
×
3419
        if err != nil {
×
3420
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3421
                        err)
×
3422
        }
×
3423

3424
        // We can exit here if we don't have the announcement yet.
3425
        if !node.HaveNodeAnnouncement {
×
3426
                return nodeID, nil
×
3427
        }
×
3428

3429
        // Update the node's features.
3430
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3431
        if err != nil {
×
3432
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3433
        }
×
3434

3435
        // Update the node's addresses.
3436
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3437
        if err != nil {
×
3438
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3439
        }
×
3440

3441
        // Convert the flat extra opaque data into a map of TLV types to
3442
        // values.
3443
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3444
        if err != nil {
×
3445
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3446
                        err)
×
3447
        }
×
3448

3449
        // Update the node's extra signed fields.
3450
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3451
        if err != nil {
×
3452
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3453
        }
×
3454

3455
        return nodeID, nil
×
3456
}
3457

3458
// upsertNodeFeatures updates the node's features node_features table. This
3459
// includes deleting any feature bits no longer present and inserting any new
3460
// feature bits. If the feature bit does not yet exist in the features table,
3461
// then an entry is created in that table first.
3462
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3463
        features *lnwire.FeatureVector) error {
×
3464

×
3465
        // Get any existing features for the node.
×
3466
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3467
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3468
                return err
×
3469
        }
×
3470

3471
        // Copy the nodes latest set of feature bits.
3472
        newFeatures := make(map[int32]struct{})
×
3473
        if features != nil {
×
3474
                for feature := range features.Features() {
×
3475
                        newFeatures[int32(feature)] = struct{}{}
×
3476
                }
×
3477
        }
3478

3479
        // For any current feature that already exists in the DB, remove it from
3480
        // the in-memory map. For any existing feature that does not exist in
3481
        // the in-memory map, delete it from the database.
3482
        for _, feature := range existingFeatures {
×
3483
                // The feature is still present, so there are no updates to be
×
3484
                // made.
×
3485
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3486
                        delete(newFeatures, feature.FeatureBit)
×
3487
                        continue
×
3488
                }
3489

3490
                // The feature is no longer present, so we remove it from the
3491
                // database.
3492
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3493
                        NodeID:     nodeID,
×
3494
                        FeatureBit: feature.FeatureBit,
×
3495
                })
×
3496
                if err != nil {
×
3497
                        return fmt.Errorf("unable to delete node(%d) "+
×
3498
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3499
                                err)
×
3500
                }
×
3501
        }
3502

3503
        // Any remaining entries in newFeatures are new features that need to be
3504
        // added to the database for the first time.
3505
        for feature := range newFeatures {
×
3506
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3507
                        NodeID:     nodeID,
×
3508
                        FeatureBit: feature,
×
3509
                })
×
3510
                if err != nil {
×
3511
                        return fmt.Errorf("unable to insert node(%d) "+
×
3512
                                "feature(%v): %w", nodeID, feature, err)
×
3513
                }
×
3514
        }
3515

3516
        return nil
×
3517
}
3518

3519
// fetchNodeFeatures fetches the features for a node with the given public key.
3520
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3521
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3522

×
3523
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3524
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3525
                        PubKey:  nodePub[:],
×
3526
                        Version: int16(ProtocolV1),
×
3527
                },
×
3528
        )
×
3529
        if err != nil {
×
3530
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3531
                        nodePub, err)
×
3532
        }
×
3533

3534
        features := lnwire.EmptyFeatureVector()
×
3535
        for _, bit := range rows {
×
3536
                features.Set(lnwire.FeatureBit(bit))
×
3537
        }
×
3538

3539
        return features, nil
×
3540
}
3541

3542
// dbAddressType is an enum type that represents the different address types
3543
// that we store in the node_addresses table. The address type determines how
3544
// the address is to be serialised/deserialize.
3545
type dbAddressType uint8
3546

3547
const (
3548
        addressTypeIPv4   dbAddressType = 1
3549
        addressTypeIPv6   dbAddressType = 2
3550
        addressTypeTorV2  dbAddressType = 3
3551
        addressTypeTorV3  dbAddressType = 4
3552
        addressTypeOpaque dbAddressType = math.MaxInt8
3553
)
3554

3555
// upsertNodeAddresses updates the node's addresses in the database. This
3556
// includes deleting any existing addresses and inserting the new set of
3557
// addresses. The deletion is necessary since the ordering of the addresses may
3558
// change, and we need to ensure that the database reflects the latest set of
3559
// addresses so that at the time of reconstructing the node announcement, the
3560
// order is preserved and the signature over the message remains valid.
3561
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3562
        addresses []net.Addr) error {
×
3563

×
3564
        // Delete any existing addresses for the node. This is required since
×
3565
        // even if the new set of addresses is the same, the ordering may have
×
3566
        // changed for a given address type.
×
3567
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3568
        if err != nil {
×
3569
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3570
                        nodeID, err)
×
3571
        }
×
3572

3573
        // Copy the nodes latest set of addresses.
3574
        newAddresses := map[dbAddressType][]string{
×
3575
                addressTypeIPv4:   {},
×
3576
                addressTypeIPv6:   {},
×
3577
                addressTypeTorV2:  {},
×
3578
                addressTypeTorV3:  {},
×
3579
                addressTypeOpaque: {},
×
3580
        }
×
3581
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3582
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3583
        }
×
3584

3585
        for _, address := range addresses {
×
3586
                switch addr := address.(type) {
×
3587
                case *net.TCPAddr:
×
3588
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3589
                                addAddr(addressTypeIPv4, addr)
×
3590
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3591
                                addAddr(addressTypeIPv6, addr)
×
3592
                        } else {
×
3593
                                return fmt.Errorf("unhandled IP address: %v",
×
3594
                                        addr)
×
3595
                        }
×
3596

3597
                case *tor.OnionAddr:
×
3598
                        switch len(addr.OnionService) {
×
3599
                        case tor.V2Len:
×
3600
                                addAddr(addressTypeTorV2, addr)
×
3601
                        case tor.V3Len:
×
3602
                                addAddr(addressTypeTorV3, addr)
×
3603
                        default:
×
3604
                                return fmt.Errorf("invalid length for a tor " +
×
3605
                                        "address")
×
3606
                        }
3607

3608
                case *lnwire.OpaqueAddrs:
×
3609
                        addAddr(addressTypeOpaque, addr)
×
3610

3611
                default:
×
3612
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3613
                }
3614
        }
3615

3616
        // Any remaining entries in newAddresses are new addresses that need to
3617
        // be added to the database for the first time.
3618
        for addrType, addrList := range newAddresses {
×
3619
                for position, addr := range addrList {
×
3620
                        err := db.InsertNodeAddress(
×
3621
                                ctx, sqlc.InsertNodeAddressParams{
×
3622
                                        NodeID:   nodeID,
×
3623
                                        Type:     int16(addrType),
×
3624
                                        Address:  addr,
×
3625
                                        Position: int32(position),
×
3626
                                },
×
3627
                        )
×
3628
                        if err != nil {
×
3629
                                return fmt.Errorf("unable to insert "+
×
3630
                                        "node(%d) address(%v): %w", nodeID,
×
3631
                                        addr, err)
×
3632
                        }
×
3633
                }
3634
        }
3635

3636
        return nil
×
3637
}
3638

3639
// getNodeAddresses fetches the addresses for a node with the given public key.
3640
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3641
        []net.Addr, error) {
×
3642

×
3643
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
3644
        // are returned in the same order as they were inserted.
×
3645
        rows, err := db.GetNodeAddressesByPubKey(
×
3646
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
3647
                        Version: int16(ProtocolV1),
×
3648
                        PubKey:  nodePub,
×
3649
                },
×
3650
        )
×
3651
        if err != nil {
×
3652
                return false, nil, err
×
3653
        }
×
3654

3655
        // GetNodeAddressesByPubKey uses a left join so there should always be
3656
        // at least one row returned if the node exists even if it has no
3657
        // addresses.
3658
        if len(rows) == 0 {
×
3659
                return false, nil, nil
×
3660
        }
×
3661

3662
        addresses := make([]net.Addr, 0, len(rows))
×
3663
        for _, addr := range rows {
×
3664
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3665
                        continue
×
3666
                }
3667

3668
                address := addr.Address.String
×
3669

×
3670
                switch dbAddressType(addr.Type.Int16) {
×
3671
                case addressTypeIPv4:
×
3672
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3673
                        if err != nil {
×
3674
                                return false, nil, nil
×
3675
                        }
×
3676
                        tcp.IP = tcp.IP.To4()
×
3677

×
3678
                        addresses = append(addresses, tcp)
×
3679

3680
                case addressTypeIPv6:
×
3681
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3682
                        if err != nil {
×
3683
                                return false, nil, nil
×
3684
                        }
×
3685
                        addresses = append(addresses, tcp)
×
3686

3687
                case addressTypeTorV3, addressTypeTorV2:
×
3688
                        service, portStr, err := net.SplitHostPort(address)
×
3689
                        if err != nil {
×
3690
                                return false, nil, fmt.Errorf("unable to "+
×
3691
                                        "split tor v3 address: %v",
×
3692
                                        addr.Address)
×
3693
                        }
×
3694

3695
                        port, err := strconv.Atoi(portStr)
×
3696
                        if err != nil {
×
3697
                                return false, nil, err
×
3698
                        }
×
3699

3700
                        addresses = append(addresses, &tor.OnionAddr{
×
3701
                                OnionService: service,
×
3702
                                Port:         port,
×
3703
                        })
×
3704

3705
                case addressTypeOpaque:
×
3706
                        opaque, err := hex.DecodeString(address)
×
3707
                        if err != nil {
×
3708
                                return false, nil, fmt.Errorf("unable to "+
×
3709
                                        "decode opaque address: %v", addr)
×
3710
                        }
×
3711

3712
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3713
                                Payload: opaque,
×
3714
                        })
×
3715

3716
                default:
×
3717
                        return false, nil, fmt.Errorf("unknown address "+
×
3718
                                "type: %v", addr.Type)
×
3719
                }
3720
        }
3721

3722
        // If we have no addresses, then we'll return nil instead of an
3723
        // empty slice.
3724
        if len(addresses) == 0 {
×
3725
                addresses = nil
×
3726
        }
×
3727

3728
        return true, addresses, nil
×
3729
}
3730

3731
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3732
// database. This includes updating any existing types, inserting any new types,
3733
// and deleting any types that are no longer present.
3734
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3735
        nodeID int64, extraFields map[uint64][]byte) error {
×
3736

×
3737
        // Get any existing extra signed fields for the node.
×
3738
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3739
        if err != nil {
×
3740
                return err
×
3741
        }
×
3742

3743
        // Make a lookup map of the existing field types so that we can use it
3744
        // to keep track of any fields we should delete.
3745
        m := make(map[uint64]bool)
×
3746
        for _, field := range existingFields {
×
3747
                m[uint64(field.Type)] = true
×
3748
        }
×
3749

3750
        // For all the new fields, we'll upsert them and remove them from the
3751
        // map of existing fields.
3752
        for tlvType, value := range extraFields {
×
3753
                err = db.UpsertNodeExtraType(
×
3754
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3755
                                NodeID: nodeID,
×
3756
                                Type:   int64(tlvType),
×
3757
                                Value:  value,
×
3758
                        },
×
3759
                )
×
3760
                if err != nil {
×
3761
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3762
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3763
                }
×
3764

3765
                // Remove the field from the map of existing fields if it was
3766
                // present.
3767
                delete(m, tlvType)
×
3768
        }
3769

3770
        // For all the fields that are left in the map of existing fields, we'll
3771
        // delete them as they are no longer present in the new set of fields.
3772
        for tlvType := range m {
×
3773
                err = db.DeleteExtraNodeType(
×
3774
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3775
                                NodeID: nodeID,
×
3776
                                Type:   int64(tlvType),
×
3777
                        },
×
3778
                )
×
3779
                if err != nil {
×
3780
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3781
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3782
                }
×
3783
        }
3784

3785
        return nil
×
3786
}
3787

3788
// srcNodeInfo holds the information about the source node of the graph.
3789
type srcNodeInfo struct {
3790
        // id is the DB level ID of the source node entry in the "nodes" table.
3791
        id int64
3792

3793
        // pub is the public key of the source node.
3794
        pub route.Vertex
3795
}
3796

3797
// sourceNode returns the DB node ID and pub key of the source node for the
3798
// specified protocol version.
3799
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3800
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3801

×
3802
        s.srcNodeMu.Lock()
×
3803
        defer s.srcNodeMu.Unlock()
×
3804

×
3805
        // If we already have the source node ID and pub key cached, then
×
3806
        // return them.
×
3807
        if info, ok := s.srcNodes[version]; ok {
×
3808
                return info.id, info.pub, nil
×
3809
        }
×
3810

3811
        var pubKey route.Vertex
×
3812

×
3813
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3814
        if err != nil {
×
3815
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3816
                        err)
×
3817
        }
×
3818

3819
        if len(nodes) == 0 {
×
3820
                return 0, pubKey, ErrSourceNodeNotSet
×
3821
        } else if len(nodes) > 1 {
×
3822
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3823
                        "protocol %s found", version)
×
3824
        }
×
3825

3826
        copy(pubKey[:], nodes[0].PubKey)
×
3827

×
3828
        s.srcNodes[version] = &srcNodeInfo{
×
3829
                id:  nodes[0].NodeID,
×
3830
                pub: pubKey,
×
3831
        }
×
3832

×
3833
        return nodes[0].NodeID, pubKey, nil
×
3834
}
3835

3836
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3837
// This then produces a map from TLV type to value. If the input is not a
3838
// valid TLV stream, then an error is returned.
3839
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3840
        r := bytes.NewReader(data)
×
3841

×
3842
        tlvStream, err := tlv.NewStream()
×
3843
        if err != nil {
×
3844
                return nil, err
×
3845
        }
×
3846

3847
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3848
        // pass it into the P2P decoding variant.
3849
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3850
        if err != nil {
×
3851
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3852
        }
×
3853
        if len(parsedTypes) == 0 {
×
3854
                return nil, nil
×
3855
        }
×
3856

3857
        records := make(map[uint64][]byte)
×
3858
        for k, v := range parsedTypes {
×
3859
                records[uint64(k)] = v
×
3860
        }
×
3861

3862
        return records, nil
×
3863
}
3864

3865
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3866
// channel.
3867
type dbChanInfo struct {
3868
        channelID int64
3869
        node1ID   int64
3870
        node2ID   int64
3871
}
3872

3873
// insertChannel inserts a new channel record into the database.
3874
func insertChannel(ctx context.Context, db SQLQueries,
3875
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3876

×
3877
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3878

×
3879
        // Make sure that the channel doesn't already exist. We do this
×
3880
        // explicitly instead of relying on catching a unique constraint error
×
3881
        // because relying on SQL to throw that error would abort the entire
×
3882
        // batch of transactions.
×
3883
        _, err := db.GetChannelBySCID(
×
3884
                ctx, sqlc.GetChannelBySCIDParams{
×
3885
                        Scid:    chanIDB,
×
3886
                        Version: int16(ProtocolV1),
×
3887
                },
×
3888
        )
×
3889
        if err == nil {
×
3890
                return nil, ErrEdgeAlreadyExist
×
3891
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3892
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3893
        }
×
3894

3895
        // Make sure that at least a "shell" entry for each node is present in
3896
        // the nodes table.
3897
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3898
        if err != nil {
×
3899
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3900
        }
×
3901

3902
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3903
        if err != nil {
×
3904
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3905
        }
×
3906

3907
        var capacity sql.NullInt64
×
3908
        if edge.Capacity != 0 {
×
3909
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3910
        }
×
3911

3912
        createParams := sqlc.CreateChannelParams{
×
3913
                Version:     int16(ProtocolV1),
×
3914
                Scid:        chanIDB,
×
3915
                NodeID1:     node1DBID,
×
3916
                NodeID2:     node2DBID,
×
3917
                Outpoint:    edge.ChannelPoint.String(),
×
3918
                Capacity:    capacity,
×
3919
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3920
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3921
        }
×
3922

×
3923
        if edge.AuthProof != nil {
×
3924
                proof := edge.AuthProof
×
3925

×
3926
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3927
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3928
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3929
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3930
        }
×
3931

3932
        // Insert the new channel record.
3933
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3934
        if err != nil {
×
3935
                return nil, err
×
3936
        }
×
3937

3938
        // Insert any channel features.
3939
        for feature := range edge.Features.Features() {
×
3940
                err = db.InsertChannelFeature(
×
3941
                        ctx, sqlc.InsertChannelFeatureParams{
×
3942
                                ChannelID:  dbChanID,
×
3943
                                FeatureBit: int32(feature),
×
3944
                        },
×
3945
                )
×
3946
                if err != nil {
×
3947
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
3948
                                "feature(%v): %w", dbChanID, feature, err)
×
3949
                }
×
3950
        }
3951

3952
        // Finally, insert any extra TLV fields in the channel announcement.
3953
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3954
        if err != nil {
×
3955
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
3956
                        "data: %w", err)
×
3957
        }
×
3958

3959
        for tlvType, value := range extra {
×
3960
                err := db.CreateChannelExtraType(
×
3961
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3962
                                ChannelID: dbChanID,
×
3963
                                Type:      int64(tlvType),
×
3964
                                Value:     value,
×
3965
                        },
×
3966
                )
×
3967
                if err != nil {
×
3968
                        return nil, fmt.Errorf("unable to upsert "+
×
3969
                                "channel(%d) extra signed field(%v): %w",
×
3970
                                edge.ChannelID, tlvType, err)
×
3971
                }
×
3972
        }
3973

3974
        return &dbChanInfo{
×
3975
                channelID: dbChanID,
×
3976
                node1ID:   node1DBID,
×
3977
                node2ID:   node2DBID,
×
3978
        }, nil
×
3979
}
3980

3981
// maybeCreateShellNode checks if a shell node entry exists for the
3982
// given public key. If it does not exist, then a new shell node entry is
3983
// created. The ID of the node is returned. A shell node only has a protocol
3984
// version and public key persisted.
3985
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3986
        pubKey route.Vertex) (int64, error) {
×
3987

×
3988
        dbNode, err := db.GetNodeByPubKey(
×
3989
                ctx, sqlc.GetNodeByPubKeyParams{
×
3990
                        PubKey:  pubKey[:],
×
3991
                        Version: int16(ProtocolV1),
×
3992
                },
×
3993
        )
×
3994
        // The node exists. Return the ID.
×
3995
        if err == nil {
×
3996
                return dbNode.ID, nil
×
3997
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3998
                return 0, err
×
3999
        }
×
4000

4001
        // Otherwise, the node does not exist, so we create a shell entry for
4002
        // it.
4003
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
4004
                Version: int16(ProtocolV1),
×
4005
                PubKey:  pubKey[:],
×
4006
        })
×
4007
        if err != nil {
×
4008
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4009
        }
×
4010

4011
        return id, nil
×
4012
}
4013

4014
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4015
// the database. This includes deleting any existing types and then inserting
4016
// the new types.
4017
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4018
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4019

×
4020
        // Delete all existing extra signed fields for the channel policy.
×
4021
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4022
        if err != nil {
×
4023
                return fmt.Errorf("unable to delete "+
×
4024
                        "existing policy extra signed fields for policy %d: %w",
×
4025
                        chanPolicyID, err)
×
4026
        }
×
4027

4028
        // Insert all new extra signed fields for the channel policy.
4029
        for tlvType, value := range extraFields {
×
4030
                err = db.InsertChanPolicyExtraType(
×
4031
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
4032
                                ChannelPolicyID: chanPolicyID,
×
4033
                                Type:            int64(tlvType),
×
4034
                                Value:           value,
×
4035
                        },
×
4036
                )
×
4037
                if err != nil {
×
4038
                        return fmt.Errorf("unable to insert "+
×
4039
                                "channel_policy(%d) extra signed field(%v): %w",
×
4040
                                chanPolicyID, tlvType, err)
×
4041
                }
×
4042
        }
4043

4044
        return nil
×
4045
}
4046

4047
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4048
// provided dbChanRow and also fetches any other required information
4049
// to construct the edge info.
4050
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
4051
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.GraphChannel, node1,
4052
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4053

×
4054
        if dbChan.Version != int16(ProtocolV1) {
×
4055
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4056
                        dbChan.Version)
×
4057
        }
×
4058

4059
        fv, extras, err := getChanFeaturesAndExtras(
×
4060
                ctx, db, dbChanID,
×
4061
        )
×
4062
        if err != nil {
×
4063
                return nil, err
×
4064
        }
×
4065

4066
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4067
        if err != nil {
×
4068
                return nil, err
×
4069
        }
×
4070

4071
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4072
        if err != nil {
×
4073
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4074
                        "fields: %w", err)
×
4075
        }
×
4076
        if recs == nil {
×
4077
                recs = make([]byte, 0)
×
4078
        }
×
4079

4080
        var btcKey1, btcKey2 route.Vertex
×
4081
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4082
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4083

×
4084
        channel := &models.ChannelEdgeInfo{
×
4085
                ChainHash:        chain,
×
4086
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4087
                NodeKey1Bytes:    node1,
×
4088
                NodeKey2Bytes:    node2,
×
4089
                BitcoinKey1Bytes: btcKey1,
×
4090
                BitcoinKey2Bytes: btcKey2,
×
4091
                ChannelPoint:     *op,
×
4092
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4093
                Features:         fv,
×
4094
                ExtraOpaqueData:  recs,
×
4095
        }
×
4096

×
4097
        // We always set all the signatures at the same time, so we can
×
4098
        // safely check if one signature is present to determine if we have the
×
4099
        // rest of the signatures for the auth proof.
×
4100
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4101
                channel.AuthProof = &models.ChannelAuthProof{
×
4102
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4103
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4104
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4105
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4106
                }
×
4107
        }
×
4108

4109
        return channel, nil
×
4110
}
4111

4112
// buildNodeVertices is a helper that converts raw node public keys
4113
// into route.Vertex instances.
4114
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4115
        route.Vertex, error) {
×
4116

×
4117
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4118
        if err != nil {
×
4119
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4120
                        "create vertex from node1 pubkey: %w", err)
×
4121
        }
×
4122

4123
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4124
        if err != nil {
×
4125
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4126
                        "create vertex from node2 pubkey: %w", err)
×
4127
        }
×
4128

4129
        return node1Vertex, node2Vertex, nil
×
4130
}
4131

4132
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
4133
// for a channel with the given ID.
4134
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
4135
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
4136

×
4137
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
4138
        if err != nil {
×
4139
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
4140
                        "features and extras: %w", err)
×
4141
        }
×
4142

4143
        var (
×
4144
                fv     = lnwire.EmptyFeatureVector()
×
4145
                extras = make(map[uint64][]byte)
×
4146
        )
×
4147
        for _, row := range rows {
×
4148
                if row.IsFeature {
×
4149
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
4150

×
4151
                        continue
×
4152
                }
4153

4154
                tlvType, ok := row.ExtraKey.(int64)
×
4155
                if !ok {
×
4156
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4157
                                "TLV type: %T", row.ExtraKey)
×
4158
                }
×
4159

4160
                valueBytes, ok := row.Value.([]byte)
×
4161
                if !ok {
×
4162
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4163
                                "Value: %T", row.Value)
×
4164
                }
×
4165

4166
                extras[uint64(tlvType)] = valueBytes
×
4167
        }
4168

4169
        return fv, extras, nil
×
4170
}
4171

4172
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4173
// retrieves all the extra info required to build the complete
4174
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4175
// the provided sqlc.GraphChannelPolicy records are nil.
4176
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4177
        dbPol1, dbPol2 *sqlc.GraphChannelPolicy, channelID uint64, node1,
4178
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4179
        *models.ChannelEdgePolicy, error) {
×
4180

×
4181
        if dbPol1 == nil && dbPol2 == nil {
×
4182
                return nil, nil, nil
×
4183
        }
×
4184

4185
        var (
×
4186
                policy1ID int64
×
4187
                policy2ID int64
×
4188
        )
×
4189
        if dbPol1 != nil {
×
4190
                policy1ID = dbPol1.ID
×
4191
        }
×
4192
        if dbPol2 != nil {
×
4193
                policy2ID = dbPol2.ID
×
4194
        }
×
4195
        rows, err := db.GetChannelPolicyExtraTypes(
×
4196
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
4197
                        ID:   policy1ID,
×
4198
                        ID_2: policy2ID,
×
4199
                },
×
4200
        )
×
4201
        if err != nil {
×
4202
                return nil, nil, err
×
4203
        }
×
4204

4205
        var (
×
4206
                dbPol1Extras = make(map[uint64][]byte)
×
4207
                dbPol2Extras = make(map[uint64][]byte)
×
4208
        )
×
4209
        for _, row := range rows {
×
4210
                switch row.PolicyID {
×
4211
                case policy1ID:
×
4212
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
4213
                case policy2ID:
×
4214
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
4215
                default:
×
4216
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
4217
                                "in row: %v", row.PolicyID, row)
×
4218
                }
4219
        }
4220

4221
        var pol1, pol2 *models.ChannelEdgePolicy
×
4222
        if dbPol1 != nil {
×
4223
                pol1, err = buildChanPolicy(
×
4224
                        *dbPol1, channelID, dbPol1Extras, node2,
×
4225
                )
×
4226
                if err != nil {
×
4227
                        return nil, nil, err
×
4228
                }
×
4229
        }
4230
        if dbPol2 != nil {
×
4231
                pol2, err = buildChanPolicy(
×
4232
                        *dbPol2, channelID, dbPol2Extras, node1,
×
4233
                )
×
4234
                if err != nil {
×
4235
                        return nil, nil, err
×
4236
                }
×
4237
        }
4238

4239
        return pol1, pol2, nil
×
4240
}
4241

4242
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4243
// provided sqlc.GraphChannelPolicy and other required information.
4244
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4245
        extras map[uint64][]byte,
4246
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4247

×
4248
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4249
        if err != nil {
×
4250
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4251
                        "fields: %w", err)
×
4252
        }
×
4253

4254
        var inboundFee fn.Option[lnwire.Fee]
×
4255
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4256
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4257

×
4258
                inboundFee = fn.Some(lnwire.Fee{
×
4259
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4260
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4261
                })
×
4262
        }
×
4263

4264
        return &models.ChannelEdgePolicy{
×
4265
                SigBytes:  dbPolicy.Signature,
×
4266
                ChannelID: channelID,
×
4267
                LastUpdate: time.Unix(
×
4268
                        dbPolicy.LastUpdate.Int64, 0,
×
4269
                ),
×
4270
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4271
                        dbPolicy.MessageFlags,
×
4272
                ),
×
4273
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4274
                        dbPolicy.ChannelFlags,
×
4275
                ),
×
4276
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4277
                MinHTLC: lnwire.MilliSatoshi(
×
4278
                        dbPolicy.MinHtlcMsat,
×
4279
                ),
×
4280
                MaxHTLC: lnwire.MilliSatoshi(
×
4281
                        dbPolicy.MaxHtlcMsat.Int64,
×
4282
                ),
×
4283
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4284
                        dbPolicy.BaseFeeMsat,
×
4285
                ),
×
4286
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4287
                ToNode:                    toNode,
×
4288
                InboundFee:                inboundFee,
×
4289
                ExtraOpaqueData:           recs,
×
4290
        }, nil
×
4291
}
4292

4293
// buildNodes builds the models.LightningNode instances for the
4294
// given row which is expected to be a sqlc type that contains node information.
4295
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4296
        dbNode2 sqlc.GraphNode) (*models.LightningNode, *models.LightningNode,
4297
        error) {
×
4298

×
4299
        node1, err := buildNode(ctx, db, &dbNode1)
×
4300
        if err != nil {
×
4301
                return nil, nil, err
×
4302
        }
×
4303

4304
        node2, err := buildNode(ctx, db, &dbNode2)
×
4305
        if err != nil {
×
4306
                return nil, nil, err
×
4307
        }
×
4308

4309
        return node1, node2, nil
×
4310
}
4311

4312
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4313
// row which is expected to be a sqlc type that contains channel policy
4314
// information. It returns two policies, which may be nil if the policy
4315
// information is not present in the row.
4316
//
4317
//nolint:ll,dupl,funlen
4318
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4319
        *sqlc.GraphChannelPolicy, error) {
×
4320

×
4321
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4322
        switch r := row.(type) {
×
NEW
4323
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
NEW
4324
                if r.Policy1ID.Valid {
×
NEW
4325
                        policy1 = &sqlc.GraphChannelPolicy{
×
NEW
4326
                                ID:                      r.Policy1ID.Int64,
×
NEW
4327
                                Version:                 r.Policy1Version.Int16,
×
NEW
4328
                                ChannelID:               r.GraphChannel.ID,
×
NEW
4329
                                NodeID:                  r.Policy1NodeID.Int64,
×
NEW
4330
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
4331
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
4332
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
4333
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
4334
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
4335
                                LastUpdate:              r.Policy1LastUpdate,
×
NEW
4336
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
4337
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
4338
                                Disabled:                r.Policy1Disabled,
×
NEW
4339
                                MessageFlags:            r.Policy1MessageFlags,
×
NEW
4340
                                ChannelFlags:            r.Policy1ChannelFlags,
×
NEW
4341
                                Signature:               r.Policy1Signature,
×
NEW
4342
                        }
×
NEW
4343
                }
×
NEW
4344
                if r.Policy2ID.Valid {
×
NEW
4345
                        policy2 = &sqlc.GraphChannelPolicy{
×
NEW
4346
                                ID:                      r.Policy2ID.Int64,
×
NEW
4347
                                Version:                 r.Policy2Version.Int16,
×
NEW
4348
                                ChannelID:               r.GraphChannel.ID,
×
NEW
4349
                                NodeID:                  r.Policy2NodeID.Int64,
×
NEW
4350
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
4351
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
4352
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
4353
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
4354
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
4355
                                LastUpdate:              r.Policy2LastUpdate,
×
NEW
4356
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
4357
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
4358
                                Disabled:                r.Policy2Disabled,
×
NEW
4359
                                MessageFlags:            r.Policy2MessageFlags,
×
NEW
4360
                                ChannelFlags:            r.Policy2ChannelFlags,
×
NEW
4361
                                Signature:               r.Policy2Signature,
×
NEW
4362
                        }
×
NEW
4363
                }
×
4364

NEW
4365
                return policy1, policy2, nil
×
4366

4367
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4368
                if r.Policy1ID.Valid {
×
4369
                        policy1 = &sqlc.GraphChannelPolicy{
×
4370
                                ID:                      r.Policy1ID.Int64,
×
4371
                                Version:                 r.Policy1Version.Int16,
×
4372
                                ChannelID:               r.GraphChannel.ID,
×
4373
                                NodeID:                  r.Policy1NodeID.Int64,
×
4374
                                Timelock:                r.Policy1Timelock.Int32,
×
4375
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4376
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4377
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4378
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4379
                                LastUpdate:              r.Policy1LastUpdate,
×
4380
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4381
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4382
                                Disabled:                r.Policy1Disabled,
×
4383
                                MessageFlags:            r.Policy1MessageFlags,
×
4384
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4385
                                Signature:               r.Policy1Signature,
×
4386
                        }
×
4387
                }
×
4388
                if r.Policy2ID.Valid {
×
4389
                        policy2 = &sqlc.GraphChannelPolicy{
×
4390
                                ID:                      r.Policy2ID.Int64,
×
4391
                                Version:                 r.Policy2Version.Int16,
×
4392
                                ChannelID:               r.GraphChannel.ID,
×
4393
                                NodeID:                  r.Policy2NodeID.Int64,
×
4394
                                Timelock:                r.Policy2Timelock.Int32,
×
4395
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4396
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4397
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4398
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4399
                                LastUpdate:              r.Policy2LastUpdate,
×
4400
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4401
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4402
                                Disabled:                r.Policy2Disabled,
×
4403
                                MessageFlags:            r.Policy2MessageFlags,
×
4404
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4405
                                Signature:               r.Policy2Signature,
×
4406
                        }
×
4407
                }
×
4408

4409
                return policy1, policy2, nil
×
4410

4411
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4412
                if r.Policy1ID.Valid {
×
4413
                        policy1 = &sqlc.GraphChannelPolicy{
×
4414
                                ID:                      r.Policy1ID.Int64,
×
4415
                                Version:                 r.Policy1Version.Int16,
×
4416
                                ChannelID:               r.GraphChannel.ID,
×
4417
                                NodeID:                  r.Policy1NodeID.Int64,
×
4418
                                Timelock:                r.Policy1Timelock.Int32,
×
4419
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4420
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4421
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4422
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4423
                                LastUpdate:              r.Policy1LastUpdate,
×
4424
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4425
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4426
                                Disabled:                r.Policy1Disabled,
×
4427
                                MessageFlags:            r.Policy1MessageFlags,
×
4428
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4429
                                Signature:               r.Policy1Signature,
×
4430
                        }
×
4431
                }
×
4432
                if r.Policy2ID.Valid {
×
4433
                        policy2 = &sqlc.GraphChannelPolicy{
×
4434
                                ID:                      r.Policy2ID.Int64,
×
4435
                                Version:                 r.Policy2Version.Int16,
×
4436
                                ChannelID:               r.GraphChannel.ID,
×
4437
                                NodeID:                  r.Policy2NodeID.Int64,
×
4438
                                Timelock:                r.Policy2Timelock.Int32,
×
4439
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4440
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4441
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4442
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4443
                                LastUpdate:              r.Policy2LastUpdate,
×
4444
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4445
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4446
                                Disabled:                r.Policy2Disabled,
×
4447
                                MessageFlags:            r.Policy2MessageFlags,
×
4448
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4449
                                Signature:               r.Policy2Signature,
×
4450
                        }
×
4451
                }
×
4452

4453
                return policy1, policy2, nil
×
4454

4455
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4456
                if r.Policy1ID.Valid {
×
4457
                        policy1 = &sqlc.GraphChannelPolicy{
×
4458
                                ID:                      r.Policy1ID.Int64,
×
4459
                                Version:                 r.Policy1Version.Int16,
×
4460
                                ChannelID:               r.GraphChannel.ID,
×
4461
                                NodeID:                  r.Policy1NodeID.Int64,
×
4462
                                Timelock:                r.Policy1Timelock.Int32,
×
4463
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4464
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4465
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4466
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4467
                                LastUpdate:              r.Policy1LastUpdate,
×
4468
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4469
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4470
                                Disabled:                r.Policy1Disabled,
×
4471
                                MessageFlags:            r.Policy1MessageFlags,
×
4472
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4473
                                Signature:               r.Policy1Signature,
×
4474
                        }
×
4475
                }
×
4476
                if r.Policy2ID.Valid {
×
4477
                        policy2 = &sqlc.GraphChannelPolicy{
×
4478
                                ID:                      r.Policy2ID.Int64,
×
4479
                                Version:                 r.Policy2Version.Int16,
×
4480
                                ChannelID:               r.GraphChannel.ID,
×
4481
                                NodeID:                  r.Policy2NodeID.Int64,
×
4482
                                Timelock:                r.Policy2Timelock.Int32,
×
4483
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4484
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4485
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4486
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4487
                                LastUpdate:              r.Policy2LastUpdate,
×
4488
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4489
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4490
                                Disabled:                r.Policy2Disabled,
×
4491
                                MessageFlags:            r.Policy2MessageFlags,
×
4492
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4493
                                Signature:               r.Policy2Signature,
×
4494
                        }
×
4495
                }
×
4496

4497
                return policy1, policy2, nil
×
4498

4499
        case sqlc.ListChannelsByNodeIDRow:
×
4500
                if r.Policy1ID.Valid {
×
4501
                        policy1 = &sqlc.GraphChannelPolicy{
×
4502
                                ID:                      r.Policy1ID.Int64,
×
4503
                                Version:                 r.Policy1Version.Int16,
×
4504
                                ChannelID:               r.GraphChannel.ID,
×
4505
                                NodeID:                  r.Policy1NodeID.Int64,
×
4506
                                Timelock:                r.Policy1Timelock.Int32,
×
4507
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4508
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4509
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4510
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4511
                                LastUpdate:              r.Policy1LastUpdate,
×
4512
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4513
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4514
                                Disabled:                r.Policy1Disabled,
×
4515
                                MessageFlags:            r.Policy1MessageFlags,
×
4516
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4517
                                Signature:               r.Policy1Signature,
×
4518
                        }
×
4519
                }
×
4520
                if r.Policy2ID.Valid {
×
4521
                        policy2 = &sqlc.GraphChannelPolicy{
×
4522
                                ID:                      r.Policy2ID.Int64,
×
4523
                                Version:                 r.Policy2Version.Int16,
×
4524
                                ChannelID:               r.GraphChannel.ID,
×
4525
                                NodeID:                  r.Policy2NodeID.Int64,
×
4526
                                Timelock:                r.Policy2Timelock.Int32,
×
4527
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4528
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4529
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4530
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4531
                                LastUpdate:              r.Policy2LastUpdate,
×
4532
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4533
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4534
                                Disabled:                r.Policy2Disabled,
×
4535
                                MessageFlags:            r.Policy2MessageFlags,
×
4536
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4537
                                Signature:               r.Policy2Signature,
×
4538
                        }
×
4539
                }
×
4540

4541
                return policy1, policy2, nil
×
4542

4543
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4544
                if r.Policy1ID.Valid {
×
4545
                        policy1 = &sqlc.GraphChannelPolicy{
×
4546
                                ID:                      r.Policy1ID.Int64,
×
4547
                                Version:                 r.Policy1Version.Int16,
×
4548
                                ChannelID:               r.GraphChannel.ID,
×
4549
                                NodeID:                  r.Policy1NodeID.Int64,
×
4550
                                Timelock:                r.Policy1Timelock.Int32,
×
4551
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4552
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4553
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4554
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4555
                                LastUpdate:              r.Policy1LastUpdate,
×
4556
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4557
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4558
                                Disabled:                r.Policy1Disabled,
×
4559
                                MessageFlags:            r.Policy1MessageFlags,
×
4560
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4561
                                Signature:               r.Policy1Signature,
×
4562
                        }
×
4563
                }
×
4564
                if r.Policy2ID.Valid {
×
4565
                        policy2 = &sqlc.GraphChannelPolicy{
×
4566
                                ID:                      r.Policy2ID.Int64,
×
4567
                                Version:                 r.Policy2Version.Int16,
×
4568
                                ChannelID:               r.GraphChannel.ID,
×
4569
                                NodeID:                  r.Policy2NodeID.Int64,
×
4570
                                Timelock:                r.Policy2Timelock.Int32,
×
4571
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4572
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4573
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4574
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4575
                                LastUpdate:              r.Policy2LastUpdate,
×
4576
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4577
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4578
                                Disabled:                r.Policy2Disabled,
×
4579
                                MessageFlags:            r.Policy2MessageFlags,
×
4580
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4581
                                Signature:               r.Policy2Signature,
×
4582
                        }
×
4583
                }
×
4584

4585
                return policy1, policy2, nil
×
4586
        default:
×
4587
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4588
                        "extractChannelPolicies: %T", r)
×
4589
        }
4590
}
4591

4592
// channelIDToBytes converts a channel ID (SCID) to a byte array
4593
// representation.
4594
func channelIDToBytes(channelID uint64) []byte {
×
4595
        var chanIDB [8]byte
×
4596
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4597

×
4598
        return chanIDB[:]
×
4599
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc