• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 17448539139

03 Sep 2025 11:19PM UTC coverage: 66.664% (+9.3%) from 57.321%
17448539139

Pull #10183

github

web-flow
Merge a58e329b2 into ea6cc8154
Pull Request #10183: brontide: eliminate all allocations from WriteMessage+Flush

47 of 55 new or added lines in 3 files covered. (85.45%)

2794 existing lines in 19 files now uncovered.

136134 of 204209 relevant lines covered (66.66%)

21430.41 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// ProtocolVersion is an enum that defines the gossip protocol version of a
35
// message.
36
type ProtocolVersion uint8
37

38
const (
39
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
40
        ProtocolV1 ProtocolVersion = 1
41
)
42

43
// String returns a string representation of the protocol version.
44
func (v ProtocolVersion) String() string {
×
45
        return fmt.Sprintf("V%d", v)
×
46
}
×
47

48
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
49
// execute queries against the SQL graph tables.
50
//
51
//nolint:ll,interfacebloat
52
type SQLQueries interface {
53
        /*
54
                Node queries.
55
        */
56
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
57
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
58
        GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
59
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
60
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
61
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
62
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
63
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
64
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
65
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
66
        DeleteNode(ctx context.Context, id int64) error
67

68
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
69
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
70
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
71
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
72

73
        UpsertNodeAddress(ctx context.Context, arg sqlc.UpsertNodeAddressParams) error
74
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
75
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
76
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
77

78
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
79
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
80
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
96
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
97
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
98
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
99
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
100
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
101
        GetChannelsByIDs(ctx context.Context, ids []int64) ([]sqlc.GetChannelsByIDsRow, error)
102
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
103
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
104
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
105
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
106
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
107
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
108
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
109
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
110
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
111
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
112
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
113
        DeleteChannels(ctx context.Context, ids []int64) error
114

115
        UpsertChannelExtraType(ctx context.Context, arg sqlc.UpsertChannelExtraTypeParams) error
116
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
117
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
118
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
119

120
        /*
121
                Channel Policy table queries.
122
        */
123
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
124
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
125
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
126

127
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
128
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
129
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
130

131
        /*
132
                Zombie index queries.
133
        */
134
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
135
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
136
        GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error)
137
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
138
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
139
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
140

141
        /*
142
                Prune log table queries.
143
        */
144
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
145
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
146
        GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]sqlc.GraphPruneLog, error)
147
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
148
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
149

150
        /*
151
                Closed SCID table queries.
152
        */
153
        InsertClosedChannel(ctx context.Context, scid []byte) error
154
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
155
        GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
156

157
        /*
158
                Migration specific queries.
159

160
                NOTE: these should not be used in code other than migrations.
161
                Once sqldbv2 is in place, these can be removed from this struct
162
                as then migrations will have their own dedicated queries
163
                structs.
164
        */
165
        InsertNodeMig(ctx context.Context, arg sqlc.InsertNodeMigParams) (int64, error)
166
        InsertChannelMig(ctx context.Context, arg sqlc.InsertChannelMigParams) (int64, error)
167
        InsertEdgePolicyMig(ctx context.Context, arg sqlc.InsertEdgePolicyMigParams) (int64, error)
168
}
169

170
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
171
// database operations.
172
type BatchedSQLQueries interface {
173
        SQLQueries
174
        sqldb.BatchedTx[SQLQueries]
175
}
176

177
// SQLStore is an implementation of the V1Store interface that uses a SQL
178
// database as the backend.
179
type SQLStore struct {
180
        cfg *SQLStoreConfig
181
        db  BatchedSQLQueries
182

183
        // cacheMu guards all caches (rejectCache and chanCache). If
184
        // this mutex will be acquired at the same time as the DB mutex then
185
        // the cacheMu MUST be acquired first to prevent deadlock.
186
        cacheMu     sync.RWMutex
187
        rejectCache *rejectCache
188
        chanCache   *channelCache
189

190
        chanScheduler batch.Scheduler[SQLQueries]
191
        nodeScheduler batch.Scheduler[SQLQueries]
192

193
        srcNodes  map[ProtocolVersion]*srcNodeInfo
194
        srcNodeMu sync.Mutex
195
}
196

197
// A compile-time assertion to ensure that SQLStore implements the V1Store
198
// interface.
199
var _ V1Store = (*SQLStore)(nil)
200

201
// SQLStoreConfig holds the configuration for the SQLStore.
202
type SQLStoreConfig struct {
203
        // ChainHash is the genesis hash for the chain that all the gossip
204
        // messages in this store are aimed at.
205
        ChainHash chainhash.Hash
206

207
        // QueryConfig holds configuration values for SQL queries.
208
        QueryCfg *sqldb.QueryConfig
209
}
210

211
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
212
// storage backend.
213
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
214
        options ...StoreOptionModifier) (*SQLStore, error) {
×
215

×
216
        opts := DefaultOptions()
×
217
        for _, o := range options {
×
218
                o(opts)
×
219
        }
×
220

221
        if opts.NoMigration {
×
222
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
223
                        "supported for SQL stores")
×
224
        }
×
225

226
        s := &SQLStore{
×
227
                cfg:         cfg,
×
228
                db:          db,
×
229
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
UNCOV
230
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
UNCOV
231
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
UNCOV
232
        }
×
UNCOV
233

×
UNCOV
234
        s.chanScheduler = batch.NewTimeScheduler(
×
UNCOV
235
                db, &s.cacheMu, opts.BatchCommitInterval,
×
UNCOV
236
        )
×
UNCOV
237
        s.nodeScheduler = batch.NewTimeScheduler(
×
UNCOV
238
                db, nil, opts.BatchCommitInterval,
×
239
        )
×
240

×
241
        return s, nil
×
242
}
243

244
// AddLightningNode adds a vertex/node to the graph database. If the node is not
245
// in the database from before, this will add a new, unconnected one to the
246
// graph. If it is present from before, this will update that node's
247
// information.
248
//
249
// NOTE: part of the V1Store interface.
250
func (s *SQLStore) AddLightningNode(ctx context.Context,
UNCOV
251
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
UNCOV
252

×
UNCOV
253
        r := &batch.Request[SQLQueries]{
×
UNCOV
254
                Opts: batch.NewSchedulerOptions(opts...),
×
UNCOV
255
                Do: func(queries SQLQueries) error {
×
UNCOV
256
                        _, err := upsertNode(ctx, queries, node)
×
UNCOV
257
                        return err
×
258
                },
×
259
        }
260

261
        return s.nodeScheduler.Execute(ctx, r)
×
262
}
263

264
// FetchLightningNode attempts to look up a target node by its identity public
265
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
266
// returned.
267
//
268
// NOTE: part of the V1Store interface.
269
func (s *SQLStore) FetchLightningNode(ctx context.Context,
UNCOV
270
        pubKey route.Vertex) (*models.LightningNode, error) {
×
271

×
UNCOV
272
        var node *models.LightningNode
×
UNCOV
273
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
274
                var err error
×
UNCOV
275
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, pubKey)
×
UNCOV
276

×
UNCOV
277
                return err
×
UNCOV
278
        }, sqldb.NoOpReset)
×
UNCOV
279
        if err != nil {
×
UNCOV
280
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
UNCOV
281
        }
×
282

283
        return node, nil
×
284
}
285

286
// HasLightningNode determines if the graph has a vertex identified by the
287
// target node identity public key. If the node exists in the database, a
288
// timestamp of when the data for the node was lasted updated is returned along
289
// with a true boolean. Otherwise, an empty time.Time is returned with a false
290
// boolean.
291
//
292
// NOTE: part of the V1Store interface.
293
func (s *SQLStore) HasLightningNode(ctx context.Context,
294
        pubKey [33]byte) (time.Time, bool, error) {
×
295

×
296
        var (
×
297
                exists     bool
×
298
                lastUpdate time.Time
×
299
        )
×
UNCOV
300
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
301
                dbNode, err := db.GetNodeByPubKey(
×
302
                        ctx, sqlc.GetNodeByPubKeyParams{
×
303
                                Version: int16(ProtocolV1),
×
304
                                PubKey:  pubKey[:],
×
305
                        },
×
UNCOV
306
                )
×
307
                if errors.Is(err, sql.ErrNoRows) {
×
UNCOV
308
                        return nil
×
309
                } else if err != nil {
×
310
                        return fmt.Errorf("unable to fetch node: %w", err)
×
311
                }
×
312

UNCOV
313
                exists = true
×
314

×
UNCOV
315
                if dbNode.LastUpdate.Valid {
×
UNCOV
316
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
UNCOV
317
                }
×
318

UNCOV
319
                return nil
×
320
        }, sqldb.NoOpReset)
UNCOV
321
        if err != nil {
×
UNCOV
322
                return time.Time{}, false,
×
323
                        fmt.Errorf("unable to fetch node: %w", err)
×
324
        }
×
325

326
        return lastUpdate, exists, nil
×
327
}
328

329
// AddrsForNode returns all known addresses for the target node public key
330
// that the graph DB is aware of. The returned boolean indicates if the
331
// given node is unknown to the graph DB or not.
332
//
333
// NOTE: part of the V1Store interface.
334
func (s *SQLStore) AddrsForNode(ctx context.Context,
335
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
336

×
337
        var (
×
338
                addresses []net.Addr
×
339
                known     bool
×
340
        )
×
UNCOV
341
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
342
                // First, check if the node exists and get its DB ID if it
×
343
                // does.
×
344
                dbID, err := db.GetNodeIDByPubKey(
×
345
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
346
                                Version: int16(ProtocolV1),
×
347
                                PubKey:  nodePub.SerializeCompressed(),
×
348
                        },
×
UNCOV
349
                )
×
350
                if errors.Is(err, sql.ErrNoRows) {
×
UNCOV
351
                        return nil
×
352
                }
×
353

354
                known = true
×
355

×
UNCOV
356
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
357
                if err != nil {
×
UNCOV
358
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
UNCOV
359
                                err)
×
UNCOV
360
                }
×
361

UNCOV
362
                return nil
×
363
        }, sqldb.NoOpReset)
UNCOV
364
        if err != nil {
×
365
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
366
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
367
        }
×
368

369
        return known, addresses, nil
×
370
}
371

372
// DeleteLightningNode starts a new database transaction to remove a vertex/node
373
// from the database according to the node's public key.
374
//
375
// NOTE: part of the V1Store interface.
376
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
UNCOV
377
        pubKey route.Vertex) error {
×
378

×
379
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
380
                res, err := db.DeleteNodeByPubKey(
×
381
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
UNCOV
382
                                Version: int16(ProtocolV1),
×
383
                                PubKey:  pubKey[:],
×
384
                        },
×
385
                )
×
386
                if err != nil {
×
387
                        return err
×
UNCOV
388
                }
×
389

UNCOV
390
                rows, err := res.RowsAffected()
×
391
                if err != nil {
×
392
                        return err
×
393
                }
×
394

395
                if rows == 0 {
×
UNCOV
396
                        return ErrGraphNodeNotFound
×
UNCOV
397
                } else if rows > 1 {
×
UNCOV
398
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
UNCOV
399
                }
×
400

UNCOV
401
                return err
×
402
        }, sqldb.NoOpReset)
403
        if err != nil {
×
404
                return fmt.Errorf("unable to delete node: %w", err)
×
405
        }
×
406

407
        return nil
×
408
}
409

410
// FetchNodeFeatures returns the features of the given node. If no features are
411
// known for the node, an empty feature vector is returned.
412
//
413
// NOTE: this is part of the graphdb.NodeTraverser interface.
414
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
415
        *lnwire.FeatureVector, error) {
×
416

×
417
        ctx := context.TODO()
×
418

×
419
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
420
}
×
421

422
// DisabledChannelIDs returns the channel ids of disabled channels.
423
// A channel is disabled when two of the associated ChanelEdgePolicies
424
// have their disabled bit on.
425
//
426
// NOTE: part of the V1Store interface.
427
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
428
        var (
×
429
                ctx     = context.TODO()
×
UNCOV
430
                chanIDs []uint64
×
431
        )
×
432
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
433
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
434
                if err != nil {
×
UNCOV
435
                        return fmt.Errorf("unable to fetch disabled "+
×
436
                                "channels: %w", err)
×
UNCOV
437
                }
×
438

UNCOV
439
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
UNCOV
440

×
UNCOV
441
                return nil
×
442
        }, sqldb.NoOpReset)
443
        if err != nil {
×
444
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
445
                        err)
×
446
        }
×
447

448
        return chanIDs, nil
×
449
}
450

451
// LookupAlias attempts to return the alias as advertised by the target node.
452
//
453
// NOTE: part of the V1Store interface.
454
func (s *SQLStore) LookupAlias(ctx context.Context,
455
        pub *btcec.PublicKey) (string, error) {
×
456

×
457
        var alias string
×
UNCOV
458
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
459
                dbNode, err := db.GetNodeByPubKey(
×
460
                        ctx, sqlc.GetNodeByPubKeyParams{
×
461
                                Version: int16(ProtocolV1),
×
UNCOV
462
                                PubKey:  pub.SerializeCompressed(),
×
463
                        },
×
464
                )
×
465
                if errors.Is(err, sql.ErrNoRows) {
×
UNCOV
466
                        return ErrNodeAliasNotFound
×
467
                } else if err != nil {
×
468
                        return fmt.Errorf("unable to fetch node: %w", err)
×
469
                }
×
470

471
                if !dbNode.Alias.Valid {
×
UNCOV
472
                        return ErrNodeAliasNotFound
×
UNCOV
473
                }
×
474

UNCOV
475
                alias = dbNode.Alias.String
×
UNCOV
476

×
UNCOV
477
                return nil
×
478
        }, sqldb.NoOpReset)
UNCOV
479
        if err != nil {
×
UNCOV
480
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
481
        }
×
482

483
        return alias, nil
×
484
}
485

486
// SourceNode returns the source node of the graph. The source node is treated
487
// as the center node within a star-graph. This method may be used to kick off
488
// a path finding algorithm in order to explore the reachability of another
489
// node based off the source node.
490
//
491
// NOTE: part of the V1Store interface.
492
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
493
        error) {
×
UNCOV
494

×
495
        var node *models.LightningNode
×
496
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
497
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
UNCOV
498
                if err != nil {
×
499
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
UNCOV
500
                                err)
×
UNCOV
501
                }
×
502

UNCOV
503
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, nodePub)
×
UNCOV
504

×
UNCOV
505
                return err
×
506
        }, sqldb.NoOpReset)
UNCOV
507
        if err != nil {
×
508
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
509
        }
×
510

511
        return node, nil
×
512
}
513

514
// SetSourceNode sets the source node within the graph database. The source
515
// node is to be used as the center of a star-graph within path finding
516
// algorithms.
517
//
518
// NOTE: part of the V1Store interface.
519
func (s *SQLStore) SetSourceNode(ctx context.Context,
520
        node *models.LightningNode) error {
×
521

×
522
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
523
                id, err := upsertNode(ctx, db, node)
×
524
                if err != nil {
×
525
                        return fmt.Errorf("unable to upsert source node: %w",
×
526
                                err)
×
527
                }
×
528

529
                // Make sure that if a source node for this version is already
530
                // set, then the ID is the same as the one we are about to set.
UNCOV
531
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
UNCOV
532
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
533
                        return fmt.Errorf("unable to fetch source node: %w",
×
UNCOV
534
                                err)
×
UNCOV
535
                } else if err == nil {
×
UNCOV
536
                        if dbSourceNodeID != id {
×
UNCOV
537
                                return fmt.Errorf("v1 source node already "+
×
UNCOV
538
                                        "set to a different node: %d vs %d",
×
UNCOV
539
                                        dbSourceNodeID, id)
×
UNCOV
540
                        }
×
541

UNCOV
542
                        return nil
×
543
                }
544

545
                return db.AddSourceNode(ctx, id)
×
546
        }, sqldb.NoOpReset)
547
}
548

549
// NodeUpdatesInHorizon returns all the known lightning node which have an
550
// update timestamp within the passed range. This method can be used by two
551
// nodes to quickly determine if they have the same set of up to date node
552
// announcements.
553
//
554
// NOTE: This is part of the V1Store interface.
555
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
556
        endTime time.Time) ([]models.LightningNode, error) {
×
557

×
558
        ctx := context.TODO()
×
UNCOV
559

×
560
        var nodes []models.LightningNode
×
561
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
562
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
563
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
564
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
565
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
566
                        },
×
UNCOV
567
                )
×
568
                if err != nil {
×
569
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
570
                }
×
571

572
                err = forEachNodeInBatch(
×
UNCOV
573
                        ctx, s.cfg.QueryCfg, db, dbNodes,
×
574
                        func(_ int64, node *models.LightningNode) error {
×
575
                                nodes = append(nodes, *node)
×
576

×
UNCOV
577
                                return nil
×
578
                        },
×
579
                )
UNCOV
580
                if err != nil {
×
UNCOV
581
                        return fmt.Errorf("unable to build nodes: %w", err)
×
UNCOV
582
                }
×
583

UNCOV
584
                return nil
×
585
        }, sqldb.NoOpReset)
UNCOV
586
        if err != nil {
×
UNCOV
587
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
UNCOV
588
        }
×
589

590
        return nodes, nil
×
591
}
592

593
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
594
// undirected edge from the two target nodes are created. The information stored
595
// denotes the static attributes of the channel, such as the channelID, the keys
596
// involved in creation of the channel, and the set of features that the channel
597
// supports. The chanPoint and chanID are used to uniquely identify the edge
598
// globally within the database.
599
//
600
// NOTE: part of the V1Store interface.
601
func (s *SQLStore) AddChannelEdge(ctx context.Context,
602
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
603

×
604
        var alreadyExists bool
×
605
        r := &batch.Request[SQLQueries]{
×
606
                Opts: batch.NewSchedulerOptions(opts...),
×
607
                Reset: func() {
×
608
                        alreadyExists = false
×
609
                },
×
610
                Do: func(tx SQLQueries) error {
×
611
                        chanIDB := channelIDToBytes(edge.ChannelID)
×
612

×
613
                        // Make sure that the channel doesn't already exist. We
×
614
                        // do this explicitly instead of relying on catching a
×
615
                        // unique constraint error because relying on SQL to
×
616
                        // throw that error would abort the entire batch of
×
617
                        // transactions.
×
618
                        _, err := tx.GetChannelBySCID(
×
UNCOV
619
                                ctx, sqlc.GetChannelBySCIDParams{
×
620
                                        Scid:    chanIDB,
×
621
                                        Version: int16(ProtocolV1),
×
622
                                },
×
UNCOV
623
                        )
×
624
                        if err == nil {
×
625
                                alreadyExists = true
×
626
                                return nil
×
627
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
628
                                return fmt.Errorf("unable to fetch channel: %w",
×
629
                                        err)
×
630
                        }
×
631

632
                        return insertChannel(ctx, tx, edge)
×
633
                },
UNCOV
634
                OnCommit: func(err error) error {
×
UNCOV
635
                        switch {
×
UNCOV
636
                        case err != nil:
×
UNCOV
637
                                return err
×
638
                        case alreadyExists:
×
UNCOV
639
                                return ErrEdgeAlreadyExist
×
UNCOV
640
                        default:
×
UNCOV
641
                                s.rejectCache.remove(edge.ChannelID)
×
UNCOV
642
                                s.chanCache.remove(edge.ChannelID)
×
UNCOV
643
                                return nil
×
644
                        }
645
                },
646
        }
647

648
        return s.chanScheduler.Execute(ctx, r)
×
649
}
650

651
// HighestChanID returns the "highest" known channel ID in the channel graph.
652
// This represents the "newest" channel from the PoV of the chain. This method
653
// can be used by peers to quickly determine if their graphs are in sync.
654
//
655
// NOTE: This is part of the V1Store interface.
UNCOV
656
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
657
        var highestChanID uint64
×
658
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
659
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
UNCOV
660
                if errors.Is(err, sql.ErrNoRows) {
×
661
                        return nil
×
662
                } else if err != nil {
×
663
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
UNCOV
664
                                err)
×
665
                }
×
666

UNCOV
667
                highestChanID = byteOrder.Uint64(chanID)
×
UNCOV
668

×
UNCOV
669
                return nil
×
670
        }, sqldb.NoOpReset)
UNCOV
671
        if err != nil {
×
UNCOV
672
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
UNCOV
673
        }
×
674

UNCOV
675
        return highestChanID, nil
×
676
}
677

678
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
679
// within the database for the referenced channel. The `flags` attribute within
680
// the ChannelEdgePolicy determines which of the directed edges are being
681
// updated. If the flag is 1, then the first node's information is being
682
// updated, otherwise it's the second node's information. The node ordering is
683
// determined by the lexicographical ordering of the identity public keys of the
684
// nodes on either side of the channel.
685
//
686
// NOTE: part of the V1Store interface.
687
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
688
        edge *models.ChannelEdgePolicy,
689
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
690

×
691
        var (
×
692
                isUpdate1    bool
×
693
                edgeNotFound bool
×
694
                from, to     route.Vertex
×
695
        )
×
696

×
697
        r := &batch.Request[SQLQueries]{
×
698
                Opts: batch.NewSchedulerOptions(opts...),
×
699
                Reset: func() {
×
700
                        isUpdate1 = false
×
UNCOV
701
                        edgeNotFound = false
×
UNCOV
702
                },
×
UNCOV
703
                Do: func(tx SQLQueries) error {
×
704
                        var err error
×
705
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
706
                                ctx, tx, edge,
×
707
                        )
×
UNCOV
708
                        if err != nil {
×
709
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
UNCOV
710
                        }
×
711

712
                        // Silence ErrEdgeNotFound so that the batch can
713
                        // succeed, but propagate the error via local state.
714
                        if errors.Is(err, ErrEdgeNotFound) {
×
715
                                edgeNotFound = true
×
716
                                return nil
×
717
                        }
×
718

719
                        return err
×
720
                },
UNCOV
721
                OnCommit: func(err error) error {
×
UNCOV
722
                        switch {
×
UNCOV
723
                        case err != nil:
×
724
                                return err
×
725
                        case edgeNotFound:
×
726
                                return ErrEdgeNotFound
×
UNCOV
727
                        default:
×
UNCOV
728
                                s.updateEdgeCache(edge, isUpdate1)
×
UNCOV
729
                                return nil
×
730
                        }
731
                },
732
        }
733

734
        err := s.chanScheduler.Execute(ctx, r)
×
735

×
736
        return from, to, err
×
737
}
738

739
// updateEdgeCache updates our reject and channel caches with the new
740
// edge policy information.
741
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
742
        isUpdate1 bool) {
×
743

×
744
        // If an entry for this channel is found in reject cache, we'll modify
×
UNCOV
745
        // the entry with the updated timestamp for the direction that was just
×
UNCOV
746
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
UNCOV
747
        // during the next query for this edge.
×
UNCOV
748
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
UNCOV
749
                if isUpdate1 {
×
UNCOV
750
                        entry.upd1Time = e.LastUpdate.Unix()
×
751
                } else {
×
752
                        entry.upd2Time = e.LastUpdate.Unix()
×
753
                }
×
754
                s.rejectCache.insert(e.ChannelID, entry)
×
755
        }
756

757
        // If an entry for this channel is found in channel cache, we'll modify
758
        // the entry with the updated policy for the direction that was just
759
        // written. If the edge doesn't exist, we'll defer loading the info and
760
        // policies and lazily read from disk during the next query.
UNCOV
761
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
UNCOV
762
                if isUpdate1 {
×
UNCOV
763
                        channel.Policy1 = e
×
UNCOV
764
                } else {
×
UNCOV
765
                        channel.Policy2 = e
×
UNCOV
766
                }
×
UNCOV
767
                s.chanCache.insert(e.ChannelID, channel)
×
768
        }
769
}
770

771
// ForEachSourceNodeChannel iterates through all channels of the source node,
772
// executing the passed callback on each. The call-back is provided with the
773
// channel's outpoint, whether we have a policy for the channel and the channel
774
// peer's node information.
775
//
776
// NOTE: part of the V1Store interface.
777
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
778
        cb func(chanPoint wire.OutPoint, havePolicy bool,
779
                otherNode *models.LightningNode) error, reset func()) error {
×
780

×
781
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
782
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
783
                if err != nil {
×
784
                        return fmt.Errorf("unable to fetch source node: %w",
×
785
                                err)
×
786
                }
×
787

788
                return forEachNodeChannel(
×
789
                        ctx, db, s.cfg, nodeID,
×
790
                        func(info *models.ChannelEdgeInfo,
×
791
                                outPolicy *models.ChannelEdgePolicy,
×
792
                                _ *models.ChannelEdgePolicy) error {
×
793

×
794
                                // Fetch the other node.
×
795
                                var (
×
796
                                        otherNodePub [33]byte
×
797
                                        node1        = info.NodeKey1Bytes
×
UNCOV
798
                                        node2        = info.NodeKey2Bytes
×
UNCOV
799
                                )
×
800
                                switch {
×
801
                                case bytes.Equal(node1[:], nodePub[:]):
×
802
                                        otherNodePub = node2
×
803
                                case bytes.Equal(node2[:], nodePub[:]):
×
804
                                        otherNodePub = node1
×
805
                                default:
×
806
                                        return fmt.Errorf("node not " +
×
807
                                                "participating in this channel")
×
808
                                }
809

810
                                _, otherNode, err := getNodeByPubKey(
×
811
                                        ctx, s.cfg.QueryCfg, db, otherNodePub,
×
812
                                )
×
UNCOV
813
                                if err != nil {
×
UNCOV
814
                                        return fmt.Errorf("unable to fetch "+
×
UNCOV
815
                                                "other node(%x): %w",
×
UNCOV
816
                                                otherNodePub, err)
×
UNCOV
817
                                }
×
818

UNCOV
819
                                return cb(
×
UNCOV
820
                                        info.ChannelPoint, outPolicy != nil,
×
UNCOV
821
                                        otherNode,
×
UNCOV
822
                                )
×
823
                        },
824
                )
825
        }, reset)
826
}
827

828
// ForEachNode iterates through all the stored vertices/nodes in the graph,
829
// executing the passed callback with each node encountered. If the callback
830
// returns an error, then the transaction is aborted and the iteration stops
831
// early.
832
//
833
// NOTE: part of the V1Store interface.
834
func (s *SQLStore) ForEachNode(ctx context.Context,
UNCOV
835
        cb func(node *models.LightningNode) error, reset func()) error {
×
UNCOV
836

×
UNCOV
837
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
838
                return forEachNodePaginated(
×
UNCOV
839
                        ctx, s.cfg.QueryCfg, db,
×
UNCOV
840
                        ProtocolV1, func(_ context.Context, _ int64,
×
UNCOV
841
                                node *models.LightningNode) error {
×
UNCOV
842

×
UNCOV
843
                                return cb(node)
×
UNCOV
844
                        },
×
845
                )
846
        }, reset)
847
}
848

849
// ForEachNodeDirectedChannel iterates through all channels of a given node,
850
// executing the passed callback on the directed edge representing the channel
851
// and its incoming policy. If the callback returns an error, then the iteration
852
// is halted with the error propagated back up to the caller.
853
//
854
// Unknown policies are passed into the callback as nil values.
855
//
856
// NOTE: this is part of the graphdb.NodeTraverser interface.
857
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
UNCOV
858
        cb func(channel *DirectedChannel) error, reset func()) error {
×
UNCOV
859

×
UNCOV
860
        var ctx = context.TODO()
×
UNCOV
861

×
UNCOV
862
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
863
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
864
        }, reset)
×
865
}
866

867
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
868
// graph, executing the passed callback with each node encountered. If the
869
// callback returns an error, then the transaction is aborted and the iteration
870
// stops early.
871
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
872
        cb func(route.Vertex, *lnwire.FeatureVector) error,
UNCOV
873
        reset func()) error {
×
UNCOV
874

×
875
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
876
                return forEachNodeCacheable(
×
877
                        ctx, s.cfg.QueryCfg, db,
×
UNCOV
878
                        func(_ int64, nodePub route.Vertex,
×
879
                                features *lnwire.FeatureVector) error {
×
UNCOV
880

×
UNCOV
881
                                return cb(nodePub, features)
×
UNCOV
882
                        },
×
883
                )
884
        }, reset)
UNCOV
885
        if err != nil {
×
UNCOV
886
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
UNCOV
887
        }
×
888

UNCOV
889
        return nil
×
890
}
891

892
// ForEachNodeChannel iterates through all channels of the given node,
893
// executing the passed callback with an edge info structure and the policies
894
// of each end of the channel. The first edge policy is the outgoing edge *to*
895
// the connecting node, while the second is the incoming edge *from* the
896
// connecting node. If the callback returns an error, then the iteration is
897
// halted with the error propagated back up to the caller.
898
//
899
// Unknown policies are passed into the callback as nil values.
900
//
901
// NOTE: part of the V1Store interface.
902
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
903
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
904
                *models.ChannelEdgePolicy) error, reset func()) error {
×
905

×
906
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
907
                dbNode, err := db.GetNodeByPubKey(
×
UNCOV
908
                        ctx, sqlc.GetNodeByPubKeyParams{
×
909
                                Version: int16(ProtocolV1),
×
UNCOV
910
                                PubKey:  nodePub[:],
×
UNCOV
911
                        },
×
UNCOV
912
                )
×
UNCOV
913
                if errors.Is(err, sql.ErrNoRows) {
×
UNCOV
914
                        return nil
×
UNCOV
915
                } else if err != nil {
×
UNCOV
916
                        return fmt.Errorf("unable to fetch node: %w", err)
×
UNCOV
917
                }
×
918

919
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
920
        }, reset)
921
}
922

923
// ChanUpdatesInHorizon returns all the known channel edges which have at least
924
// one edge that has an update timestamp within the specified horizon.
925
//
926
// NOTE: This is part of the V1Store interface.
927
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
928
        endTime time.Time) ([]ChannelEdge, error) {
×
929

×
930
        s.cacheMu.Lock()
×
931
        defer s.cacheMu.Unlock()
×
932

×
933
        var (
×
934
                ctx = context.TODO()
×
935
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
936
                // an additional map to keep track of the edges already seen to
×
937
                // prevent re-adding it.
×
938
                edgesSeen    = make(map[uint64]struct{})
×
939
                edgesToCache = make(map[uint64]ChannelEdge)
×
940
                edges        []ChannelEdge
×
941
                hits         int
×
942
        )
×
943

×
944
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
945
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
946
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
947
                                Version:   int16(ProtocolV1),
×
948
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
UNCOV
949
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
UNCOV
950
                        },
×
UNCOV
951
                )
×
UNCOV
952
                if err != nil {
×
953
                        return err
×
954
                }
×
955

956
                if len(rows) == 0 {
×
957
                        return nil
×
958
                }
×
959

960
                // We'll pre-allocate the slices and maps here with a best
961
                // effort size in order to avoid unnecessary allocations later
962
                // on.
963
                uncachedRows := make(
×
964
                        []sqlc.GetChannelsByPolicyLastUpdateRangeRow, 0,
×
965
                        len(rows),
×
966
                )
×
967
                edgesToCache = make(map[uint64]ChannelEdge, len(rows))
×
968
                edgesSeen = make(map[uint64]struct{}, len(rows))
×
UNCOV
969
                edges = make([]ChannelEdge, 0, len(rows))
×
970

×
971
                // Separate cached from non-cached channels since we will only
×
972
                // batch load the data for the ones we haven't cached yet.
×
973
                for _, row := range rows {
×
974
                        chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid)
×
975

×
976
                        // Skip duplicates.
×
UNCOV
977
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
UNCOV
978
                                continue
×
979
                        }
980
                        edgesSeen[chanIDInt] = struct{}{}
×
UNCOV
981

×
UNCOV
982
                        // Check cache first.
×
UNCOV
983
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
984
                                hits++
×
985
                                edges = append(edges, channel)
×
986
                                continue
×
987
                        }
988

989
                        // Mark this row as one we need to batch load data for.
990
                        uncachedRows = append(uncachedRows, row)
×
991
                }
992

993
                // If there are no uncached rows, then we can return early.
994
                if len(uncachedRows) == 0 {
×
995
                        return nil
×
UNCOV
996
                }
×
997

998
                // Batch load data for all uncached channels.
999
                newEdges, err := batchBuildChannelEdges(
×
UNCOV
1000
                        ctx, s.cfg, db, uncachedRows,
×
1001
                )
×
1002
                if err != nil {
×
1003
                        return fmt.Errorf("unable to batch build channel "+
×
UNCOV
1004
                                "edges: %w", err)
×
UNCOV
1005
                }
×
1006

1007
                edges = append(edges, newEdges...)
×
1008

×
UNCOV
1009
                return nil
×
1010
        }, sqldb.NoOpReset)
1011
        if err != nil {
×
1012
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1013
        }
×
1014

1015
        // Insert any edges loaded from disk into the cache.
1016
        for chanid, channel := range edgesToCache {
×
UNCOV
1017
                s.chanCache.insert(chanid, channel)
×
1018
        }
×
1019

UNCOV
1020
        if len(edges) > 0 {
×
UNCOV
1021
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
UNCOV
1022
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
UNCOV
1023
        } else {
×
UNCOV
1024
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
UNCOV
1025
                        "horizon (%s, %s)", startTime, endTime)
×
UNCOV
1026
        }
×
1027

UNCOV
1028
        return edges, nil
×
1029
}
1030

1031
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1032
// data to the call-back. If withAddrs is true, then the call-back will also be
1033
// provided with the addresses associated with the node. The address retrieval
1034
// result in an additional round-trip to the database, so it should only be used
1035
// if the addresses are actually needed.
1036
//
1037
// NOTE: part of the V1Store interface.
1038
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1039
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1040
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1041

×
1042
        type nodeCachedBatchData struct {
×
1043
                features      map[int64][]int
×
1044
                addrs         map[int64][]nodeAddress
×
1045
                chanBatchData *batchChannelData
×
1046
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1047
        }
×
1048

×
1049
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1050
                // pageQueryFunc is used to query the next page of nodes.
×
1051
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
UNCOV
1052
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
UNCOV
1053

×
UNCOV
1054
                        return db.ListNodeIDsAndPubKeys(
×
1055
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1056
                                        Version: int16(ProtocolV1),
×
1057
                                        ID:      lastID,
×
1058
                                        Limit:   limit,
×
1059
                                },
×
1060
                        )
×
1061
                }
×
1062

1063
                // batchDataFunc is then used to batch load the data required
1064
                // for each page of nodes.
1065
                batchDataFunc := func(ctx context.Context,
×
UNCOV
1066
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
UNCOV
1067

×
1068
                        // Batch load node features.
×
1069
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1070
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1071
                        )
×
1072
                        if err != nil {
×
1073
                                return nil, fmt.Errorf("unable to batch load "+
×
1074
                                        "node features: %w", err)
×
1075
                        }
×
1076

1077
                        // Maybe fetch the node's addresses if requested.
UNCOV
1078
                        var nodeAddrs map[int64][]nodeAddress
×
UNCOV
1079
                        if withAddrs {
×
UNCOV
1080
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
UNCOV
1081
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1082
                                )
×
1083
                                if err != nil {
×
1084
                                        return nil, fmt.Errorf("unable to "+
×
1085
                                                "batch load node "+
×
1086
                                                "addresses: %w", err)
×
1087
                                }
×
1088
                        }
1089

1090
                        // Batch load ALL unique channels for ALL nodes in this
1091
                        // page.
1092
                        allChannels, err := db.ListChannelsForNodeIDs(
×
UNCOV
1093
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
UNCOV
1094
                                        Version:  int16(ProtocolV1),
×
1095
                                        Node1Ids: nodeIDs,
×
1096
                                        Node2Ids: nodeIDs,
×
1097
                                },
×
1098
                        )
×
1099
                        if err != nil {
×
1100
                                return nil, fmt.Errorf("unable to batch "+
×
1101
                                        "fetch channels for nodes: %w", err)
×
1102
                        }
×
1103

1104
                        // Deduplicate channels and collect IDs.
1105
                        var (
×
1106
                                allChannelIDs []int64
×
1107
                                allPolicyIDs  []int64
×
1108
                        )
×
1109
                        uniqueChannels := make(
×
UNCOV
1110
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
UNCOV
1111
                        )
×
1112

×
1113
                        for _, channel := range allChannels {
×
1114
                                channelID := channel.GraphChannel.ID
×
1115

×
1116
                                // Only process each unique channel once.
×
1117
                                _, exists := uniqueChannels[channelID]
×
1118
                                if exists {
×
1119
                                        continue
×
1120
                                }
1121

1122
                                uniqueChannels[channelID] = channel
×
1123
                                allChannelIDs = append(allChannelIDs, channelID)
×
1124

×
1125
                                if channel.Policy1ID.Valid {
×
1126
                                        allPolicyIDs = append(
×
UNCOV
1127
                                                allPolicyIDs,
×
UNCOV
1128
                                                channel.Policy1ID.Int64,
×
UNCOV
1129
                                        )
×
1130
                                }
×
1131
                                if channel.Policy2ID.Valid {
×
1132
                                        allPolicyIDs = append(
×
1133
                                                allPolicyIDs,
×
1134
                                                channel.Policy2ID.Int64,
×
1135
                                        )
×
1136
                                }
×
1137
                        }
1138

1139
                        // Batch load channel data for all unique channels.
UNCOV
1140
                        channelBatchData, err := batchLoadChannelData(
×
1141
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1142
                                allPolicyIDs,
×
1143
                        )
×
1144
                        if err != nil {
×
UNCOV
1145
                                return nil, fmt.Errorf("unable to batch "+
×
1146
                                        "load channel data: %w", err)
×
1147
                        }
×
1148

1149
                        // Create map of node ID to channels that involve this
1150
                        // node.
1151
                        nodeIDSet := make(map[int64]bool)
×
1152
                        for _, nodeID := range nodeIDs {
×
1153
                                nodeIDSet[nodeID] = true
×
1154
                        }
×
1155

1156
                        nodeChannelMap := make(
×
1157
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1158
                        )
×
1159
                        for _, channel := range uniqueChannels {
×
1160
                                // Add channel to both nodes if they're in our
×
1161
                                // current page.
×
1162
                                node1 := channel.GraphChannel.NodeID1
×
1163
                                if nodeIDSet[node1] {
×
UNCOV
1164
                                        nodeChannelMap[node1] = append(
×
UNCOV
1165
                                                nodeChannelMap[node1], channel,
×
1166
                                        )
×
1167
                                }
×
1168
                                node2 := channel.GraphChannel.NodeID2
×
1169
                                if nodeIDSet[node2] {
×
1170
                                        nodeChannelMap[node2] = append(
×
1171
                                                nodeChannelMap[node2], channel,
×
UNCOV
1172
                                        )
×
UNCOV
1173
                                }
×
1174
                        }
1175

1176
                        return &nodeCachedBatchData{
×
1177
                                features:      nodeFeatures,
×
1178
                                addrs:         nodeAddrs,
×
1179
                                chanBatchData: channelBatchData,
×
1180
                                chanMap:       nodeChannelMap,
×
1181
                        }, nil
×
1182
                }
1183

1184
                // processItem is used to process each node in the current page.
1185
                processItem := func(ctx context.Context,
×
UNCOV
1186
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
UNCOV
1187
                        batchData *nodeCachedBatchData) error {
×
1188

×
1189
                        // Build feature vector for this node.
×
1190
                        fv := lnwire.EmptyFeatureVector()
×
1191
                        features, exists := batchData.features[nodeData.ID]
×
1192
                        if exists {
×
1193
                                for _, bit := range features {
×
1194
                                        fv.Set(lnwire.FeatureBit(bit))
×
1195
                                }
×
1196
                        }
1197

1198
                        var nodePub route.Vertex
×
1199
                        copy(nodePub[:], nodeData.PubKey)
×
1200

×
1201
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1202

×
1203
                        toNodeCallback := func() route.Vertex {
×
1204
                                return nodePub
×
1205
                        }
×
1206

1207
                        // Build cached channels map for this node.
UNCOV
1208
                        channels := make(map[uint64]*DirectedChannel)
×
1209
                        for _, channelRow := range nodeChannels {
×
UNCOV
1210
                                directedChan, err := buildDirectedChannel(
×
UNCOV
1211
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1212
                                        channelRow, batchData.chanBatchData, fv,
×
1213
                                        toNodeCallback,
×
1214
                                )
×
1215
                                if err != nil {
×
1216
                                        return err
×
1217
                                }
×
1218

UNCOV
1219
                                channels[directedChan.ChannelID] = directedChan
×
1220
                        }
1221

UNCOV
1222
                        addrs, err := buildNodeAddresses(
×
1223
                                batchData.addrs[nodeData.ID],
×
1224
                        )
×
1225
                        if err != nil {
×
1226
                                return fmt.Errorf("unable to build node "+
×
1227
                                        "addresses: %w", err)
×
UNCOV
1228
                        }
×
1229

1230
                        return cb(ctx, nodePub, addrs, channels)
×
1231
                }
1232

UNCOV
1233
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
UNCOV
1234
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
UNCOV
1235
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
UNCOV
1236
                                return node.ID
×
UNCOV
1237
                        },
×
1238
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
UNCOV
1239
                                error) {
×
UNCOV
1240

×
UNCOV
1241
                                return node.ID, nil
×
UNCOV
1242
                        },
×
1243
                        batchDataFunc, processItem,
1244
                )
1245
        }, reset)
1246
}
1247

1248
// ForEachChannelCacheable iterates through all the channel edges stored
1249
// within the graph and invokes the passed callback for each edge. The
1250
// callback takes two edges as since this is a directed graph, both the
1251
// in/out edges are visited. If the callback returns an error, then the
1252
// transaction is aborted and the iteration stops early.
1253
//
1254
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1255
// pointer for that particular channel edge routing policy will be
1256
// passed into the callback.
1257
//
1258
// NOTE: this method is like ForEachChannel but fetches only the data
1259
// required for the graph cache.
1260
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1261
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1262
        reset func()) error {
×
1263

×
1264
        ctx := context.TODO()
×
UNCOV
1265

×
1266
        handleChannel := func(_ context.Context,
×
1267
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1268

×
1269
                node1, node2, err := buildNodeVertices(
×
1270
                        row.Node1Pubkey, row.Node2Pubkey,
×
1271
                )
×
1272
                if err != nil {
×
1273
                        return err
×
UNCOV
1274
                }
×
1275

1276
                edge := buildCacheableChannelInfo(
×
1277
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1278
                )
×
1279

×
1280
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
UNCOV
1281
                if err != nil {
×
1282
                        return err
×
UNCOV
1283
                }
×
1284

1285
                pol1, pol2, err := buildCachedChanPolicies(
×
1286
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1287
                )
×
1288
                if err != nil {
×
1289
                        return err
×
UNCOV
1290
                }
×
1291

1292
                return cb(edge, pol1, pol2)
×
1293
        }
1294

1295
        extractCursor := func(
×
1296
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1297

×
1298
                return row.ID
×
1299
        }
×
1300

1301
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1302
                //nolint:ll
×
1303
                queryFunc := func(ctx context.Context, lastID int64,
×
1304
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
UNCOV
1305
                        error) {
×
1306

×
1307
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1308
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1309
                                        Version: int16(ProtocolV1),
×
UNCOV
1310
                                        ID:      lastID,
×
UNCOV
1311
                                        Limit:   limit,
×
UNCOV
1312
                                },
×
UNCOV
1313
                        )
×
UNCOV
1314
                }
×
1315

UNCOV
1316
                return sqldb.ExecutePaginatedQuery(
×
UNCOV
1317
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
UNCOV
1318
                        extractCursor, handleChannel,
×
UNCOV
1319
                )
×
1320
        }, reset)
1321
}
1322

1323
// ForEachChannel iterates through all the channel edges stored within the
1324
// graph and invokes the passed callback for each edge. The callback takes two
1325
// edges as since this is a directed graph, both the in/out edges are visited.
1326
// If the callback returns an error, then the transaction is aborted and the
1327
// iteration stops early.
1328
//
1329
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1330
// for that particular channel edge routing policy will be passed into the
1331
// callback.
1332
//
1333
// NOTE: part of the V1Store interface.
1334
func (s *SQLStore) ForEachChannel(ctx context.Context,
1335
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
UNCOV
1336
                *models.ChannelEdgePolicy) error, reset func()) error {
×
UNCOV
1337

×
UNCOV
1338
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
1339
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
UNCOV
1340
        }, reset)
×
1341
}
1342

1343
// FilterChannelRange returns the channel ID's of all known channels which were
1344
// mined in a block height within the passed range. The channel IDs are grouped
1345
// by their common block height. This method can be used to quickly share with a
1346
// peer the set of channels we know of within a particular range to catch them
1347
// up after a period of time offline. If withTimestamps is true then the
1348
// timestamp info of the latest received channel update messages of the channel
1349
// will be included in the response.
1350
//
1351
// NOTE: This is part of the V1Store interface.
1352
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1353
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1354

×
1355
        var (
×
1356
                ctx       = context.TODO()
×
1357
                startSCID = &lnwire.ShortChannelID{
×
1358
                        BlockHeight: startHeight,
×
1359
                }
×
1360
                endSCID = lnwire.ShortChannelID{
×
1361
                        BlockHeight: endHeight,
×
1362
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1363
                        TxPosition:  math.MaxUint16,
×
1364
                }
×
1365
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1366
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1367
        )
×
1368

×
1369
        // 1) get all channels where channelID is between start and end chan ID.
×
1370
        // 2) skip if not public (ie, no channel_proof)
×
1371
        // 3) collect that channel.
×
1372
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1373
        //    and add those timestamps to the collected channel.
×
1374
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1375
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
1376
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1377
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1378
                                StartScid: chanIDStart,
×
1379
                                EndScid:   chanIDEnd,
×
1380
                        },
×
1381
                )
×
1382
                if err != nil {
×
1383
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1384
                                err)
×
1385
                }
×
1386

1387
                for _, dbChan := range dbChans {
×
1388
                        cid := lnwire.NewShortChanIDFromInt(
×
1389
                                byteOrder.Uint64(dbChan.Scid),
×
1390
                        )
×
1391
                        chanInfo := NewChannelUpdateInfo(
×
UNCOV
1392
                                cid, time.Time{}, time.Time{},
×
UNCOV
1393
                        )
×
UNCOV
1394

×
1395
                        if !withTimestamps {
×
1396
                                channelsPerBlock[cid.BlockHeight] = append(
×
1397
                                        channelsPerBlock[cid.BlockHeight],
×
1398
                                        chanInfo,
×
1399
                                )
×
1400

×
1401
                                continue
×
1402
                        }
1403

1404
                        //nolint:ll
1405
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1406
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1407
                                        Version:   int16(ProtocolV1),
×
1408
                                        ChannelID: dbChan.ID,
×
1409
                                        NodeID:    dbChan.NodeID1,
×
UNCOV
1410
                                },
×
UNCOV
1411
                        )
×
1412
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1413
                                return fmt.Errorf("unable to fetch node1 "+
×
1414
                                        "policy: %w", err)
×
1415
                        } else if err == nil {
×
1416
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1417
                                        node1Policy.LastUpdate.Int64, 0,
×
1418
                                )
×
1419
                        }
×
1420

1421
                        //nolint:ll
1422
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1423
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1424
                                        Version:   int16(ProtocolV1),
×
1425
                                        ChannelID: dbChan.ID,
×
1426
                                        NodeID:    dbChan.NodeID2,
×
UNCOV
1427
                                },
×
1428
                        )
×
1429
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1430
                                return fmt.Errorf("unable to fetch node2 "+
×
UNCOV
1431
                                        "policy: %w", err)
×
UNCOV
1432
                        } else if err == nil {
×
1433
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1434
                                        node2Policy.LastUpdate.Int64, 0,
×
1435
                                )
×
1436
                        }
×
1437

1438
                        channelsPerBlock[cid.BlockHeight] = append(
×
1439
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
UNCOV
1440
                        )
×
1441
                }
1442

1443
                return nil
×
UNCOV
1444
        }, func() {
×
UNCOV
1445
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1446
        })
×
1447
        if err != nil {
×
1448
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1449
        }
×
1450

1451
        if len(channelsPerBlock) == 0 {
×
1452
                return nil, nil
×
1453
        }
×
1454

1455
        // Return the channel ranges in ascending block height order.
UNCOV
1456
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
UNCOV
1457
        slices.Sort(blocks)
×
UNCOV
1458

×
UNCOV
1459
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
UNCOV
1460
                return BlockChannelRange{
×
UNCOV
1461
                        Height:   block,
×
UNCOV
1462
                        Channels: channelsPerBlock[block],
×
1463
                }
×
1464
        }), nil
×
1465
}
1466

1467
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1468
// zombie. This method is used on an ad-hoc basis, when channels need to be
1469
// marked as zombies outside the normal pruning cycle.
1470
//
1471
// NOTE: part of the V1Store interface.
1472
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1473
        pubKey1, pubKey2 [33]byte) error {
×
1474

×
1475
        ctx := context.TODO()
×
1476

×
1477
        s.cacheMu.Lock()
×
1478
        defer s.cacheMu.Unlock()
×
1479

×
1480
        chanIDB := channelIDToBytes(chanID)
×
1481

×
1482
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1483
                return db.UpsertZombieChannel(
×
1484
                        ctx, sqlc.UpsertZombieChannelParams{
×
1485
                                Version:  int16(ProtocolV1),
×
UNCOV
1486
                                Scid:     chanIDB,
×
1487
                                NodeKey1: pubKey1[:],
×
1488
                                NodeKey2: pubKey2[:],
×
1489
                        },
×
1490
                )
×
UNCOV
1491
        }, sqldb.NoOpReset)
×
UNCOV
1492
        if err != nil {
×
UNCOV
1493
                return fmt.Errorf("unable to upsert zombie channel "+
×
UNCOV
1494
                        "(channel_id=%d): %w", chanID, err)
×
UNCOV
1495
        }
×
1496

1497
        s.rejectCache.remove(chanID)
×
1498
        s.chanCache.remove(chanID)
×
1499

×
1500
        return nil
×
1501
}
1502

1503
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1504
//
1505
// NOTE: part of the V1Store interface.
1506
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1507
        s.cacheMu.Lock()
×
1508
        defer s.cacheMu.Unlock()
×
1509

×
1510
        var (
×
1511
                ctx     = context.TODO()
×
1512
                chanIDB = channelIDToBytes(chanID)
×
1513
        )
×
1514

×
1515
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
UNCOV
1516
                res, err := db.DeleteZombieChannel(
×
1517
                        ctx, sqlc.DeleteZombieChannelParams{
×
1518
                                Scid:    chanIDB,
×
1519
                                Version: int16(ProtocolV1),
×
1520
                        },
×
UNCOV
1521
                )
×
1522
                if err != nil {
×
1523
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1524
                                err)
×
1525
                }
×
1526

1527
                rows, err := res.RowsAffected()
×
UNCOV
1528
                if err != nil {
×
1529
                        return err
×
UNCOV
1530
                }
×
1531

1532
                if rows == 0 {
×
1533
                        return ErrZombieEdgeNotFound
×
1534
                } else if rows > 1 {
×
UNCOV
1535
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1536
                                "expected 1", rows)
×
1537
                }
×
1538

1539
                return nil
×
1540
        }, sqldb.NoOpReset)
UNCOV
1541
        if err != nil {
×
UNCOV
1542
                return fmt.Errorf("unable to mark edge live "+
×
UNCOV
1543
                        "(channel_id=%d): %w", chanID, err)
×
UNCOV
1544
        }
×
1545

UNCOV
1546
        s.rejectCache.remove(chanID)
×
UNCOV
1547
        s.chanCache.remove(chanID)
×
1548

×
1549
        return err
×
1550
}
1551

1552
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1553
// zombie, then the two node public keys corresponding to this edge are also
1554
// returned.
1555
//
1556
// NOTE: part of the V1Store interface.
1557
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1558
        error) {
×
1559

×
1560
        var (
×
1561
                ctx              = context.TODO()
×
1562
                isZombie         bool
×
1563
                pubKey1, pubKey2 route.Vertex
×
1564
                chanIDB          = channelIDToBytes(chanID)
×
1565
        )
×
1566

×
1567
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1568
                zombie, err := db.GetZombieChannel(
×
1569
                        ctx, sqlc.GetZombieChannelParams{
×
1570
                                Scid:    chanIDB,
×
UNCOV
1571
                                Version: int16(ProtocolV1),
×
1572
                        },
×
1573
                )
×
1574
                if errors.Is(err, sql.ErrNoRows) {
×
1575
                        return nil
×
1576
                }
×
UNCOV
1577
                if err != nil {
×
1578
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1579
                                err)
×
1580
                }
×
1581

1582
                copy(pubKey1[:], zombie.NodeKey1)
×
UNCOV
1583
                copy(pubKey2[:], zombie.NodeKey2)
×
1584
                isZombie = true
×
UNCOV
1585

×
UNCOV
1586
                return nil
×
1587
        }, sqldb.NoOpReset)
UNCOV
1588
        if err != nil {
×
UNCOV
1589
                return false, route.Vertex{}, route.Vertex{},
×
1590
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1591
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1592
        }
×
1593

1594
        return isZombie, pubKey1, pubKey2, nil
×
1595
}
1596

1597
// NumZombies returns the current number of zombie channels in the graph.
1598
//
1599
// NOTE: part of the V1Store interface.
1600
func (s *SQLStore) NumZombies() (uint64, error) {
×
UNCOV
1601
        var (
×
1602
                ctx        = context.TODO()
×
1603
                numZombies uint64
×
1604
        )
×
UNCOV
1605
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1606
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1607
                if err != nil {
×
1608
                        return fmt.Errorf("unable to count zombie channels: %w",
×
UNCOV
1609
                                err)
×
1610
                }
×
1611

UNCOV
1612
                numZombies = uint64(count)
×
UNCOV
1613

×
UNCOV
1614
                return nil
×
1615
        }, sqldb.NoOpReset)
UNCOV
1616
        if err != nil {
×
UNCOV
1617
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
UNCOV
1618
        }
×
1619

UNCOV
1620
        return numZombies, nil
×
1621
}
1622

1623
// DeleteChannelEdges removes edges with the given channel IDs from the
1624
// database and marks them as zombies. This ensures that we're unable to re-add
1625
// it to our database once again. If an edge does not exist within the
1626
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1627
// true, then when we mark these edges as zombies, we'll set up the keys such
1628
// that we require the node that failed to send the fresh update to be the one
1629
// that resurrects the channel from its zombie state. The markZombie bool
1630
// denotes whether to mark the channel as a zombie.
1631
//
1632
// NOTE: part of the V1Store interface.
1633
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1634
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
UNCOV
1635

×
1636
        s.cacheMu.Lock()
×
1637
        defer s.cacheMu.Unlock()
×
1638

×
1639
        // Keep track of which channels we end up finding so that we can
×
1640
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1641
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1642
        for _, chanID := range chanIDs {
×
1643
                chanLookup[chanID] = struct{}{}
×
1644
        }
×
1645

1646
        var (
×
1647
                ctx   = context.TODO()
×
1648
                edges []*models.ChannelEdgeInfo
×
1649
        )
×
1650
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1651
                // First, collect all channel rows.
×
1652
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1653
                chanCallBack := func(ctx context.Context,
×
1654
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
UNCOV
1655

×
1656
                        // Deleting the entry from the map indicates that we
×
1657
                        // have found the channel.
×
1658
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1659
                        delete(chanLookup, scid)
×
1660

×
1661
                        channelRows = append(channelRows, row)
×
UNCOV
1662

×
1663
                        return nil
×
1664
                }
×
1665

UNCOV
1666
                err := s.forEachChanWithPoliciesInSCIDList(
×
1667
                        ctx, db, chanCallBack, chanIDs,
×
1668
                )
×
1669
                if err != nil {
×
UNCOV
1670
                        return err
×
UNCOV
1671
                }
×
1672

1673
                if len(chanLookup) > 0 {
×
1674
                        return ErrEdgeNotFound
×
1675
                }
×
1676

1677
                if len(channelRows) == 0 {
×
1678
                        return nil
×
UNCOV
1679
                }
×
1680

1681
                // Batch build all channel edges.
1682
                var chanIDsToDelete []int64
×
1683
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
1684
                        ctx, s.cfg, db, channelRows,
×
1685
                )
×
1686
                if err != nil {
×
1687
                        return err
×
1688
                }
×
1689

1690
                if markZombie {
×
1691
                        for i, row := range channelRows {
×
UNCOV
1692
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
UNCOV
1693

×
UNCOV
1694
                                err := handleZombieMarking(
×
1695
                                        ctx, db, row, edges[i],
×
1696
                                        strictZombiePruning, scid,
×
1697
                                )
×
1698
                                if err != nil {
×
1699
                                        return fmt.Errorf("unable to mark "+
×
1700
                                                "channel as zombie: %w", err)
×
1701
                                }
×
1702
                        }
1703
                }
1704

1705
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1706
        }, func() {
×
1707
                edges = nil
×
UNCOV
1708

×
1709
                // Re-fill the lookup map.
×
1710
                for _, chanID := range chanIDs {
×
1711
                        chanLookup[chanID] = struct{}{}
×
1712
                }
×
1713
        })
1714
        if err != nil {
×
UNCOV
1715
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
UNCOV
1716
                        err)
×
UNCOV
1717
        }
×
1718

UNCOV
1719
        for _, chanID := range chanIDs {
×
UNCOV
1720
                s.rejectCache.remove(chanID)
×
UNCOV
1721
                s.chanCache.remove(chanID)
×
UNCOV
1722
        }
×
1723

UNCOV
1724
        return edges, nil
×
1725
}
1726

1727
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1728
// channel identified by the channel ID. If the channel can't be found, then
1729
// ErrEdgeNotFound is returned. A struct which houses the general information
1730
// for the channel itself is returned as well as two structs that contain the
1731
// routing policies for the channel in either direction.
1732
//
1733
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1734
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1735
// the ChannelEdgeInfo will only include the public keys of each node.
1736
//
1737
// NOTE: part of the V1Store interface.
1738
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1739
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1740
        *models.ChannelEdgePolicy, error) {
×
1741

×
1742
        var (
×
1743
                ctx              = context.TODO()
×
1744
                edge             *models.ChannelEdgeInfo
×
1745
                policy1, policy2 *models.ChannelEdgePolicy
×
1746
                chanIDB          = channelIDToBytes(chanID)
×
1747
        )
×
1748
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1749
                row, err := db.GetChannelBySCIDWithPolicies(
×
1750
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1751
                                Scid:    chanIDB,
×
1752
                                Version: int16(ProtocolV1),
×
1753
                        },
×
1754
                )
×
1755
                if errors.Is(err, sql.ErrNoRows) {
×
1756
                        // First check if this edge is perhaps in the zombie
×
1757
                        // index.
×
1758
                        zombie, err := db.GetZombieChannel(
×
1759
                                ctx, sqlc.GetZombieChannelParams{
×
UNCOV
1760
                                        Scid:    chanIDB,
×
UNCOV
1761
                                        Version: int16(ProtocolV1),
×
UNCOV
1762
                                },
×
UNCOV
1763
                        )
×
UNCOV
1764
                        if errors.Is(err, sql.ErrNoRows) {
×
UNCOV
1765
                                return ErrEdgeNotFound
×
1766
                        } else if err != nil {
×
1767
                                return fmt.Errorf("unable to check if "+
×
1768
                                        "channel is zombie: %w", err)
×
1769
                        }
×
1770

1771
                        // At this point, we know the channel is a zombie, so
1772
                        // we'll return an error indicating this, and we will
1773
                        // populate the edge info with the public keys of each
1774
                        // party as this is the only information we have about
1775
                        // it.
1776
                        edge = &models.ChannelEdgeInfo{}
×
1777
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1778
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1779

×
1780
                        return ErrZombieEdge
×
UNCOV
1781
                } else if err != nil {
×
1782
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1783
                }
×
1784

1785
                node1, node2, err := buildNodeVertices(
×
1786
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1787
                )
×
1788
                if err != nil {
×
UNCOV
1789
                        return err
×
1790
                }
×
1791

1792
                edge, err = getAndBuildEdgeInfo(
×
1793
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
1794
                )
×
UNCOV
1795
                if err != nil {
×
1796
                        return fmt.Errorf("unable to build channel info: %w",
×
1797
                                err)
×
1798
                }
×
1799

1800
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1801
                if err != nil {
×
1802
                        return fmt.Errorf("unable to extract channel "+
×
1803
                                "policies: %w", err)
×
UNCOV
1804
                }
×
1805

UNCOV
1806
                policy1, policy2, err = getAndBuildChanPolicies(
×
1807
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
1808
                        node1, node2,
×
1809
                )
×
1810
                if err != nil {
×
1811
                        return fmt.Errorf("unable to build channel "+
×
1812
                                "policies: %w", err)
×
1813
                }
×
1814

1815
                return nil
×
1816
        }, sqldb.NoOpReset)
UNCOV
1817
        if err != nil {
×
UNCOV
1818
                // If we are returning the ErrZombieEdge, then we also need to
×
UNCOV
1819
                // return the edge info as the method comment indicates that
×
UNCOV
1820
                // this will be populated when the edge is a zombie.
×
UNCOV
1821
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
UNCOV
1822
                        err)
×
UNCOV
1823
        }
×
1824

UNCOV
1825
        return edge, policy1, policy2, nil
×
1826
}
1827

1828
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1829
// the channel identified by the funding outpoint. If the channel can't be
1830
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1831
// information for the channel itself is returned as well as two structs that
1832
// contain the routing policies for the channel in either direction.
1833
//
1834
// NOTE: part of the V1Store interface.
1835
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1836
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1837
        *models.ChannelEdgePolicy, error) {
×
1838

×
1839
        var (
×
1840
                ctx              = context.TODO()
×
1841
                edge             *models.ChannelEdgeInfo
×
1842
                policy1, policy2 *models.ChannelEdgePolicy
×
1843
        )
×
1844
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1845
                row, err := db.GetChannelByOutpointWithPolicies(
×
UNCOV
1846
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1847
                                Outpoint: op.String(),
×
1848
                                Version:  int16(ProtocolV1),
×
1849
                        },
×
1850
                )
×
1851
                if errors.Is(err, sql.ErrNoRows) {
×
1852
                        return ErrEdgeNotFound
×
UNCOV
1853
                } else if err != nil {
×
1854
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1855
                }
×
1856

1857
                node1, node2, err := buildNodeVertices(
×
1858
                        row.Node1Pubkey, row.Node2Pubkey,
×
1859
                )
×
1860
                if err != nil {
×
UNCOV
1861
                        return err
×
1862
                }
×
1863

1864
                edge, err = getAndBuildEdgeInfo(
×
1865
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
1866
                )
×
UNCOV
1867
                if err != nil {
×
1868
                        return fmt.Errorf("unable to build channel info: %w",
×
1869
                                err)
×
1870
                }
×
1871

1872
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1873
                if err != nil {
×
1874
                        return fmt.Errorf("unable to extract channel "+
×
1875
                                "policies: %w", err)
×
UNCOV
1876
                }
×
1877

UNCOV
1878
                policy1, policy2, err = getAndBuildChanPolicies(
×
1879
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
1880
                        node1, node2,
×
1881
                )
×
1882
                if err != nil {
×
UNCOV
1883
                        return fmt.Errorf("unable to build channel "+
×
1884
                                "policies: %w", err)
×
UNCOV
1885
                }
×
1886

UNCOV
1887
                return nil
×
1888
        }, sqldb.NoOpReset)
UNCOV
1889
        if err != nil {
×
UNCOV
1890
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
UNCOV
1891
                        err)
×
UNCOV
1892
        }
×
1893

UNCOV
1894
        return edge, policy1, policy2, nil
×
1895
}
1896

1897
// HasChannelEdge returns true if the database knows of a channel edge with the
1898
// passed channel ID, and false otherwise. If an edge with that ID is found
1899
// within the graph, then two time stamps representing the last time the edge
1900
// was updated for both directed edges are returned along with the boolean. If
1901
// it is not found, then the zombie index is checked and its result is returned
1902
// as the second boolean.
1903
//
1904
// NOTE: part of the V1Store interface.
1905
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1906
        bool, error) {
×
1907

×
1908
        ctx := context.TODO()
×
1909

×
1910
        var (
×
1911
                exists          bool
×
1912
                isZombie        bool
×
1913
                node1LastUpdate time.Time
×
1914
                node2LastUpdate time.Time
×
1915
        )
×
1916

×
1917
        // We'll query the cache with the shared lock held to allow multiple
×
1918
        // readers to access values in the cache concurrently if they exist.
×
1919
        s.cacheMu.RLock()
×
1920
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1921
                s.cacheMu.RUnlock()
×
1922
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1923
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1924
                exists, isZombie = entry.flags.unpack()
×
1925

×
1926
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1927
        }
×
1928
        s.cacheMu.RUnlock()
×
1929

×
1930
        s.cacheMu.Lock()
×
1931
        defer s.cacheMu.Unlock()
×
1932

×
UNCOV
1933
        // The item was not found with the shared lock, so we'll acquire the
×
1934
        // exclusive lock and check the cache again in case another method added
×
1935
        // the entry to the cache while no lock was held.
×
1936
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1937
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1938
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1939
                exists, isZombie = entry.flags.unpack()
×
1940

×
1941
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1942
        }
×
1943

1944
        chanIDB := channelIDToBytes(chanID)
×
1945
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1946
                channel, err := db.GetChannelBySCID(
×
1947
                        ctx, sqlc.GetChannelBySCIDParams{
×
1948
                                Scid:    chanIDB,
×
1949
                                Version: int16(ProtocolV1),
×
1950
                        },
×
1951
                )
×
1952
                if errors.Is(err, sql.ErrNoRows) {
×
1953
                        // Check if it is a zombie channel.
×
UNCOV
1954
                        isZombie, err = db.IsZombieChannel(
×
1955
                                ctx, sqlc.IsZombieChannelParams{
×
1956
                                        Scid:    chanIDB,
×
1957
                                        Version: int16(ProtocolV1),
×
1958
                                },
×
UNCOV
1959
                        )
×
1960
                        if err != nil {
×
1961
                                return fmt.Errorf("could not check if channel "+
×
1962
                                        "is zombie: %w", err)
×
1963
                        }
×
1964

1965
                        return nil
×
1966
                } else if err != nil {
×
1967
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1968
                }
×
1969

1970
                exists = true
×
1971

×
1972
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
1973
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1974
                                Version:   int16(ProtocolV1),
×
UNCOV
1975
                                ChannelID: channel.ID,
×
1976
                                NodeID:    channel.NodeID1,
×
1977
                        },
×
1978
                )
×
1979
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1980
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
1981
                                err)
×
1982
                } else if err == nil {
×
1983
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
1984
                }
×
1985

1986
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
1987
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1988
                                Version:   int16(ProtocolV1),
×
UNCOV
1989
                                ChannelID: channel.ID,
×
1990
                                NodeID:    channel.NodeID2,
×
UNCOV
1991
                        },
×
1992
                )
×
1993
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1994
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
1995
                                err)
×
UNCOV
1996
                } else if err == nil {
×
1997
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
1998
                }
×
1999

2000
                return nil
×
2001
        }, sqldb.NoOpReset)
2002
        if err != nil {
×
2003
                return time.Time{}, time.Time{}, false, false,
×
UNCOV
2004
                        fmt.Errorf("unable to fetch channel: %w", err)
×
UNCOV
2005
        }
×
2006

UNCOV
2007
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
UNCOV
2008
                upd1Time: node1LastUpdate.Unix(),
×
UNCOV
2009
                upd2Time: node2LastUpdate.Unix(),
×
UNCOV
2010
                flags:    packRejectFlags(exists, isZombie),
×
2011
        })
×
2012

×
2013
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2014
}
2015

2016
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2017
// passed channel point (outpoint). If the passed channel doesn't exist within
2018
// the database, then ErrEdgeNotFound is returned.
2019
//
2020
// NOTE: part of the V1Store interface.
2021
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2022
        var (
×
2023
                ctx       = context.TODO()
×
2024
                channelID uint64
×
2025
        )
×
2026
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2027
                chanID, err := db.GetSCIDByOutpoint(
×
2028
                        ctx, sqlc.GetSCIDByOutpointParams{
×
UNCOV
2029
                                Outpoint: chanPoint.String(),
×
2030
                                Version:  int16(ProtocolV1),
×
2031
                        },
×
2032
                )
×
UNCOV
2033
                if errors.Is(err, sql.ErrNoRows) {
×
2034
                        return ErrEdgeNotFound
×
2035
                } else if err != nil {
×
2036
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
UNCOV
2037
                                err)
×
2038
                }
×
2039

UNCOV
2040
                channelID = byteOrder.Uint64(chanID)
×
UNCOV
2041

×
UNCOV
2042
                return nil
×
2043
        }, sqldb.NoOpReset)
UNCOV
2044
        if err != nil {
×
UNCOV
2045
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2046
        }
×
2047

2048
        return channelID, nil
×
2049
}
2050

2051
// IsPublicNode is a helper method that determines whether the node with the
2052
// given public key is seen as a public node in the graph from the graph's
2053
// source node's point of view.
2054
//
2055
// NOTE: part of the V1Store interface.
2056
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2057
        ctx := context.TODO()
×
2058

×
2059
        var isPublic bool
×
UNCOV
2060
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2061
                var err error
×
UNCOV
2062
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
UNCOV
2063

×
UNCOV
2064
                return err
×
UNCOV
2065
        }, sqldb.NoOpReset)
×
UNCOV
2066
        if err != nil {
×
UNCOV
2067
                return false, fmt.Errorf("unable to check if node is "+
×
UNCOV
2068
                        "public: %w", err)
×
UNCOV
2069
        }
×
2070

2071
        return isPublic, nil
×
2072
}
2073

2074
// FetchChanInfos returns the set of channel edges that correspond to the passed
2075
// channel ID's. If an edge is the query is unknown to the database, it will
2076
// skipped and the result will contain only those edges that exist at the time
2077
// of the query. This can be used to respond to peer queries that are seeking to
2078
// fill in gaps in their view of the channel graph.
2079
//
2080
// NOTE: part of the V1Store interface.
2081
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2082
        var (
×
2083
                ctx   = context.TODO()
×
2084
                edges = make(map[uint64]ChannelEdge)
×
UNCOV
2085
        )
×
2086
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2087
                // First, collect all channel rows.
×
2088
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2089
                chanCallBack := func(ctx context.Context,
×
2090
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2091

×
UNCOV
2092
                        channelRows = append(channelRows, row)
×
2093
                        return nil
×
2094
                }
×
2095

UNCOV
2096
                err := s.forEachChanWithPoliciesInSCIDList(
×
UNCOV
2097
                        ctx, db, chanCallBack, chanIDs,
×
2098
                )
×
2099
                if err != nil {
×
2100
                        return err
×
2101
                }
×
2102

2103
                if len(channelRows) == 0 {
×
2104
                        return nil
×
UNCOV
2105
                }
×
2106

2107
                // Batch build all channel edges.
2108
                chans, err := batchBuildChannelEdges(
×
UNCOV
2109
                        ctx, s.cfg, db, channelRows,
×
2110
                )
×
2111
                if err != nil {
×
2112
                        return fmt.Errorf("unable to build channel edges: %w",
×
2113
                                err)
×
2114
                }
×
2115

2116
                for _, c := range chans {
×
UNCOV
2117
                        edges[c.Info.ChannelID] = c
×
2118
                }
×
2119

2120
                return err
×
2121
        }, func() {
×
2122
                clear(edges)
×
UNCOV
2123
        })
×
UNCOV
2124
        if err != nil {
×
2125
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
UNCOV
2126
        }
×
2127

2128
        res := make([]ChannelEdge, 0, len(edges))
×
UNCOV
2129
        for _, chanID := range chanIDs {
×
UNCOV
2130
                edge, ok := edges[chanID]
×
UNCOV
2131
                if !ok {
×
UNCOV
2132
                        continue
×
2133
                }
2134

UNCOV
2135
                res = append(res, edge)
×
2136
        }
2137

2138
        return res, nil
×
2139
}
2140

2141
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2142
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2143
// channels in a paginated manner.
2144
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2145
        db SQLQueries, cb func(ctx context.Context,
2146
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2147
        chanIDs []uint64) error {
×
2148

×
2149
        queryWrapper := func(ctx context.Context,
×
UNCOV
2150
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2151
                error) {
×
2152

×
2153
                return db.GetChannelsBySCIDWithPolicies(
×
2154
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
UNCOV
2155
                                Version: int16(ProtocolV1),
×
UNCOV
2156
                                Scids:   scids,
×
UNCOV
2157
                        },
×
UNCOV
2158
                )
×
UNCOV
2159
        }
×
2160

UNCOV
2161
        return sqldb.ExecuteBatchQuery(
×
UNCOV
2162
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
UNCOV
2163
                cb,
×
UNCOV
2164
        )
×
2165
}
2166

2167
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2168
// ID's that we don't know and are not known zombies of the passed set. In other
2169
// words, we perform a set difference of our set of chan ID's and the ones
2170
// passed in. This method can be used by callers to determine the set of
2171
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2172
// known zombies is also returned.
2173
//
2174
// NOTE: part of the V1Store interface.
2175
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2176
        []ChannelUpdateInfo, error) {
×
2177

×
2178
        var (
×
2179
                ctx          = context.TODO()
×
2180
                newChanIDs   []uint64
×
2181
                knownZombies []ChannelUpdateInfo
×
2182
                infoLookup   = make(
×
UNCOV
2183
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2184
                )
×
2185
        )
×
2186

×
2187
        // We first build a lookup map of the channel ID's to the
×
2188
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2189
        // already know about.
×
2190
        for _, chanInfo := range chansInfo {
×
2191
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2192
        }
×
2193

2194
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
2195
                // The call-back function deletes known channels from
×
2196
                // infoLookup, so that we can later check which channels are
×
2197
                // zombies by only looking at the remaining channels in the set.
×
2198
                cb := func(ctx context.Context,
×
2199
                        channel sqlc.GraphChannel) error {
×
2200

×
UNCOV
2201
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
UNCOV
2202

×
UNCOV
2203
                        return nil
×
UNCOV
2204
                }
×
2205

2206
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2207
                if err != nil {
×
2208
                        return fmt.Errorf("unable to iterate through "+
×
2209
                                "channels: %w", err)
×
UNCOV
2210
                }
×
2211

2212
                // We want to ensure that we deal with the channels in the
2213
                // same order that they were passed in, so we iterate over the
2214
                // original chansInfo slice and then check if that channel is
2215
                // still in the infoLookup map.
2216
                for _, chanInfo := range chansInfo {
×
2217
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2218
                        if _, ok := infoLookup[channelID]; !ok {
×
2219
                                continue
×
2220
                        }
2221

UNCOV
2222
                        isZombie, err := db.IsZombieChannel(
×
2223
                                ctx, sqlc.IsZombieChannelParams{
×
2224
                                        Scid:    channelIDToBytes(channelID),
×
2225
                                        Version: int16(ProtocolV1),
×
2226
                                },
×
UNCOV
2227
                        )
×
UNCOV
2228
                        if err != nil {
×
2229
                                return fmt.Errorf("unable to fetch zombie "+
×
UNCOV
2230
                                        "channel: %w", err)
×
UNCOV
2231
                        }
×
2232

2233
                        if isZombie {
×
2234
                                knownZombies = append(knownZombies, chanInfo)
×
2235

×
2236
                                continue
×
2237
                        }
2238

2239
                        newChanIDs = append(newChanIDs, channelID)
×
2240
                }
2241

2242
                return nil
×
2243
        }, func() {
×
2244
                newChanIDs = nil
×
UNCOV
2245
                knownZombies = nil
×
2246
                // Rebuild the infoLookup map in case of a rollback.
×
UNCOV
2247
                for _, chanInfo := range chansInfo {
×
UNCOV
2248
                        scid := chanInfo.ShortChannelID.ToUint64()
×
UNCOV
2249
                        infoLookup[scid] = chanInfo
×
UNCOV
2250
                }
×
2251
        })
UNCOV
2252
        if err != nil {
×
UNCOV
2253
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
UNCOV
2254
        }
×
2255

2256
        return newChanIDs, knownZombies, nil
×
2257
}
2258

2259
// forEachChanInSCIDList is a helper method that executes a paged query
2260
// against the database to fetch all channels that match the passed
2261
// ChannelUpdateInfo slice. The callback function is called for each channel
2262
// that is found.
2263
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2264
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2265
        chansInfo []ChannelUpdateInfo) error {
×
2266

×
UNCOV
2267
        queryWrapper := func(ctx context.Context,
×
2268
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2269

×
2270
                return db.GetChannelsBySCIDs(
×
2271
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2272
                                Version: int16(ProtocolV1),
×
UNCOV
2273
                                Scids:   scids,
×
2274
                        },
×
2275
                )
×
2276
        }
×
2277

UNCOV
2278
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
UNCOV
2279
                channelID := chanInfo.ShortChannelID.ToUint64()
×
UNCOV
2280

×
UNCOV
2281
                return channelIDToBytes(channelID)
×
UNCOV
2282
        }
×
2283

UNCOV
2284
        return sqldb.ExecuteBatchQuery(
×
UNCOV
2285
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
UNCOV
2286
                cb,
×
UNCOV
2287
        )
×
2288
}
2289

2290
// PruneGraphNodes is a garbage collection method which attempts to prune out
2291
// any nodes from the channel graph that are currently unconnected. This ensure
2292
// that we only maintain a graph of reachable nodes. In the event that a pruned
2293
// node gains more channels, it will be re-added back to the graph.
2294
//
2295
// NOTE: this prunes nodes across protocol versions. It will never prune the
2296
// source nodes.
2297
//
2298
// NOTE: part of the V1Store interface.
2299
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2300
        var ctx = context.TODO()
×
2301

×
2302
        var prunedNodes []route.Vertex
×
2303
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
UNCOV
2304
                var err error
×
2305
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
UNCOV
2306

×
UNCOV
2307
                return err
×
UNCOV
2308
        }, func() {
×
UNCOV
2309
                prunedNodes = nil
×
UNCOV
2310
        })
×
UNCOV
2311
        if err != nil {
×
UNCOV
2312
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
UNCOV
2313
        }
×
2314

UNCOV
2315
        return prunedNodes, nil
×
2316
}
2317

2318
// PruneGraph prunes newly closed channels from the channel graph in response
2319
// to a new block being solved on the network. Any transactions which spend the
2320
// funding output of any known channels within he graph will be deleted.
2321
// Additionally, the "prune tip", or the last block which has been used to
2322
// prune the graph is stored so callers can ensure the graph is fully in sync
2323
// with the current UTXO state. A slice of channels that have been closed by
2324
// the target block along with any pruned nodes are returned if the function
2325
// succeeds without error.
2326
//
2327
// NOTE: part of the V1Store interface.
2328
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2329
        blockHash *chainhash.Hash, blockHeight uint32) (
2330
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2331

×
2332
        ctx := context.TODO()
×
2333

×
2334
        s.cacheMu.Lock()
×
2335
        defer s.cacheMu.Unlock()
×
2336

×
2337
        var (
×
2338
                closedChans []*models.ChannelEdgeInfo
×
2339
                prunedNodes []route.Vertex
×
2340
        )
×
UNCOV
2341
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2342
                // First, collect all channel rows that need to be pruned.
×
2343
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2344
                channelCallback := func(ctx context.Context,
×
2345
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2346

×
2347
                        channelRows = append(channelRows, row)
×
2348

×
UNCOV
2349
                        return nil
×
2350
                }
×
2351

2352
                err := s.forEachChanInOutpoints(
×
2353
                        ctx, db, spentOutputs, channelCallback,
×
2354
                )
×
2355
                if err != nil {
×
2356
                        return fmt.Errorf("unable to fetch channels by "+
×
2357
                                "outpoints: %w", err)
×
2358
                }
×
2359

2360
                if len(channelRows) == 0 {
×
2361
                        // There are no channels to prune. So we can exit early
×
2362
                        // after updating the prune log.
×
UNCOV
2363
                        err = db.UpsertPruneLogEntry(
×
2364
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
UNCOV
2365
                                        BlockHash:   blockHash[:],
×
UNCOV
2366
                                        BlockHeight: int64(blockHeight),
×
UNCOV
2367
                                },
×
2368
                        )
×
2369
                        if err != nil {
×
2370
                                return fmt.Errorf("unable to insert prune log "+
×
2371
                                        "entry: %w", err)
×
2372
                        }
×
2373

2374
                        return nil
×
2375
                }
2376

2377
                // Batch build all channel edges for pruning.
2378
                var chansToDelete []int64
×
2379
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
UNCOV
2380
                        ctx, s.cfg, db, channelRows,
×
2381
                )
×
2382
                if err != nil {
×
2383
                        return err
×
2384
                }
×
2385

2386
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2387
                if err != nil {
×
2388
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2389
                }
×
2390

UNCOV
2391
                err = db.UpsertPruneLogEntry(
×
UNCOV
2392
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
UNCOV
2393
                                BlockHash:   blockHash[:],
×
2394
                                BlockHeight: int64(blockHeight),
×
2395
                        },
×
2396
                )
×
2397
                if err != nil {
×
2398
                        return fmt.Errorf("unable to insert prune log "+
×
UNCOV
2399
                                "entry: %w", err)
×
2400
                }
×
2401

2402
                // Now that we've pruned some channels, we'll also prune any
2403
                // nodes that no longer have any channels.
2404
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2405
                if err != nil {
×
2406
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2407
                                err)
×
UNCOV
2408
                }
×
2409

2410
                return nil
×
2411
        }, func() {
×
2412
                prunedNodes = nil
×
UNCOV
2413
                closedChans = nil
×
2414
        })
×
UNCOV
2415
        if err != nil {
×
UNCOV
2416
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
UNCOV
2417
        }
×
2418

UNCOV
2419
        for _, channel := range closedChans {
×
UNCOV
2420
                s.rejectCache.remove(channel.ChannelID)
×
UNCOV
2421
                s.chanCache.remove(channel.ChannelID)
×
UNCOV
2422
        }
×
2423

2424
        return closedChans, prunedNodes, nil
×
2425
}
2426

2427
// forEachChanInOutpoints is a helper function that executes a paginated
2428
// query to fetch channels by their outpoints and applies the given call-back
2429
// to each.
2430
//
2431
// NOTE: this fetches channels for all protocol versions.
2432
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2433
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
UNCOV
2434
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
UNCOV
2435

×
2436
        // Create a wrapper that uses the transaction's db instance to execute
×
2437
        // the query.
×
2438
        queryWrapper := func(ctx context.Context,
×
UNCOV
2439
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2440
                error) {
×
2441

×
2442
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2443
        }
×
2444

2445
        // Define the conversion function from Outpoint to string.
UNCOV
2446
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2447
                return outpoint.String()
×
2448
        }
×
2449

2450
        return sqldb.ExecuteBatchQuery(
×
2451
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2452
                queryWrapper, cb,
×
2453
        )
×
2454
}
2455

2456
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2457
        dbIDs []int64) error {
×
UNCOV
2458

×
2459
        // Create a wrapper that uses the transaction's db instance to execute
×
2460
        // the query.
×
2461
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2462
                return nil, db.DeleteChannels(ctx, ids)
×
2463
        }
×
2464

UNCOV
2465
        idConverter := func(id int64) int64 {
×
UNCOV
2466
                return id
×
UNCOV
2467
        }
×
2468

UNCOV
2469
        return sqldb.ExecuteBatchQuery(
×
UNCOV
2470
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
UNCOV
2471
                queryWrapper, func(ctx context.Context, _ any) error {
×
UNCOV
2472
                        return nil
×
2473
                },
×
2474
        )
2475
}
2476

2477
// ChannelView returns the verifiable edge information for each active channel
2478
// within the known channel graph. The set of UTXOs (along with their scripts)
2479
// returned are the ones that need to be watched on chain to detect channel
2480
// closes on the resident blockchain.
2481
//
2482
// NOTE: part of the V1Store interface.
2483
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2484
        var (
×
2485
                ctx        = context.TODO()
×
2486
                edgePoints []EdgePoint
×
2487
        )
×
2488

×
UNCOV
2489
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2490
                handleChannel := func(_ context.Context,
×
2491
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2492

×
2493
                        pkScript, err := genMultiSigP2WSH(
×
UNCOV
2494
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2495
                        )
×
2496
                        if err != nil {
×
2497
                                return err
×
2498
                        }
×
2499

2500
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
UNCOV
2501
                        if err != nil {
×
UNCOV
2502
                                return err
×
2503
                        }
×
2504

2505
                        edgePoints = append(edgePoints, EdgePoint{
×
2506
                                FundingPkScript: pkScript,
×
2507
                                OutPoint:        *op,
×
2508
                        })
×
2509

×
2510
                        return nil
×
2511
                }
2512

2513
                queryFunc := func(ctx context.Context, lastID int64,
×
UNCOV
2514
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2515

×
2516
                        return db.ListChannelsPaginated(
×
2517
                                ctx, sqlc.ListChannelsPaginatedParams{
×
UNCOV
2518
                                        Version: int16(ProtocolV1),
×
2519
                                        ID:      lastID,
×
2520
                                        Limit:   limit,
×
2521
                                },
×
2522
                        )
×
2523
                }
×
2524

2525
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2526
                        return row.ID
×
2527
                }
×
2528

UNCOV
2529
                return sqldb.ExecutePaginatedQuery(
×
2530
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
UNCOV
2531
                        extractCursor, handleChannel,
×
UNCOV
2532
                )
×
UNCOV
2533
        }, func() {
×
UNCOV
2534
                edgePoints = nil
×
UNCOV
2535
        })
×
UNCOV
2536
        if err != nil {
×
UNCOV
2537
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
UNCOV
2538
        }
×
2539

2540
        return edgePoints, nil
×
2541
}
2542

2543
// PruneTip returns the block height and hash of the latest block that has been
2544
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2545
// to tell if the graph is currently in sync with the current best known UTXO
2546
// state.
2547
//
2548
// NOTE: part of the V1Store interface.
2549
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2550
        var (
×
2551
                ctx       = context.TODO()
×
UNCOV
2552
                tipHash   chainhash.Hash
×
2553
                tipHeight uint32
×
2554
        )
×
2555
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2556
                pruneTip, err := db.GetPruneTip(ctx)
×
UNCOV
2557
                if errors.Is(err, sql.ErrNoRows) {
×
2558
                        return ErrGraphNeverPruned
×
2559
                } else if err != nil {
×
2560
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
UNCOV
2561
                }
×
2562

UNCOV
2563
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
UNCOV
2564
                tipHeight = uint32(pruneTip.BlockHeight)
×
UNCOV
2565

×
UNCOV
2566
                return nil
×
2567
        }, sqldb.NoOpReset)
UNCOV
2568
        if err != nil {
×
UNCOV
2569
                return nil, 0, err
×
2570
        }
×
2571

2572
        return &tipHash, tipHeight, nil
×
2573
}
2574

2575
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2576
//
2577
// NOTE: this prunes nodes across protocol versions. It will never prune the
2578
// source nodes.
2579
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2580
        db SQLQueries) ([]route.Vertex, error) {
×
2581

×
2582
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2583
        if err != nil {
×
2584
                return nil, fmt.Errorf("unable to delete unconnected "+
×
UNCOV
2585
                        "nodes: %w", err)
×
2586
        }
×
2587

UNCOV
2588
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2589
        for i, nodeKey := range nodeKeys {
×
UNCOV
2590
                pub, err := route.NewVertexFromBytes(nodeKey)
×
UNCOV
2591
                if err != nil {
×
UNCOV
2592
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
UNCOV
2593
                                "from bytes: %w", err)
×
UNCOV
2594
                }
×
2595

UNCOV
2596
                prunedNodes[i] = pub
×
2597
        }
2598

UNCOV
2599
        return prunedNodes, nil
×
2600
}
2601

2602
// DisconnectBlockAtHeight is used to indicate that the block specified
2603
// by the passed height has been disconnected from the main chain. This
2604
// will "rewind" the graph back to the height below, deleting channels
2605
// that are no longer confirmed from the graph. The prune log will be
2606
// set to the last prune height valid for the remaining chain.
2607
// Channels that were removed from the graph resulting from the
2608
// disconnected block are returned.
2609
//
2610
// NOTE: part of the V1Store interface.
2611
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2612
        []*models.ChannelEdgeInfo, error) {
×
2613

×
2614
        ctx := context.TODO()
×
2615

×
2616
        var (
×
2617
                // Every channel having a ShortChannelID starting at 'height'
×
2618
                // will no longer be confirmed.
×
2619
                startShortChanID = lnwire.ShortChannelID{
×
2620
                        BlockHeight: height,
×
2621
                }
×
2622

×
2623
                // Delete everything after this height from the db up until the
×
2624
                // SCID alias range.
×
2625
                endShortChanID = aliasmgr.StartingAlias
×
2626

×
2627
                removedChans []*models.ChannelEdgeInfo
×
2628

×
2629
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2630
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2631
        )
×
2632

×
UNCOV
2633
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2634
                rows, err := db.GetChannelsBySCIDRange(
×
2635
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2636
                                StartScid: chanIDStart,
×
2637
                                EndScid:   chanIDEnd,
×
2638
                        },
×
2639
                )
×
2640
                if err != nil {
×
2641
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2642
                }
×
2643

2644
                if len(rows) == 0 {
×
2645
                        // No channels to disconnect, but still clean up prune
×
UNCOV
2646
                        // log.
×
UNCOV
2647
                        return db.DeletePruneLogEntriesInRange(
×
2648
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2649
                                        StartHeight: int64(height),
×
2650
                                        EndHeight: int64(
×
2651
                                                endShortChanID.BlockHeight,
×
2652
                                        ),
×
2653
                                },
×
UNCOV
2654
                        )
×
2655
                }
×
2656

2657
                // Batch build all channel edges for disconnection.
2658
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
2659
                        ctx, s.cfg, db, rows,
×
2660
                )
×
UNCOV
2661
                if err != nil {
×
2662
                        return err
×
2663
                }
×
2664

2665
                removedChans = channelEdges
×
2666

×
2667
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2668
                if err != nil {
×
2669
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2670
                }
×
2671

2672
                return db.DeletePruneLogEntriesInRange(
×
2673
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2674
                                StartHeight: int64(height),
×
UNCOV
2675
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2676
                        },
×
2677
                )
×
2678
        }, func() {
×
2679
                removedChans = nil
×
UNCOV
2680
        })
×
2681
        if err != nil {
×
UNCOV
2682
                return nil, fmt.Errorf("unable to disconnect block at "+
×
UNCOV
2683
                        "height: %w", err)
×
UNCOV
2684
        }
×
2685

UNCOV
2686
        for _, channel := range removedChans {
×
UNCOV
2687
                s.rejectCache.remove(channel.ChannelID)
×
2688
                s.chanCache.remove(channel.ChannelID)
×
2689
        }
×
2690

2691
        return removedChans, nil
×
2692
}
2693

2694
// AddEdgeProof sets the proof of an existing edge in the graph database.
2695
//
2696
// NOTE: part of the V1Store interface.
2697
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2698
        proof *models.ChannelAuthProof) error {
×
2699

×
2700
        var (
×
2701
                ctx       = context.TODO()
×
2702
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2703
        )
×
2704

×
2705
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2706
                res, err := db.AddV1ChannelProof(
×
2707
                        ctx, sqlc.AddV1ChannelProofParams{
×
UNCOV
2708
                                Scid:              scidBytes,
×
2709
                                Node1Signature:    proof.NodeSig1Bytes,
×
2710
                                Node2Signature:    proof.NodeSig2Bytes,
×
2711
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2712
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
UNCOV
2713
                        },
×
2714
                )
×
2715
                if err != nil {
×
2716
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2717
                }
×
2718

2719
                n, err := res.RowsAffected()
×
2720
                if err != nil {
×
2721
                        return err
×
UNCOV
2722
                }
×
2723

UNCOV
2724
                if n == 0 {
×
2725
                        return fmt.Errorf("no rows affected when adding edge "+
×
2726
                                "proof for SCID %v", scid)
×
2727
                } else if n > 1 {
×
UNCOV
2728
                        return fmt.Errorf("multiple rows affected when adding "+
×
2729
                                "edge proof for SCID %v: %d rows affected",
×
UNCOV
2730
                                scid, n)
×
UNCOV
2731
                }
×
2732

UNCOV
2733
                return nil
×
2734
        }, sqldb.NoOpReset)
UNCOV
2735
        if err != nil {
×
UNCOV
2736
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2737
        }
×
2738

2739
        return nil
×
2740
}
2741

2742
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2743
// that we can ignore channel announcements that we know to be closed without
2744
// having to validate them and fetch a block.
2745
//
2746
// NOTE: part of the V1Store interface.
UNCOV
2747
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
UNCOV
2748
        var (
×
UNCOV
2749
                ctx     = context.TODO()
×
UNCOV
2750
                chanIDB = channelIDToBytes(scid.ToUint64())
×
UNCOV
2751
        )
×
2752

×
2753
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2754
                return db.InsertClosedChannel(ctx, chanIDB)
×
2755
        }, sqldb.NoOpReset)
×
2756
}
2757

2758
// IsClosedScid checks whether a channel identified by the passed in scid is
2759
// closed. This helps avoid having to perform expensive validation checks.
2760
//
2761
// NOTE: part of the V1Store interface.
2762
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2763
        var (
×
2764
                ctx      = context.TODO()
×
UNCOV
2765
                isClosed bool
×
2766
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
UNCOV
2767
        )
×
2768
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2769
                var err error
×
2770
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2771
                if err != nil {
×
UNCOV
2772
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2773
                                err)
×
UNCOV
2774
                }
×
2775

UNCOV
2776
                return nil
×
2777
        }, sqldb.NoOpReset)
UNCOV
2778
        if err != nil {
×
UNCOV
2779
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
UNCOV
2780
                        err)
×
2781
        }
×
2782

2783
        return isClosed, nil
×
2784
}
2785

2786
// GraphSession will provide the call-back with access to a NodeTraverser
2787
// instance which can be used to perform queries against the channel graph.
2788
//
2789
// NOTE: part of the V1Store interface.
2790
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
UNCOV
2791
        reset func()) error {
×
UNCOV
2792

×
UNCOV
2793
        var ctx = context.TODO()
×
UNCOV
2794

×
UNCOV
2795
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
2796
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
UNCOV
2797
        }, reset)
×
2798
}
2799

2800
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2801
// read only transaction for a consistent view of the graph.
2802
type sqlNodeTraverser struct {
2803
        db    SQLQueries
2804
        chain chainhash.Hash
2805
}
2806

2807
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2808
// NodeTraverser interface.
2809
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2810

2811
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2812
func newSQLNodeTraverser(db SQLQueries,
UNCOV
2813
        chain chainhash.Hash) *sqlNodeTraverser {
×
UNCOV
2814

×
UNCOV
2815
        return &sqlNodeTraverser{
×
2816
                db:    db,
×
2817
                chain: chain,
×
2818
        }
×
2819
}
×
2820

2821
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2822
// node.
2823
//
2824
// NOTE: Part of the NodeTraverser interface.
2825
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
UNCOV
2826
        cb func(channel *DirectedChannel) error, _ func()) error {
×
UNCOV
2827

×
2828
        ctx := context.TODO()
×
2829

×
2830
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2831
}
×
2832

2833
// FetchNodeFeatures returns the features of the given node. If the node is
2834
// unknown, assume no additional features are supported.
2835
//
2836
// NOTE: Part of the NodeTraverser interface.
2837
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
UNCOV
2838
        *lnwire.FeatureVector, error) {
×
UNCOV
2839

×
2840
        ctx := context.TODO()
×
2841

×
2842
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2843
}
×
2844

2845
// forEachNodeDirectedChannel iterates through all channels of a given
2846
// node, executing the passed callback on the directed edge representing the
2847
// channel and its incoming policy. If the node is not found, no error is
2848
// returned.
2849
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2850
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2851

×
2852
        toNodeCallback := func() route.Vertex {
×
2853
                return nodePub
×
2854
        }
×
2855

2856
        dbID, err := db.GetNodeIDByPubKey(
×
UNCOV
2857
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2858
                        Version: int16(ProtocolV1),
×
2859
                        PubKey:  nodePub[:],
×
2860
                },
×
2861
        )
×
2862
        if errors.Is(err, sql.ErrNoRows) {
×
2863
                return nil
×
2864
        } else if err != nil {
×
2865
                return fmt.Errorf("unable to fetch node: %w", err)
×
2866
        }
×
2867

UNCOV
2868
        rows, err := db.ListChannelsByNodeID(
×
UNCOV
2869
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2870
                        Version: int16(ProtocolV1),
×
2871
                        NodeID1: dbID,
×
2872
                },
×
UNCOV
2873
        )
×
2874
        if err != nil {
×
2875
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2876
        }
×
2877

2878
        // Exit early if there are no channels for this node so we don't
2879
        // do the unnecessary feature fetching.
2880
        if len(rows) == 0 {
×
2881
                return nil
×
2882
        }
×
2883

2884
        features, err := getNodeFeatures(ctx, db, dbID)
×
2885
        if err != nil {
×
2886
                return fmt.Errorf("unable to fetch node features: %w", err)
×
UNCOV
2887
        }
×
2888

2889
        for _, row := range rows {
×
2890
                node1, node2, err := buildNodeVertices(
×
2891
                        row.Node1Pubkey, row.Node2Pubkey,
×
2892
                )
×
2893
                if err != nil {
×
2894
                        return fmt.Errorf("unable to build node vertices: %w",
×
2895
                                err)
×
2896
                }
×
2897

2898
                edge := buildCacheableChannelInfo(
×
2899
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
2900
                        node1, node2,
×
2901
                )
×
2902

×
2903
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
UNCOV
2904
                if err != nil {
×
UNCOV
2905
                        return err
×
UNCOV
2906
                }
×
2907

2908
                p1, p2, err := buildCachedChanPolicies(
×
2909
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
2910
                )
×
2911
                if err != nil {
×
2912
                        return err
×
UNCOV
2913
                }
×
2914

2915
                // Determine the outgoing and incoming policy for this
2916
                // channel and node combo.
2917
                outPolicy, inPolicy := p1, p2
×
2918
                if p1 != nil && node2 == nodePub {
×
2919
                        outPolicy, inPolicy = p2, p1
×
UNCOV
2920
                } else if p2 != nil && node1 != nodePub {
×
2921
                        outPolicy, inPolicy = p2, p1
×
2922
                }
×
2923

2924
                var cachedInPolicy *models.CachedEdgePolicy
×
2925
                if inPolicy != nil {
×
2926
                        cachedInPolicy = inPolicy
×
2927
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2928
                        cachedInPolicy.ToNodeFeatures = features
×
2929
                }
×
2930

2931
                directedChannel := &DirectedChannel{
×
2932
                        ChannelID:    edge.ChannelID,
×
UNCOV
2933
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
UNCOV
2934
                        OtherNode:    edge.NodeKey2Bytes,
×
2935
                        Capacity:     edge.Capacity,
×
2936
                        OutPolicySet: outPolicy != nil,
×
2937
                        InPolicy:     cachedInPolicy,
×
UNCOV
2938
                }
×
2939
                if outPolicy != nil {
×
2940
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2941
                                directedChannel.InboundFee = fee
×
UNCOV
2942
                        })
×
2943
                }
2944

UNCOV
2945
                if nodePub == edge.NodeKey2Bytes {
×
UNCOV
2946
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
UNCOV
2947
                }
×
2948

UNCOV
2949
                if err := cb(directedChannel); err != nil {
×
UNCOV
2950
                        return err
×
UNCOV
2951
                }
×
2952
        }
2953

2954
        return nil
×
2955
}
2956

2957
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2958
// and executes the provided callback for each node. It does so via pagination
2959
// along with batch loading of the node feature bits.
2960
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
2961
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
2962
                features *lnwire.FeatureVector) error) error {
×
UNCOV
2963

×
UNCOV
2964
        handleNode := func(_ context.Context,
×
2965
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
2966
                featureBits map[int64][]int) error {
×
2967

×
2968
                fv := lnwire.EmptyFeatureVector()
×
UNCOV
2969
                if features, exists := featureBits[dbNode.ID]; exists {
×
UNCOV
2970
                        for _, bit := range features {
×
2971
                                fv.Set(lnwire.FeatureBit(bit))
×
2972
                        }
×
2973
                }
2974

2975
                var pub route.Vertex
×
2976
                copy(pub[:], dbNode.PubKey)
×
2977

×
2978
                return processNode(dbNode.ID, pub, fv)
×
2979
        }
2980

2981
        queryFunc := func(ctx context.Context, lastID int64,
×
UNCOV
2982
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
2983

×
2984
                return db.ListNodeIDsAndPubKeys(
×
2985
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
UNCOV
2986
                                Version: int16(ProtocolV1),
×
2987
                                ID:      lastID,
×
2988
                                Limit:   limit,
×
2989
                        },
×
UNCOV
2990
                )
×
2991
        }
×
2992

2993
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
2994
                return row.ID
×
2995
        }
×
2996

2997
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
2998
                return node.ID, nil
×
2999
        }
×
3000

UNCOV
3001
        batchQueryFunc := func(ctx context.Context,
×
UNCOV
3002
                nodeIDs []int64) (map[int64][]int, error) {
×
UNCOV
3003

×
UNCOV
3004
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
UNCOV
3005
        }
×
3006

UNCOV
3007
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
UNCOV
3008
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
UNCOV
3009
                batchQueryFunc, handleNode,
×
3010
        )
×
3011
}
3012

3013
// forEachNodeChannel iterates through all channels of a node, executing
3014
// the passed callback on each. The call-back is provided with the channel's
3015
// edge information, the outgoing policy and the incoming policy for the
3016
// channel and node combo.
3017
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3018
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
3019
                *models.ChannelEdgePolicy,
3020
                *models.ChannelEdgePolicy) error) error {
×
3021

×
UNCOV
3022
        // Get all the V1 channels for this node.
×
UNCOV
3023
        rows, err := db.ListChannelsByNodeID(
×
3024
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3025
                        Version: int16(ProtocolV1),
×
3026
                        NodeID1: id,
×
3027
                },
×
3028
        )
×
3029
        if err != nil {
×
3030
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3031
        }
×
3032

3033
        // Collect all the channel and policy IDs.
3034
        var (
×
3035
                chanIDs   = make([]int64, 0, len(rows))
×
3036
                policyIDs = make([]int64, 0, 2*len(rows))
×
UNCOV
3037
        )
×
UNCOV
3038
        for _, row := range rows {
×
3039
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3040

×
3041
                if row.Policy1ID.Valid {
×
3042
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3043
                }
×
3044
                if row.Policy2ID.Valid {
×
UNCOV
3045
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
UNCOV
3046
                }
×
3047
        }
3048

3049
        batchData, err := batchLoadChannelData(
×
3050
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3051
        )
×
3052
        if err != nil {
×
3053
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3054
        }
×
3055

3056
        // Call the call-back for each channel and its known policies.
3057
        for _, row := range rows {
×
3058
                node1, node2, err := buildNodeVertices(
×
3059
                        row.Node1Pubkey, row.Node2Pubkey,
×
3060
                )
×
3061
                if err != nil {
×
3062
                        return fmt.Errorf("unable to build node vertices: %w",
×
3063
                                err)
×
UNCOV
3064
                }
×
3065

3066
                edge, err := buildEdgeInfoWithBatchData(
×
3067
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3068
                        batchData,
×
3069
                )
×
UNCOV
3070
                if err != nil {
×
3071
                        return fmt.Errorf("unable to build channel info: %w",
×
3072
                                err)
×
3073
                }
×
3074

3075
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3076
                if err != nil {
×
3077
                        return fmt.Errorf("unable to extract channel "+
×
UNCOV
3078
                                "policies: %w", err)
×
UNCOV
3079
                }
×
3080

3081
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3082
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3083
                )
×
3084
                if err != nil {
×
3085
                        return fmt.Errorf("unable to build channel "+
×
3086
                                "policies: %w", err)
×
3087
                }
×
3088

3089
                // Determine the outgoing and incoming policy for this
3090
                // channel and node combo.
3091
                p1ToNode := row.GraphChannel.NodeID2
×
3092
                p2ToNode := row.GraphChannel.NodeID1
×
UNCOV
3093
                outPolicy, inPolicy := p1, p2
×
UNCOV
3094
                if (p1 != nil && p1ToNode == id) ||
×
3095
                        (p2 != nil && p2ToNode != id) {
×
UNCOV
3096

×
UNCOV
3097
                        outPolicy, inPolicy = p2, p1
×
UNCOV
3098
                }
×
3099

UNCOV
3100
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
UNCOV
3101
                        return err
×
3102
                }
×
3103
        }
3104

3105
        return nil
×
3106
}
3107

3108
// updateChanEdgePolicy upserts the channel policy info we have stored for
3109
// a channel we already know of.
3110
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3111
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3112
        error) {
×
3113

×
3114
        var (
×
3115
                node1Pub, node2Pub route.Vertex
×
3116
                isNode1            bool
×
3117
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3118
        )
×
3119

×
3120
        // Check that this edge policy refers to a channel that we already
×
3121
        // know of. We do this explicitly so that we can return the appropriate
×
3122
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3123
        // abort the transaction which would abort the entire batch.
×
3124
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3125
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
UNCOV
3126
                        Scid:    chanIDB,
×
3127
                        Version: int16(ProtocolV1),
×
3128
                },
×
3129
        )
×
3130
        if errors.Is(err, sql.ErrNoRows) {
×
3131
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3132
        } else if err != nil {
×
3133
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3134
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3135
        }
×
3136

3137
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3138
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3139

×
3140
        // Figure out which node this edge is from.
×
3141
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3142
        nodeID := dbChan.NodeID1
×
3143
        if !isNode1 {
×
3144
                nodeID = dbChan.NodeID2
×
UNCOV
3145
        }
×
3146

3147
        var (
×
3148
                inboundBase sql.NullInt64
×
3149
                inboundRate sql.NullInt64
×
3150
        )
×
3151
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3152
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3153
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3154
        })
×
3155

3156
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3157
                Version:     int16(ProtocolV1),
×
3158
                ChannelID:   dbChan.ID,
×
3159
                NodeID:      nodeID,
×
3160
                Timelock:    int32(edge.TimeLockDelta),
×
3161
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3162
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3163
                MinHtlcMsat: int64(edge.MinHTLC),
×
3164
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3165
                Disabled: sql.NullBool{
×
3166
                        Valid: true,
×
3167
                        Bool:  edge.IsDisabled(),
×
3168
                },
×
3169
                MaxHtlcMsat: sql.NullInt64{
×
3170
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3171
                        Int64: int64(edge.MaxHTLC),
×
3172
                },
×
UNCOV
3173
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
UNCOV
3174
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
UNCOV
3175
                InboundBaseFeeMsat:      inboundBase,
×
3176
                InboundFeeRateMilliMsat: inboundRate,
×
3177
                Signature:               edge.SigBytes,
×
3178
        })
×
3179
        if err != nil {
×
3180
                return node1Pub, node2Pub, isNode1,
×
UNCOV
3181
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
UNCOV
3182
        }
×
3183

3184
        // Convert the flat extra opaque data into a map of TLV types to
3185
        // values.
3186
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3187
        if err != nil {
×
UNCOV
3188
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3189
                        "marshal extra opaque data: %w", err)
×
UNCOV
3190
        }
×
3191

3192
        // Update the channel policy's extra signed fields.
UNCOV
3193
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3194
        if err != nil {
×
3195
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3196
                        "policy extra TLVs: %w", err)
×
3197
        }
×
3198

3199
        return node1Pub, node2Pub, isNode1, nil
×
3200
}
3201

3202
// getNodeByPubKey attempts to look up a target node by its public key.
3203
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3204
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3205

×
3206
        dbNode, err := db.GetNodeByPubKey(
×
UNCOV
3207
                ctx, sqlc.GetNodeByPubKeyParams{
×
3208
                        Version: int16(ProtocolV1),
×
3209
                        PubKey:  pubKey[:],
×
3210
                },
×
3211
        )
×
UNCOV
3212
        if errors.Is(err, sql.ErrNoRows) {
×
3213
                return 0, nil, ErrGraphNodeNotFound
×
UNCOV
3214
        } else if err != nil {
×
UNCOV
3215
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
UNCOV
3216
        }
×
3217

UNCOV
3218
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3219
        if err != nil {
×
3220
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3221
        }
×
3222

3223
        return dbNode.ID, node, nil
×
3224
}
3225

3226
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3227
// provided parameters.
3228
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
UNCOV
3229
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
UNCOV
3230

×
UNCOV
3231
        return &models.CachedEdgeInfo{
×
UNCOV
3232
                ChannelID:     byteOrder.Uint64(scid),
×
3233
                NodeKey1Bytes: node1Pub,
×
3234
                NodeKey2Bytes: node2Pub,
×
3235
                Capacity:      btcutil.Amount(capacity),
×
3236
        }
×
3237
}
×
3238

3239
// buildNode constructs a LightningNode instance from the given database node
3240
// record. The node's features, addresses and extra signed fields are also
3241
// fetched from the database and set on the node.
3242
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
UNCOV
3243
        dbNode sqlc.GraphNode) (*models.LightningNode, error) {
×
UNCOV
3244

×
UNCOV
3245
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
UNCOV
3246
        if err != nil {
×
UNCOV
3247
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
UNCOV
3248
                        err)
×
3249
        }
×
3250

3251
        return buildNodeWithBatchData(dbNode, data)
×
3252
}
3253

3254
// buildNodeWithBatchData builds a models.LightningNode instance
3255
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3256
// features/addresses/extra fields, then the corresponding fields are expected
3257
// to be present in the batchNodeData.
3258
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3259
        batchData *batchNodeData) (*models.LightningNode, error) {
×
3260

×
3261
        if dbNode.Version != int16(ProtocolV1) {
×
3262
                return nil, fmt.Errorf("unsupported node version: %d",
×
3263
                        dbNode.Version)
×
3264
        }
×
3265

3266
        var pub [33]byte
×
3267
        copy(pub[:], dbNode.PubKey)
×
UNCOV
3268

×
3269
        node := &models.LightningNode{
×
3270
                PubKeyBytes: pub,
×
3271
                Features:    lnwire.EmptyFeatureVector(),
×
3272
                LastUpdate:  time.Unix(0, 0),
×
3273
        }
×
3274

×
3275
        if len(dbNode.Signature) == 0 {
×
3276
                return node, nil
×
3277
        }
×
3278

3279
        node.HaveNodeAnnouncement = true
×
3280
        node.AuthSigBytes = dbNode.Signature
×
UNCOV
3281
        node.Alias = dbNode.Alias.String
×
UNCOV
3282
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
UNCOV
3283

×
3284
        var err error
×
3285
        if dbNode.Color.Valid {
×
3286
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3287
                if err != nil {
×
3288
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3289
                                err)
×
UNCOV
3290
                }
×
3291
        }
3292

3293
        // Use preloaded features.
3294
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3295
                fv := lnwire.EmptyFeatureVector()
×
3296
                for _, bit := range features {
×
3297
                        fv.Set(lnwire.FeatureBit(bit))
×
3298
                }
×
3299
                node.Features = fv
×
3300
        }
3301

3302
        // Use preloaded addresses.
3303
        addresses, exists := batchData.addresses[dbNode.ID]
×
3304
        if exists && len(addresses) > 0 {
×
3305
                node.Addresses, err = buildNodeAddresses(addresses)
×
3306
                if err != nil {
×
3307
                        return nil, fmt.Errorf("unable to build addresses "+
×
3308
                                "for node(%d): %w", dbNode.ID, err)
×
3309
                }
×
3310
        }
3311

3312
        // Use preloaded extra fields.
UNCOV
3313
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3314
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
UNCOV
3315
                if err != nil {
×
UNCOV
3316
                        return nil, fmt.Errorf("unable to serialize extra "+
×
UNCOV
3317
                                "signed fields: %w", err)
×
UNCOV
3318
                }
×
UNCOV
3319
                if len(recs) != 0 {
×
UNCOV
3320
                        node.ExtraOpaqueData = recs
×
3321
                }
×
3322
        }
3323

3324
        return node, nil
×
3325
}
3326

3327
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3328
// with the preloaded data, and executes the provided callback for each node.
3329
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3330
        db SQLQueries, nodes []sqlc.GraphNode,
3331
        cb func(dbID int64, node *models.LightningNode) error) error {
×
3332

×
3333
        // Extract node IDs for batch loading.
×
UNCOV
3334
        nodeIDs := make([]int64, len(nodes))
×
3335
        for i, node := range nodes {
×
3336
                nodeIDs[i] = node.ID
×
3337
        }
×
3338

3339
        // Batch load all related data for this page.
3340
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
UNCOV
3341
        if err != nil {
×
3342
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3343
        }
×
3344

3345
        for _, dbNode := range nodes {
×
UNCOV
3346
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
UNCOV
3347
                if err != nil {
×
3348
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
UNCOV
3349
                                dbNode.ID, err)
×
UNCOV
3350
                }
×
3351

UNCOV
3352
                if err := cb(dbNode.ID, node); err != nil {
×
UNCOV
3353
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3354
                                dbNode.ID, err)
×
3355
                }
×
3356
        }
3357

3358
        return nil
×
3359
}
3360

3361
// getNodeFeatures fetches the feature bits and constructs the feature vector
3362
// for a node with the given DB ID.
3363
func getNodeFeatures(ctx context.Context, db SQLQueries,
3364
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3365

×
UNCOV
3366
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3367
        if err != nil {
×
UNCOV
3368
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
UNCOV
3369
                        nodeID, err)
×
UNCOV
3370
        }
×
3371

UNCOV
3372
        features := lnwire.EmptyFeatureVector()
×
UNCOV
3373
        for _, feature := range rows {
×
UNCOV
3374
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3375
        }
×
3376

3377
        return features, nil
×
3378
}
3379

3380
// upsertNode upserts the node record into the database. If the node already
3381
// exists, then the node's information is updated. If the node doesn't exist,
3382
// then a new node is created. The node's features, addresses and extra TLV
3383
// types are also updated. The node's DB ID is returned.
3384
func upsertNode(ctx context.Context, db SQLQueries,
3385
        node *models.LightningNode) (int64, error) {
×
3386

×
3387
        params := sqlc.UpsertNodeParams{
×
UNCOV
3388
                Version: int16(ProtocolV1),
×
3389
                PubKey:  node.PubKeyBytes[:],
×
3390
        }
×
3391

×
3392
        if node.HaveNodeAnnouncement {
×
3393
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
UNCOV
3394
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
UNCOV
3395
                params.Alias = sqldb.SQLStr(node.Alias)
×
3396
                params.Signature = node.AuthSigBytes
×
3397
        }
×
3398

UNCOV
3399
        nodeID, err := db.UpsertNode(ctx, params)
×
UNCOV
3400
        if err != nil {
×
3401
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3402
                        err)
×
3403
        }
×
3404

3405
        // We can exit here if we don't have the announcement yet.
UNCOV
3406
        if !node.HaveNodeAnnouncement {
×
3407
                return nodeID, nil
×
3408
        }
×
3409

3410
        // Update the node's features.
UNCOV
3411
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
UNCOV
3412
        if err != nil {
×
UNCOV
3413
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3414
        }
×
3415

3416
        // Update the node's addresses.
3417
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3418
        if err != nil {
×
UNCOV
3419
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
UNCOV
3420
        }
×
3421

3422
        // Convert the flat extra opaque data into a map of TLV types to
3423
        // values.
3424
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
UNCOV
3425
        if err != nil {
×
3426
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
UNCOV
3427
                        err)
×
UNCOV
3428
        }
×
3429

3430
        // Update the node's extra signed fields.
UNCOV
3431
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
UNCOV
3432
        if err != nil {
×
UNCOV
3433
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3434
        }
×
3435

3436
        return nodeID, nil
×
3437
}
3438

3439
// upsertNodeFeatures updates the node's features node_features table. This
3440
// includes deleting any feature bits no longer present and inserting any new
3441
// feature bits. If the feature bit does not yet exist in the features table,
3442
// then an entry is created in that table first.
3443
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3444
        features *lnwire.FeatureVector) error {
×
3445

×
3446
        // Get any existing features for the node.
×
3447
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
UNCOV
3448
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
UNCOV
3449
                return err
×
UNCOV
3450
        }
×
3451

3452
        // Copy the nodes latest set of feature bits.
3453
        newFeatures := make(map[int32]struct{})
×
3454
        if features != nil {
×
3455
                for feature := range features.Features() {
×
3456
                        newFeatures[int32(feature)] = struct{}{}
×
3457
                }
×
3458
        }
3459

3460
        // For any current feature that already exists in the DB, remove it from
3461
        // the in-memory map. For any existing feature that does not exist in
3462
        // the in-memory map, delete it from the database.
3463
        for _, feature := range existingFeatures {
×
3464
                // The feature is still present, so there are no updates to be
×
3465
                // made.
×
3466
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3467
                        delete(newFeatures, feature.FeatureBit)
×
3468
                        continue
×
3469
                }
3470

3471
                // The feature is no longer present, so we remove it from the
3472
                // database.
UNCOV
3473
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
UNCOV
3474
                        NodeID:     nodeID,
×
UNCOV
3475
                        FeatureBit: feature.FeatureBit,
×
3476
                })
×
3477
                if err != nil {
×
3478
                        return fmt.Errorf("unable to delete node(%d) "+
×
3479
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3480
                                err)
×
3481
                }
×
3482
        }
3483

3484
        // Any remaining entries in newFeatures are new features that need to be
3485
        // added to the database for the first time.
UNCOV
3486
        for feature := range newFeatures {
×
3487
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
UNCOV
3488
                        NodeID:     nodeID,
×
UNCOV
3489
                        FeatureBit: feature,
×
UNCOV
3490
                })
×
UNCOV
3491
                if err != nil {
×
3492
                        return fmt.Errorf("unable to insert node(%d) "+
×
3493
                                "feature(%v): %w", nodeID, feature, err)
×
3494
                }
×
3495
        }
3496

3497
        return nil
×
3498
}
3499

3500
// fetchNodeFeatures fetches the features for a node with the given public key.
3501
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3502
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3503

×
UNCOV
3504
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3505
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3506
                        PubKey:  nodePub[:],
×
3507
                        Version: int16(ProtocolV1),
×
3508
                },
×
UNCOV
3509
        )
×
3510
        if err != nil {
×
UNCOV
3511
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
UNCOV
3512
                        nodePub, err)
×
UNCOV
3513
        }
×
3514

UNCOV
3515
        features := lnwire.EmptyFeatureVector()
×
UNCOV
3516
        for _, bit := range rows {
×
UNCOV
3517
                features.Set(lnwire.FeatureBit(bit))
×
UNCOV
3518
        }
×
3519

UNCOV
3520
        return features, nil
×
3521
}
3522

3523
// dbAddressType is an enum type that represents the different address types
3524
// that we store in the node_addresses table. The address type determines how
3525
// the address is to be serialised/deserialize.
3526
type dbAddressType uint8
3527

3528
const (
3529
        addressTypeIPv4   dbAddressType = 1
3530
        addressTypeIPv6   dbAddressType = 2
3531
        addressTypeTorV2  dbAddressType = 3
3532
        addressTypeTorV3  dbAddressType = 4
3533
        addressTypeDNS    dbAddressType = 5
3534
        addressTypeOpaque dbAddressType = math.MaxInt8
3535
)
3536

3537
// collectAddressRecords collects the addresses from the provided
3538
// net.Addr slice and returns a map of dbAddressType to a slice of address
3539
// strings.
3540
func collectAddressRecords(addresses []net.Addr) (map[dbAddressType][]string,
3541
        error) {
×
3542

×
UNCOV
3543
        // Copy the nodes latest set of addresses.
×
UNCOV
3544
        newAddresses := map[dbAddressType][]string{
×
3545
                addressTypeIPv4:   {},
×
3546
                addressTypeIPv6:   {},
×
3547
                addressTypeTorV2:  {},
×
3548
                addressTypeTorV3:  {},
×
3549
                addressTypeDNS:    {},
×
3550
                addressTypeOpaque: {},
×
3551
        }
×
3552
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3553
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3554
        }
×
3555

3556
        for _, address := range addresses {
×
3557
                switch addr := address.(type) {
×
3558
                case *net.TCPAddr:
×
3559
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3560
                                addAddr(addressTypeIPv4, addr)
×
3561
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3562
                                addAddr(addressTypeIPv6, addr)
×
3563
                        } else {
×
3564
                                return nil, fmt.Errorf("unhandled IP "+
×
3565
                                        "address: %v", addr)
×
3566
                        }
×
3567

3568
                case *tor.OnionAddr:
×
3569
                        switch len(addr.OnionService) {
×
3570
                        case tor.V2Len:
×
3571
                                addAddr(addressTypeTorV2, addr)
×
3572
                        case tor.V3Len:
×
3573
                                addAddr(addressTypeTorV3, addr)
×
3574
                        default:
×
3575
                                return nil, fmt.Errorf("invalid length for " +
×
3576
                                        "a tor address")
×
3577
                        }
3578

3579
                case *lnwire.DNSAddress:
×
3580
                        addAddr(addressTypeDNS, addr)
×
3581

3582
                case *lnwire.OpaqueAddrs:
×
3583
                        addAddr(addressTypeOpaque, addr)
×
3584

UNCOV
3585
                default:
×
UNCOV
3586
                        return nil, fmt.Errorf("unhandled address type: %T",
×
UNCOV
3587
                                addr)
×
3588
                }
3589
        }
3590

3591
        return newAddresses, nil
×
3592
}
3593

3594
// upsertNodeAddresses updates the node's addresses in the database. This
3595
// includes deleting any existing addresses and inserting the new set of
3596
// addresses. The deletion is necessary since the ordering of the addresses may
3597
// change, and we need to ensure that the database reflects the latest set of
3598
// addresses so that at the time of reconstructing the node announcement, the
3599
// order is preserved and the signature over the message remains valid.
3600
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3601
        addresses []net.Addr) error {
×
3602

×
3603
        // Delete any existing addresses for the node. This is required since
×
UNCOV
3604
        // even if the new set of addresses is the same, the ordering may have
×
UNCOV
3605
        // changed for a given address type.
×
UNCOV
3606
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3607
        if err != nil {
×
UNCOV
3608
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
UNCOV
3609
                        nodeID, err)
×
UNCOV
3610
        }
×
3611

3612
        newAddresses, err := collectAddressRecords(addresses)
×
3613
        if err != nil {
×
3614
                return err
×
3615
        }
×
3616

3617
        // Any remaining entries in newAddresses are new addresses that need to
3618
        // be added to the database for the first time.
3619
        for addrType, addrList := range newAddresses {
×
UNCOV
3620
                for position, addr := range addrList {
×
3621
                        err := db.UpsertNodeAddress(
×
3622
                                ctx, sqlc.UpsertNodeAddressParams{
×
3623
                                        NodeID:   nodeID,
×
3624
                                        Type:     int16(addrType),
×
3625
                                        Address:  addr,
×
3626
                                        Position: int32(position),
×
3627
                                },
×
3628
                        )
×
3629
                        if err != nil {
×
UNCOV
3630
                                return fmt.Errorf("unable to insert "+
×
3631
                                        "node(%d) address(%v): %w", nodeID,
×
UNCOV
3632
                                        addr, err)
×
UNCOV
3633
                        }
×
3634
                }
3635
        }
3636

3637
        return nil
×
3638
}
3639

3640
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3641
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
UNCOV
3642
        error) {
×
UNCOV
3643

×
UNCOV
3644
        // GetNodeAddresses ensures that the addresses for a given type are
×
UNCOV
3645
        // returned in the same order as they were inserted.
×
UNCOV
3646
        rows, err := db.GetNodeAddresses(ctx, id)
×
3647
        if err != nil {
×
3648
                return nil, err
×
3649
        }
×
3650

3651
        addresses := make([]net.Addr, 0, len(rows))
×
3652
        for _, row := range rows {
×
3653
                address := row.Address
×
UNCOV
3654

×
UNCOV
3655
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
UNCOV
3656
                if err != nil {
×
3657
                        return nil, fmt.Errorf("unable to parse address "+
×
3658
                                "for node(%d): %v: %w", id, address, err)
×
3659
                }
×
3660

UNCOV
3661
                addresses = append(addresses, addr)
×
3662
        }
3663

3664
        // If we have no addresses, then we'll return nil instead of an
3665
        // empty slice.
3666
        if len(addresses) == 0 {
×
3667
                addresses = nil
×
3668
        }
×
3669

3670
        return addresses, nil
×
3671
}
3672

3673
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3674
// database. This includes updating any existing types, inserting any new types,
3675
// and deleting any types that are no longer present.
3676
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
UNCOV
3677
        nodeID int64, extraFields map[uint64][]byte) error {
×
UNCOV
3678

×
3679
        // Get any existing extra signed fields for the node.
×
UNCOV
3680
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
UNCOV
3681
        if err != nil {
×
UNCOV
3682
                return err
×
UNCOV
3683
        }
×
3684

3685
        // Make a lookup map of the existing field types so that we can use it
3686
        // to keep track of any fields we should delete.
3687
        m := make(map[uint64]bool)
×
3688
        for _, field := range existingFields {
×
3689
                m[uint64(field.Type)] = true
×
3690
        }
×
3691

3692
        // For all the new fields, we'll upsert them and remove them from the
3693
        // map of existing fields.
3694
        for tlvType, value := range extraFields {
×
UNCOV
3695
                err = db.UpsertNodeExtraType(
×
UNCOV
3696
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3697
                                NodeID: nodeID,
×
UNCOV
3698
                                Type:   int64(tlvType),
×
UNCOV
3699
                                Value:  value,
×
UNCOV
3700
                        },
×
UNCOV
3701
                )
×
UNCOV
3702
                if err != nil {
×
UNCOV
3703
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
UNCOV
3704
                                "signed field(%v): %w", nodeID, tlvType, err)
×
UNCOV
3705
                }
×
3706

3707
                // Remove the field from the map of existing fields if it was
3708
                // present.
UNCOV
3709
                delete(m, tlvType)
×
3710
        }
3711

3712
        // For all the fields that are left in the map of existing fields, we'll
3713
        // delete them as they are no longer present in the new set of fields.
3714
        for tlvType := range m {
×
3715
                err = db.DeleteExtraNodeType(
×
3716
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3717
                                NodeID: nodeID,
×
3718
                                Type:   int64(tlvType),
×
3719
                        },
×
3720
                )
×
3721
                if err != nil {
×
UNCOV
3722
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3723
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3724
                }
×
3725
        }
3726

3727
        return nil
×
3728
}
3729

3730
// srcNodeInfo holds the information about the source node of the graph.
3731
type srcNodeInfo struct {
3732
        // id is the DB level ID of the source node entry in the "nodes" table.
3733
        id int64
3734

3735
        // pub is the public key of the source node.
3736
        pub route.Vertex
3737
}
3738

3739
// sourceNode returns the DB node ID and pub key of the source node for the
3740
// specified protocol version.
3741
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3742
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3743

×
3744
        s.srcNodeMu.Lock()
×
3745
        defer s.srcNodeMu.Unlock()
×
UNCOV
3746

×
UNCOV
3747
        // If we already have the source node ID and pub key cached, then
×
UNCOV
3748
        // return them.
×
UNCOV
3749
        if info, ok := s.srcNodes[version]; ok {
×
UNCOV
3750
                return info.id, info.pub, nil
×
3751
        }
×
3752

3753
        var pubKey route.Vertex
×
3754

×
3755
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3756
        if err != nil {
×
3757
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
UNCOV
3758
                        err)
×
UNCOV
3759
        }
×
3760

3761
        if len(nodes) == 0 {
×
3762
                return 0, pubKey, ErrSourceNodeNotSet
×
3763
        } else if len(nodes) > 1 {
×
3764
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3765
                        "protocol %s found", version)
×
3766
        }
×
3767

UNCOV
3768
        copy(pubKey[:], nodes[0].PubKey)
×
3769

×
3770
        s.srcNodes[version] = &srcNodeInfo{
×
3771
                id:  nodes[0].NodeID,
×
3772
                pub: pubKey,
×
UNCOV
3773
        }
×
3774

×
UNCOV
3775
        return nodes[0].NodeID, pubKey, nil
×
3776
}
3777

3778
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3779
// This then produces a map from TLV type to value. If the input is not a
3780
// valid TLV stream, then an error is returned.
UNCOV
3781
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
UNCOV
3782
        r := bytes.NewReader(data)
×
UNCOV
3783

×
UNCOV
3784
        tlvStream, err := tlv.NewStream()
×
UNCOV
3785
        if err != nil {
×
UNCOV
3786
                return nil, err
×
3787
        }
×
3788

3789
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3790
        // pass it into the P2P decoding variant.
3791
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3792
        if err != nil {
×
3793
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3794
        }
×
UNCOV
3795
        if len(parsedTypes) == 0 {
×
3796
                return nil, nil
×
3797
        }
×
3798

3799
        records := make(map[uint64][]byte)
×
UNCOV
3800
        for k, v := range parsedTypes {
×
3801
                records[uint64(k)] = v
×
3802
        }
×
3803

3804
        return records, nil
×
3805
}
3806

3807
// insertChannel inserts a new channel record into the database.
3808
func insertChannel(ctx context.Context, db SQLQueries,
3809
        edge *models.ChannelEdgeInfo) error {
×
3810

×
3811
        // Make sure that at least a "shell" entry for each node is present in
×
3812
        // the nodes table.
×
3813
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3814
        if err != nil {
×
3815
                return fmt.Errorf("unable to create shell node: %w", err)
×
3816
        }
×
3817

3818
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3819
        if err != nil {
×
3820
                return fmt.Errorf("unable to create shell node: %w", err)
×
3821
        }
×
3822

3823
        var capacity sql.NullInt64
×
3824
        if edge.Capacity != 0 {
×
UNCOV
3825
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
UNCOV
3826
        }
×
3827

3828
        createParams := sqlc.CreateChannelParams{
×
3829
                Version:     int16(ProtocolV1),
×
3830
                Scid:        channelIDToBytes(edge.ChannelID),
×
UNCOV
3831
                NodeID1:     node1DBID,
×
UNCOV
3832
                NodeID2:     node2DBID,
×
3833
                Outpoint:    edge.ChannelPoint.String(),
×
3834
                Capacity:    capacity,
×
3835
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3836
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3837
        }
×
3838

×
3839
        if edge.AuthProof != nil {
×
3840
                proof := edge.AuthProof
×
3841

×
3842
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3843
                createParams.Node2Signature = proof.NodeSig2Bytes
×
UNCOV
3844
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
UNCOV
3845
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
UNCOV
3846
        }
×
3847

3848
        // Insert the new channel record.
3849
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3850
        if err != nil {
×
3851
                return err
×
UNCOV
3852
        }
×
3853

3854
        // Insert any channel features.
3855
        for feature := range edge.Features.Features() {
×
3856
                err = db.InsertChannelFeature(
×
3857
                        ctx, sqlc.InsertChannelFeatureParams{
×
3858
                                ChannelID:  dbChanID,
×
3859
                                FeatureBit: int32(feature),
×
3860
                        },
×
3861
                )
×
3862
                if err != nil {
×
3863
                        return fmt.Errorf("unable to insert channel(%d) "+
×
3864
                                "feature(%v): %w", dbChanID, feature, err)
×
3865
                }
×
3866
        }
3867

3868
        // Finally, insert any extra TLV fields in the channel announcement.
3869
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3870
        if err != nil {
×
3871
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
3872
                        err)
×
UNCOV
3873
        }
×
3874

UNCOV
3875
        for tlvType, value := range extra {
×
UNCOV
3876
                err := db.UpsertChannelExtraType(
×
UNCOV
3877
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
UNCOV
3878
                                ChannelID: dbChanID,
×
UNCOV
3879
                                Type:      int64(tlvType),
×
3880
                                Value:     value,
×
3881
                        },
×
3882
                )
×
3883
                if err != nil {
×
3884
                        return fmt.Errorf("unable to upsert channel(%d) "+
×
3885
                                "extra signed field(%v): %w", edge.ChannelID,
×
3886
                                tlvType, err)
×
3887
                }
×
3888
        }
3889

3890
        return nil
×
3891
}
3892

3893
// maybeCreateShellNode checks if a shell node entry exists for the
3894
// given public key. If it does not exist, then a new shell node entry is
3895
// created. The ID of the node is returned. A shell node only has a protocol
3896
// version and public key persisted.
3897
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3898
        pubKey route.Vertex) (int64, error) {
×
3899

×
3900
        dbNode, err := db.GetNodeByPubKey(
×
3901
                ctx, sqlc.GetNodeByPubKeyParams{
×
3902
                        PubKey:  pubKey[:],
×
3903
                        Version: int16(ProtocolV1),
×
UNCOV
3904
                },
×
3905
        )
×
UNCOV
3906
        // The node exists. Return the ID.
×
UNCOV
3907
        if err == nil {
×
UNCOV
3908
                return dbNode.ID, nil
×
UNCOV
3909
        } else if !errors.Is(err, sql.ErrNoRows) {
×
UNCOV
3910
                return 0, err
×
UNCOV
3911
        }
×
3912

3913
        // Otherwise, the node does not exist, so we create a shell entry for
3914
        // it.
3915
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3916
                Version: int16(ProtocolV1),
×
3917
                PubKey:  pubKey[:],
×
3918
        })
×
3919
        if err != nil {
×
3920
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
UNCOV
3921
        }
×
3922

3923
        return id, nil
×
3924
}
3925

3926
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3927
// the database. This includes deleting any existing types and then inserting
3928
// the new types.
3929
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3930
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3931

×
3932
        // Delete all existing extra signed fields for the channel policy.
×
3933
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3934
        if err != nil {
×
3935
                return fmt.Errorf("unable to delete "+
×
UNCOV
3936
                        "existing policy extra signed fields for policy %d: %w",
×
UNCOV
3937
                        chanPolicyID, err)
×
3938
        }
×
3939

3940
        // Insert all new extra signed fields for the channel policy.
UNCOV
3941
        for tlvType, value := range extraFields {
×
UNCOV
3942
                err = db.UpsertChanPolicyExtraType(
×
UNCOV
3943
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
UNCOV
3944
                                ChannelPolicyID: chanPolicyID,
×
UNCOV
3945
                                Type:            int64(tlvType),
×
3946
                                Value:           value,
×
3947
                        },
×
3948
                )
×
3949
                if err != nil {
×
3950
                        return fmt.Errorf("unable to insert "+
×
3951
                                "channel_policy(%d) extra signed field(%v): %w",
×
3952
                                chanPolicyID, tlvType, err)
×
3953
                }
×
3954
        }
3955

3956
        return nil
×
3957
}
3958

3959
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3960
// provided dbChanRow and also fetches any other required information
3961
// to construct the edge info.
3962
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
3963
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
3964
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3965

×
3966
        data, err := batchLoadChannelData(
×
3967
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
3968
        )
×
3969
        if err != nil {
×
UNCOV
3970
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
UNCOV
3971
                        err)
×
3972
        }
×
3973

3974
        return buildEdgeInfoWithBatchData(
×
3975
                cfg.ChainHash, dbChan, node1, node2, data,
×
3976
        )
×
3977
}
3978

3979
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
3980
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
3981
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
3982
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
3983

×
3984
        if dbChan.Version != int16(ProtocolV1) {
×
3985
                return nil, fmt.Errorf("unsupported channel version: %d",
×
UNCOV
3986
                        dbChan.Version)
×
3987
        }
×
3988

3989
        // Use pre-loaded features and extras types.
3990
        fv := lnwire.EmptyFeatureVector()
×
UNCOV
3991
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
3992
                for _, bit := range features {
×
3993
                        fv.Set(lnwire.FeatureBit(bit))
×
3994
                }
×
3995
        }
3996

3997
        var extras map[uint64][]byte
×
3998
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
3999
        if exists {
×
UNCOV
4000
                extras = channelExtras
×
4001
        } else {
×
4002
                extras = make(map[uint64][]byte)
×
4003
        }
×
4004

4005
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4006
        if err != nil {
×
4007
                return nil, err
×
4008
        }
×
4009

4010
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4011
        if err != nil {
×
4012
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4013
                        "fields: %w", err)
×
4014
        }
×
4015
        if recs == nil {
×
4016
                recs = make([]byte, 0)
×
4017
        }
×
4018

4019
        var btcKey1, btcKey2 route.Vertex
×
4020
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4021
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4022

×
4023
        channel := &models.ChannelEdgeInfo{
×
4024
                ChainHash:        chain,
×
4025
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4026
                NodeKey1Bytes:    node1,
×
4027
                NodeKey2Bytes:    node2,
×
4028
                BitcoinKey1Bytes: btcKey1,
×
UNCOV
4029
                BitcoinKey2Bytes: btcKey2,
×
4030
                ChannelPoint:     *op,
×
UNCOV
4031
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
UNCOV
4032
                Features:         fv,
×
UNCOV
4033
                ExtraOpaqueData:  recs,
×
UNCOV
4034
        }
×
UNCOV
4035

×
4036
        // We always set all the signatures at the same time, so we can
×
4037
        // safely check if one signature is present to determine if we have the
×
4038
        // rest of the signatures for the auth proof.
×
4039
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4040
                channel.AuthProof = &models.ChannelAuthProof{
×
4041
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4042
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
UNCOV
4043
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4044
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4045
                }
×
4046
        }
×
4047

4048
        return channel, nil
×
4049
}
4050

4051
// buildNodeVertices is a helper that converts raw node public keys
4052
// into route.Vertex instances.
4053
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
UNCOV
4054
        route.Vertex, error) {
×
UNCOV
4055

×
UNCOV
4056
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
UNCOV
4057
        if err != nil {
×
UNCOV
4058
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
UNCOV
4059
                        "create vertex from node1 pubkey: %w", err)
×
4060
        }
×
4061

4062
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4063
        if err != nil {
×
4064
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
UNCOV
4065
                        "create vertex from node2 pubkey: %w", err)
×
4066
        }
×
4067

4068
        return node1Vertex, node2Vertex, nil
×
4069
}
4070

4071
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4072
// retrieves all the extra info required to build the complete
4073
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4074
// the provided sqlc.GraphChannelPolicy records are nil.
4075
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
4076
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4077
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
4078
        *models.ChannelEdgePolicy, error) {
×
UNCOV
4079

×
4080
        if dbPol1 == nil && dbPol2 == nil {
×
4081
                return nil, nil, nil
×
4082
        }
×
4083

4084
        var policyIDs = make([]int64, 0, 2)
×
4085
        if dbPol1 != nil {
×
UNCOV
4086
                policyIDs = append(policyIDs, dbPol1.ID)
×
4087
        }
×
4088
        if dbPol2 != nil {
×
4089
                policyIDs = append(policyIDs, dbPol2.ID)
×
4090
        }
×
4091

4092
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
UNCOV
4093
        if err != nil {
×
4094
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
UNCOV
4095
                        "data: %w", err)
×
UNCOV
4096
        }
×
4097

UNCOV
4098
        pol1, err := buildChanPolicyWithBatchData(
×
UNCOV
4099
                dbPol1, channelID, node2, batchData,
×
UNCOV
4100
        )
×
UNCOV
4101
        if err != nil {
×
4102
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4103
        }
×
4104

4105
        pol2, err := buildChanPolicyWithBatchData(
×
4106
                dbPol2, channelID, node1, batchData,
×
4107
        )
×
4108
        if err != nil {
×
4109
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
UNCOV
4110
        }
×
4111

UNCOV
4112
        return pol1, pol2, nil
×
4113
}
4114

4115
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4116
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4117
// then nil is returned for it.
4118
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4119
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
UNCOV
4120
        *models.CachedEdgePolicy, error) {
×
UNCOV
4121

×
4122
        var p1, p2 *models.CachedEdgePolicy
×
UNCOV
4123
        if dbPol1 != nil {
×
UNCOV
4124
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
UNCOV
4125
                if err != nil {
×
UNCOV
4126
                        return nil, nil, err
×
UNCOV
4127
                }
×
4128

4129
                p1 = models.NewCachedPolicy(policy1)
×
4130
        }
4131
        if dbPol2 != nil {
×
4132
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4133
                if err != nil {
×
4134
                        return nil, nil, err
×
4135
                }
×
4136

4137
                p2 = models.NewCachedPolicy(policy2)
×
4138
        }
4139

4140
        return p1, p2, nil
×
4141
}
4142

4143
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4144
// provided sqlc.GraphChannelPolicy and other required information.
4145
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4146
        extras map[uint64][]byte,
4147
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4148

×
4149
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4150
        if err != nil {
×
4151
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4152
                        "fields: %w", err)
×
4153
        }
×
4154

4155
        var inboundFee fn.Option[lnwire.Fee]
×
4156
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4157
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4158

×
4159
                inboundFee = fn.Some(lnwire.Fee{
×
4160
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4161
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4162
                })
×
4163
        }
×
4164

4165
        return &models.ChannelEdgePolicy{
×
4166
                SigBytes:  dbPolicy.Signature,
×
4167
                ChannelID: channelID,
×
4168
                LastUpdate: time.Unix(
×
4169
                        dbPolicy.LastUpdate.Int64, 0,
×
4170
                ),
×
4171
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4172
                        dbPolicy.MessageFlags,
×
4173
                ),
×
UNCOV
4174
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
UNCOV
4175
                        dbPolicy.ChannelFlags,
×
UNCOV
4176
                ),
×
UNCOV
4177
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
UNCOV
4178
                MinHTLC: lnwire.MilliSatoshi(
×
UNCOV
4179
                        dbPolicy.MinHtlcMsat,
×
UNCOV
4180
                ),
×
UNCOV
4181
                MaxHTLC: lnwire.MilliSatoshi(
×
UNCOV
4182
                        dbPolicy.MaxHtlcMsat.Int64,
×
4183
                ),
×
4184
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4185
                        dbPolicy.BaseFeeMsat,
×
4186
                ),
×
4187
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4188
                ToNode:                    toNode,
×
4189
                InboundFee:                inboundFee,
×
4190
                ExtraOpaqueData:           recs,
×
4191
        }, nil
×
4192
}
4193

4194
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4195
// row which is expected to be a sqlc type that contains channel policy
4196
// information. It returns two policies, which may be nil if the policy
4197
// information is not present in the row.
4198
//
4199
//nolint:ll,dupl,funlen
4200
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4201
        *sqlc.GraphChannelPolicy, error) {
×
4202

×
4203
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4204
        switch r := row.(type) {
×
4205
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4206
                if r.Policy1Timelock.Valid {
×
4207
                        policy1 = &sqlc.GraphChannelPolicy{
×
4208
                                Timelock:                r.Policy1Timelock.Int32,
×
4209
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4210
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4211
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4212
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4213
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4214
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4215
                                Disabled:                r.Policy1Disabled,
×
UNCOV
4216
                                MessageFlags:            r.Policy1MessageFlags,
×
4217
                                ChannelFlags:            r.Policy1ChannelFlags,
×
UNCOV
4218
                        }
×
4219
                }
×
4220
                if r.Policy2Timelock.Valid {
×
4221
                        policy2 = &sqlc.GraphChannelPolicy{
×
4222
                                Timelock:                r.Policy2Timelock.Int32,
×
4223
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4224
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4225
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4226
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4227
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4228
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4229
                                Disabled:                r.Policy2Disabled,
×
4230
                                MessageFlags:            r.Policy2MessageFlags,
×
4231
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4232
                        }
×
4233
                }
×
4234

4235
                return policy1, policy2, nil
×
4236

4237
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4238
                if r.Policy1ID.Valid {
×
4239
                        policy1 = &sqlc.GraphChannelPolicy{
×
4240
                                ID:                      r.Policy1ID.Int64,
×
4241
                                Version:                 r.Policy1Version.Int16,
×
4242
                                ChannelID:               r.GraphChannel.ID,
×
4243
                                NodeID:                  r.Policy1NodeID.Int64,
×
4244
                                Timelock:                r.Policy1Timelock.Int32,
×
4245
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4246
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4247
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4248
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4249
                                LastUpdate:              r.Policy1LastUpdate,
×
4250
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4251
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4252
                                Disabled:                r.Policy1Disabled,
×
4253
                                MessageFlags:            r.Policy1MessageFlags,
×
4254
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4255
                                Signature:               r.Policy1Signature,
×
4256
                        }
×
4257
                }
×
4258
                if r.Policy2ID.Valid {
×
4259
                        policy2 = &sqlc.GraphChannelPolicy{
×
UNCOV
4260
                                ID:                      r.Policy2ID.Int64,
×
4261
                                Version:                 r.Policy2Version.Int16,
×
UNCOV
4262
                                ChannelID:               r.GraphChannel.ID,
×
4263
                                NodeID:                  r.Policy2NodeID.Int64,
×
4264
                                Timelock:                r.Policy2Timelock.Int32,
×
4265
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4266
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4267
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4268
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4269
                                LastUpdate:              r.Policy2LastUpdate,
×
4270
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4271
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4272
                                Disabled:                r.Policy2Disabled,
×
4273
                                MessageFlags:            r.Policy2MessageFlags,
×
4274
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4275
                                Signature:               r.Policy2Signature,
×
4276
                        }
×
4277
                }
×
4278

4279
                return policy1, policy2, nil
×
4280

4281
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4282
                if r.Policy1ID.Valid {
×
4283
                        policy1 = &sqlc.GraphChannelPolicy{
×
4284
                                ID:                      r.Policy1ID.Int64,
×
4285
                                Version:                 r.Policy1Version.Int16,
×
4286
                                ChannelID:               r.GraphChannel.ID,
×
4287
                                NodeID:                  r.Policy1NodeID.Int64,
×
4288
                                Timelock:                r.Policy1Timelock.Int32,
×
4289
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4290
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4291
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4292
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4293
                                LastUpdate:              r.Policy1LastUpdate,
×
4294
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4295
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4296
                                Disabled:                r.Policy1Disabled,
×
4297
                                MessageFlags:            r.Policy1MessageFlags,
×
4298
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4299
                                Signature:               r.Policy1Signature,
×
4300
                        }
×
4301
                }
×
4302
                if r.Policy2ID.Valid {
×
4303
                        policy2 = &sqlc.GraphChannelPolicy{
×
UNCOV
4304
                                ID:                      r.Policy2ID.Int64,
×
4305
                                Version:                 r.Policy2Version.Int16,
×
UNCOV
4306
                                ChannelID:               r.GraphChannel.ID,
×
4307
                                NodeID:                  r.Policy2NodeID.Int64,
×
4308
                                Timelock:                r.Policy2Timelock.Int32,
×
4309
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4310
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4311
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4312
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4313
                                LastUpdate:              r.Policy2LastUpdate,
×
4314
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4315
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4316
                                Disabled:                r.Policy2Disabled,
×
4317
                                MessageFlags:            r.Policy2MessageFlags,
×
4318
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4319
                                Signature:               r.Policy2Signature,
×
4320
                        }
×
4321
                }
×
4322

4323
                return policy1, policy2, nil
×
4324

4325
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4326
                if r.Policy1ID.Valid {
×
4327
                        policy1 = &sqlc.GraphChannelPolicy{
×
4328
                                ID:                      r.Policy1ID.Int64,
×
4329
                                Version:                 r.Policy1Version.Int16,
×
4330
                                ChannelID:               r.GraphChannel.ID,
×
4331
                                NodeID:                  r.Policy1NodeID.Int64,
×
4332
                                Timelock:                r.Policy1Timelock.Int32,
×
4333
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4334
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4335
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4336
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4337
                                LastUpdate:              r.Policy1LastUpdate,
×
4338
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4339
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4340
                                Disabled:                r.Policy1Disabled,
×
4341
                                MessageFlags:            r.Policy1MessageFlags,
×
4342
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4343
                                Signature:               r.Policy1Signature,
×
4344
                        }
×
4345
                }
×
4346
                if r.Policy2ID.Valid {
×
4347
                        policy2 = &sqlc.GraphChannelPolicy{
×
UNCOV
4348
                                ID:                      r.Policy2ID.Int64,
×
4349
                                Version:                 r.Policy2Version.Int16,
×
UNCOV
4350
                                ChannelID:               r.GraphChannel.ID,
×
4351
                                NodeID:                  r.Policy2NodeID.Int64,
×
4352
                                Timelock:                r.Policy2Timelock.Int32,
×
4353
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4354
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4355
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4356
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4357
                                LastUpdate:              r.Policy2LastUpdate,
×
4358
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4359
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4360
                                Disabled:                r.Policy2Disabled,
×
4361
                                MessageFlags:            r.Policy2MessageFlags,
×
4362
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4363
                                Signature:               r.Policy2Signature,
×
4364
                        }
×
4365
                }
×
4366

4367
                return policy1, policy2, nil
×
4368

4369
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4370
                if r.Policy1ID.Valid {
×
4371
                        policy1 = &sqlc.GraphChannelPolicy{
×
4372
                                ID:                      r.Policy1ID.Int64,
×
4373
                                Version:                 r.Policy1Version.Int16,
×
4374
                                ChannelID:               r.GraphChannel.ID,
×
4375
                                NodeID:                  r.Policy1NodeID.Int64,
×
4376
                                Timelock:                r.Policy1Timelock.Int32,
×
4377
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4378
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4379
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4380
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4381
                                LastUpdate:              r.Policy1LastUpdate,
×
4382
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4383
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4384
                                Disabled:                r.Policy1Disabled,
×
4385
                                MessageFlags:            r.Policy1MessageFlags,
×
4386
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4387
                                Signature:               r.Policy1Signature,
×
4388
                        }
×
4389
                }
×
4390
                if r.Policy2ID.Valid {
×
4391
                        policy2 = &sqlc.GraphChannelPolicy{
×
UNCOV
4392
                                ID:                      r.Policy2ID.Int64,
×
4393
                                Version:                 r.Policy2Version.Int16,
×
UNCOV
4394
                                ChannelID:               r.GraphChannel.ID,
×
4395
                                NodeID:                  r.Policy2NodeID.Int64,
×
4396
                                Timelock:                r.Policy2Timelock.Int32,
×
4397
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4398
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4399
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4400
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4401
                                LastUpdate:              r.Policy2LastUpdate,
×
4402
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4403
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4404
                                Disabled:                r.Policy2Disabled,
×
4405
                                MessageFlags:            r.Policy2MessageFlags,
×
4406
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4407
                                Signature:               r.Policy2Signature,
×
4408
                        }
×
4409
                }
×
4410

4411
                return policy1, policy2, nil
×
4412

4413
        case sqlc.ListChannelsForNodeIDsRow:
×
4414
                if r.Policy1ID.Valid {
×
4415
                        policy1 = &sqlc.GraphChannelPolicy{
×
4416
                                ID:                      r.Policy1ID.Int64,
×
4417
                                Version:                 r.Policy1Version.Int16,
×
4418
                                ChannelID:               r.GraphChannel.ID,
×
4419
                                NodeID:                  r.Policy1NodeID.Int64,
×
4420
                                Timelock:                r.Policy1Timelock.Int32,
×
4421
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4422
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4423
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4424
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4425
                                LastUpdate:              r.Policy1LastUpdate,
×
4426
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4427
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4428
                                Disabled:                r.Policy1Disabled,
×
4429
                                MessageFlags:            r.Policy1MessageFlags,
×
4430
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4431
                                Signature:               r.Policy1Signature,
×
4432
                        }
×
4433
                }
×
4434
                if r.Policy2ID.Valid {
×
4435
                        policy2 = &sqlc.GraphChannelPolicy{
×
UNCOV
4436
                                ID:                      r.Policy2ID.Int64,
×
4437
                                Version:                 r.Policy2Version.Int16,
×
UNCOV
4438
                                ChannelID:               r.GraphChannel.ID,
×
4439
                                NodeID:                  r.Policy2NodeID.Int64,
×
4440
                                Timelock:                r.Policy2Timelock.Int32,
×
4441
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4442
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4443
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4444
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4445
                                LastUpdate:              r.Policy2LastUpdate,
×
4446
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4447
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4448
                                Disabled:                r.Policy2Disabled,
×
4449
                                MessageFlags:            r.Policy2MessageFlags,
×
4450
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4451
                                Signature:               r.Policy2Signature,
×
4452
                        }
×
4453
                }
×
4454

4455
                return policy1, policy2, nil
×
4456

4457
        case sqlc.ListChannelsByNodeIDRow:
×
4458
                if r.Policy1ID.Valid {
×
4459
                        policy1 = &sqlc.GraphChannelPolicy{
×
4460
                                ID:                      r.Policy1ID.Int64,
×
4461
                                Version:                 r.Policy1Version.Int16,
×
4462
                                ChannelID:               r.GraphChannel.ID,
×
4463
                                NodeID:                  r.Policy1NodeID.Int64,
×
4464
                                Timelock:                r.Policy1Timelock.Int32,
×
4465
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4466
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4467
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4468
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4469
                                LastUpdate:              r.Policy1LastUpdate,
×
4470
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4471
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4472
                                Disabled:                r.Policy1Disabled,
×
4473
                                MessageFlags:            r.Policy1MessageFlags,
×
4474
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4475
                                Signature:               r.Policy1Signature,
×
4476
                        }
×
4477
                }
×
4478
                if r.Policy2ID.Valid {
×
4479
                        policy2 = &sqlc.GraphChannelPolicy{
×
UNCOV
4480
                                ID:                      r.Policy2ID.Int64,
×
4481
                                Version:                 r.Policy2Version.Int16,
×
UNCOV
4482
                                ChannelID:               r.GraphChannel.ID,
×
4483
                                NodeID:                  r.Policy2NodeID.Int64,
×
4484
                                Timelock:                r.Policy2Timelock.Int32,
×
4485
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4486
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4487
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4488
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4489
                                LastUpdate:              r.Policy2LastUpdate,
×
4490
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4491
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4492
                                Disabled:                r.Policy2Disabled,
×
4493
                                MessageFlags:            r.Policy2MessageFlags,
×
4494
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4495
                                Signature:               r.Policy2Signature,
×
4496
                        }
×
4497
                }
×
4498

4499
                return policy1, policy2, nil
×
4500

4501
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4502
                if r.Policy1ID.Valid {
×
4503
                        policy1 = &sqlc.GraphChannelPolicy{
×
4504
                                ID:                      r.Policy1ID.Int64,
×
4505
                                Version:                 r.Policy1Version.Int16,
×
4506
                                ChannelID:               r.GraphChannel.ID,
×
4507
                                NodeID:                  r.Policy1NodeID.Int64,
×
4508
                                Timelock:                r.Policy1Timelock.Int32,
×
4509
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4510
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4511
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4512
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4513
                                LastUpdate:              r.Policy1LastUpdate,
×
4514
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4515
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4516
                                Disabled:                r.Policy1Disabled,
×
4517
                                MessageFlags:            r.Policy1MessageFlags,
×
4518
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4519
                                Signature:               r.Policy1Signature,
×
4520
                        }
×
4521
                }
×
4522
                if r.Policy2ID.Valid {
×
4523
                        policy2 = &sqlc.GraphChannelPolicy{
×
UNCOV
4524
                                ID:                      r.Policy2ID.Int64,
×
4525
                                Version:                 r.Policy2Version.Int16,
×
UNCOV
4526
                                ChannelID:               r.GraphChannel.ID,
×
4527
                                NodeID:                  r.Policy2NodeID.Int64,
×
4528
                                Timelock:                r.Policy2Timelock.Int32,
×
4529
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4530
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4531
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4532
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4533
                                LastUpdate:              r.Policy2LastUpdate,
×
4534
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4535
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4536
                                Disabled:                r.Policy2Disabled,
×
4537
                                MessageFlags:            r.Policy2MessageFlags,
×
4538
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4539
                                Signature:               r.Policy2Signature,
×
4540
                        }
×
4541
                }
×
4542

4543
                return policy1, policy2, nil
×
4544

4545
        case sqlc.GetChannelsByIDsRow:
×
4546
                if r.Policy1ID.Valid {
×
4547
                        policy1 = &sqlc.GraphChannelPolicy{
×
4548
                                ID:                      r.Policy1ID.Int64,
×
4549
                                Version:                 r.Policy1Version.Int16,
×
4550
                                ChannelID:               r.GraphChannel.ID,
×
4551
                                NodeID:                  r.Policy1NodeID.Int64,
×
4552
                                Timelock:                r.Policy1Timelock.Int32,
×
4553
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4554
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4555
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4556
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4557
                                LastUpdate:              r.Policy1LastUpdate,
×
4558
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4559
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4560
                                Disabled:                r.Policy1Disabled,
×
4561
                                MessageFlags:            r.Policy1MessageFlags,
×
4562
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4563
                                Signature:               r.Policy1Signature,
×
4564
                        }
×
4565
                }
×
4566
                if r.Policy2ID.Valid {
×
4567
                        policy2 = &sqlc.GraphChannelPolicy{
×
UNCOV
4568
                                ID:                      r.Policy2ID.Int64,
×
4569
                                Version:                 r.Policy2Version.Int16,
×
UNCOV
4570
                                ChannelID:               r.GraphChannel.ID,
×
4571
                                NodeID:                  r.Policy2NodeID.Int64,
×
4572
                                Timelock:                r.Policy2Timelock.Int32,
×
4573
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
UNCOV
4574
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
UNCOV
4575
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
UNCOV
4576
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
UNCOV
4577
                                LastUpdate:              r.Policy2LastUpdate,
×
UNCOV
4578
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4579
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4580
                                Disabled:                r.Policy2Disabled,
×
4581
                                MessageFlags:            r.Policy2MessageFlags,
×
4582
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4583
                                Signature:               r.Policy2Signature,
×
4584
                        }
×
UNCOV
4585
                }
×
4586

4587
                return policy1, policy2, nil
×
4588

4589
        default:
×
4590
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
UNCOV
4591
                        "extractChannelPolicies: %T", r)
×
4592
        }
4593
}
4594

4595
// channelIDToBytes converts a channel ID (SCID) to a byte array
4596
// representation.
4597
func channelIDToBytes(channelID uint64) []byte {
×
4598
        var chanIDB [8]byte
×
4599
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4600

×
4601
        return chanIDB[:]
×
4602
}
×
4603

4604
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
UNCOV
4605
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4606
        if len(addresses) == 0 {
×
4607
                return nil, nil
×
4608
        }
×
4609

4610
        result := make([]net.Addr, 0, len(addresses))
×
UNCOV
4611
        for _, addr := range addresses {
×
UNCOV
4612
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
UNCOV
4613
                if err != nil {
×
UNCOV
4614
                        return nil, fmt.Errorf("unable to parse address %s "+
×
UNCOV
4615
                                "of type %d: %w", addr.address, addr.addrType,
×
4616
                                err)
×
4617
                }
×
4618
                if netAddr != nil {
×
4619
                        result = append(result, netAddr)
×
4620
                }
×
4621
        }
4622

4623
        // If we have no valid addresses, return nil instead of empty slice.
4624
        if len(result) == 0 {
×
4625
                return nil, nil
×
4626
        }
×
4627

4628
        return result, nil
×
4629
}
4630

4631
// parseAddress parses the given address string based on the address type
4632
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4633
// and opaque addresses.
4634
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
UNCOV
4635
        switch addrType {
×
4636
        case addressTypeIPv4:
×
4637
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4638
                if err != nil {
×
4639
                        return nil, err
×
4640
                }
×
4641

UNCOV
4642
                tcp.IP = tcp.IP.To4()
×
4643

×
4644
                return tcp, nil
×
4645

4646
        case addressTypeIPv6:
×
UNCOV
4647
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
4648
                if err != nil {
×
4649
                        return nil, err
×
4650
                }
×
4651

UNCOV
4652
                return tcp, nil
×
4653

4654
        case addressTypeTorV3, addressTypeTorV2:
×
4655
                service, portStr, err := net.SplitHostPort(address)
×
4656
                if err != nil {
×
4657
                        return nil, fmt.Errorf("unable to split tor "+
×
4658
                                "address: %v", address)
×
UNCOV
4659
                }
×
4660

4661
                port, err := strconv.Atoi(portStr)
×
4662
                if err != nil {
×
UNCOV
4663
                        return nil, err
×
4664
                }
×
4665

UNCOV
4666
                return &tor.OnionAddr{
×
UNCOV
4667
                        OnionService: service,
×
UNCOV
4668
                        Port:         port,
×
UNCOV
4669
                }, nil
×
4670

UNCOV
4671
        case addressTypeDNS:
×
UNCOV
4672
                hostname, portStr, err := net.SplitHostPort(address)
×
UNCOV
4673
                if err != nil {
×
UNCOV
4674
                        return nil, fmt.Errorf("unable to split DNS "+
×
UNCOV
4675
                                "address: %v", address)
×
UNCOV
4676
                }
×
4677

UNCOV
4678
                port, err := strconv.Atoi(portStr)
×
UNCOV
4679
                if err != nil {
×
UNCOV
4680
                        return nil, err
×
UNCOV
4681
                }
×
4682

UNCOV
4683
                return &lnwire.DNSAddress{
×
UNCOV
4684
                        Hostname: hostname,
×
UNCOV
4685
                        Port:     uint16(port),
×
UNCOV
4686
                }, nil
×
4687

UNCOV
4688
        case addressTypeOpaque:
×
UNCOV
4689
                opaque, err := hex.DecodeString(address)
×
UNCOV
4690
                if err != nil {
×
UNCOV
4691
                        return nil, fmt.Errorf("unable to decode opaque "+
×
UNCOV
4692
                                "address: %v", address)
×
UNCOV
4693
                }
×
4694

4695
                return &lnwire.OpaqueAddrs{
×
4696
                        Payload: opaque,
×
4697
                }, nil
×
4698

4699
        default:
×
4700
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
4701
        }
4702
}
4703

4704
// batchNodeData holds all the related data for a batch of nodes.
4705
type batchNodeData struct {
4706
        // features is a map from a DB node ID to the feature bits for that
4707
        // node.
4708
        features map[int64][]int
4709

4710
        // addresses is a map from a DB node ID to the node's addresses.
4711
        addresses map[int64][]nodeAddress
4712

4713
        // extraFields is a map from a DB node ID to the extra signed fields
4714
        // for that node.
4715
        extraFields map[int64]map[uint64][]byte
4716
}
4717

4718
// nodeAddress holds the address type, position and address string for a
4719
// node. This is used to batch the fetching of node addresses.
4720
type nodeAddress struct {
4721
        addrType dbAddressType
4722
        position int32
4723
        address  string
4724
}
4725

4726
// batchLoadNodeData loads all related data for a batch of node IDs using the
4727
// provided SQLQueries interface. It returns a batchNodeData instance containing
4728
// the node features, addresses and extra signed fields.
4729
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
4730
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
4731

×
4732
        // Batch load the node features.
×
4733
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
4734
        if err != nil {
×
4735
                return nil, fmt.Errorf("unable to batch load node "+
×
4736
                        "features: %w", err)
×
4737
        }
×
4738

4739
        // Batch load the node addresses.
4740
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
4741
        if err != nil {
×
4742
                return nil, fmt.Errorf("unable to batch load node "+
×
4743
                        "addresses: %w", err)
×
4744
        }
×
4745

4746
        // Batch load the node extra signed fields.
4747
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
4748
        if err != nil {
×
4749
                return nil, fmt.Errorf("unable to batch load node extra "+
×
4750
                        "signed fields: %w", err)
×
UNCOV
4751
        }
×
4752

UNCOV
4753
        return &batchNodeData{
×
UNCOV
4754
                features:    features,
×
UNCOV
4755
                addresses:   addrs,
×
UNCOV
4756
                extraFields: extraTypes,
×
UNCOV
4757
        }, nil
×
4758
}
4759

4760
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
4761
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
4762
func batchLoadNodeFeaturesHelper(ctx context.Context,
4763
        cfg *sqldb.QueryConfig, db SQLQueries,
4764
        nodeIDs []int64) (map[int64][]int, error) {
×
4765

×
4766
        features := make(map[int64][]int)
×
4767

×
UNCOV
4768
        return features, sqldb.ExecuteBatchQuery(
×
4769
                ctx, cfg, nodeIDs,
×
4770
                func(id int64) int64 {
×
4771
                        return id
×
4772
                },
×
4773
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
4774
                        error) {
×
4775

×
4776
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
4777
                },
×
4778
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
4779
                        features[feature.NodeID] = append(
×
4780
                                features[feature.NodeID],
×
4781
                                int(feature.FeatureBit),
×
4782
                        )
×
4783

×
UNCOV
4784
                        return nil
×
UNCOV
4785
                },
×
4786
        )
4787
}
4788

4789
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
4790
// wrapper around the GetNodeAddressesBatch query. It returns a map from
4791
// node ID to a slice of nodeAddress structs.
4792
func batchLoadNodeAddressesHelper(ctx context.Context,
4793
        cfg *sqldb.QueryConfig, db SQLQueries,
4794
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
4795

×
4796
        addrs := make(map[int64][]nodeAddress)
×
4797

×
4798
        return addrs, sqldb.ExecuteBatchQuery(
×
4799
                ctx, cfg, nodeIDs,
×
4800
                func(id int64) int64 {
×
4801
                        return id
×
4802
                },
×
4803
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
4804
                        error) {
×
UNCOV
4805

×
UNCOV
4806
                        return db.GetNodeAddressesBatch(ctx, ids)
×
4807
                },
×
4808
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
4809
                        addrs[addr.NodeID] = append(
×
4810
                                addrs[addr.NodeID], nodeAddress{
×
4811
                                        addrType: dbAddressType(addr.Type),
×
UNCOV
4812
                                        position: addr.Position,
×
4813
                                        address:  addr.Address,
×
4814
                                },
×
4815
                        )
×
4816

×
UNCOV
4817
                        return nil
×
UNCOV
4818
                },
×
4819
        )
4820
}
4821

4822
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
4823
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
4824
// query.
4825
func batchLoadNodeExtraTypesHelper(ctx context.Context,
4826
        cfg *sqldb.QueryConfig, db SQLQueries,
4827
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4828

×
4829
        extraFields := make(map[int64]map[uint64][]byte)
×
4830

×
4831
        callback := func(ctx context.Context,
×
4832
                field sqlc.GraphNodeExtraType) error {
×
4833

×
4834
                if extraFields[field.NodeID] == nil {
×
UNCOV
4835
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
4836
                }
×
4837
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
4838

×
4839
                return nil
×
4840
        }
4841

UNCOV
4842
        return extraFields, sqldb.ExecuteBatchQuery(
×
4843
                ctx, cfg, nodeIDs,
×
UNCOV
4844
                func(id int64) int64 {
×
UNCOV
4845
                        return id
×
UNCOV
4846
                },
×
4847
                func(ctx context.Context, ids []int64) (
UNCOV
4848
                        []sqlc.GraphNodeExtraType, error) {
×
UNCOV
4849

×
4850
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
4851
                },
×
4852
                callback,
4853
        )
4854
}
4855

4856
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
4857
// from the provided sqlc.GraphChannelPolicy records and the
4858
// provided batchChannelData.
4859
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4860
        channelID uint64, node1, node2 route.Vertex,
4861
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
UNCOV
4862
        *models.ChannelEdgePolicy, error) {
×
4863

×
UNCOV
4864
        pol1, err := buildChanPolicyWithBatchData(
×
UNCOV
4865
                dbPol1, channelID, node2, batchData,
×
UNCOV
4866
        )
×
UNCOV
4867
        if err != nil {
×
UNCOV
4868
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
UNCOV
4869
        }
×
4870

UNCOV
4871
        pol2, err := buildChanPolicyWithBatchData(
×
UNCOV
4872
                dbPol2, channelID, node1, batchData,
×
UNCOV
4873
        )
×
UNCOV
4874
        if err != nil {
×
UNCOV
4875
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
UNCOV
4876
        }
×
4877

UNCOV
4878
        return pol1, pol2, nil
×
4879
}
4880

4881
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
4882
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
4883
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
4884
        channelID uint64, toNode route.Vertex,
4885
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
4886

×
4887
        if dbPol == nil {
×
4888
                return nil, nil
×
4889
        }
×
4890

4891
        var dbPol1Extras map[uint64][]byte
×
4892
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
4893
                dbPol1Extras = extras
×
4894
        } else {
×
4895
                dbPol1Extras = make(map[uint64][]byte)
×
4896
        }
×
4897

4898
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
4899
}
4900

4901
// batchChannelData holds all the related data for a batch of channels.
4902
type batchChannelData struct {
4903
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
4904
        chanfeatures map[int64][]int
4905

4906
        // chanExtras is a map from DB channel ID to a map of TLV type to
4907
        // extra signed field bytes.
4908
        chanExtraTypes map[int64]map[uint64][]byte
4909

4910
        // policyExtras is a map from DB channel policy ID to a map of TLV type
4911
        // to extra signed field bytes.
4912
        policyExtras map[int64]map[uint64][]byte
4913
}
4914

4915
// batchLoadChannelData loads all related data for batches of channels and
4916
// policies.
4917
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
4918
        db SQLQueries, channelIDs []int64,
4919
        policyIDs []int64) (*batchChannelData, error) {
×
4920

×
UNCOV
4921
        batchData := &batchChannelData{
×
UNCOV
4922
                chanfeatures:   make(map[int64][]int),
×
4923
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
UNCOV
4924
                policyExtras:   make(map[int64]map[uint64][]byte),
×
UNCOV
4925
        }
×
UNCOV
4926

×
UNCOV
4927
        // Batch load channel features and extras
×
UNCOV
4928
        var err error
×
UNCOV
4929
        if len(channelIDs) > 0 {
×
UNCOV
4930
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
UNCOV
4931
                        ctx, cfg, db, channelIDs,
×
4932
                )
×
4933
                if err != nil {
×
4934
                        return nil, fmt.Errorf("unable to batch load "+
×
4935
                                "channel features: %w", err)
×
4936
                }
×
4937

4938
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
4939
                        ctx, cfg, db, channelIDs,
×
4940
                )
×
UNCOV
4941
                if err != nil {
×
4942
                        return nil, fmt.Errorf("unable to batch load "+
×
4943
                                "channel extras: %w", err)
×
4944
                }
×
4945
        }
4946

4947
        if len(policyIDs) > 0 {
×
4948
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
4949
                        ctx, cfg, db, policyIDs,
×
4950
                )
×
4951
                if err != nil {
×
4952
                        return nil, fmt.Errorf("unable to batch load "+
×
4953
                                "policy extras: %w", err)
×
4954
                }
×
4955
                batchData.policyExtras = policyExtras
×
4956
        }
4957

UNCOV
4958
        return batchData, nil
×
4959
}
4960

4961
// batchLoadChannelFeaturesHelper loads channel features for a batch of
4962
// channel IDs using ExecuteBatchQuery wrapper around the
4963
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
4964
// slice of feature bits.
4965
func batchLoadChannelFeaturesHelper(ctx context.Context,
4966
        cfg *sqldb.QueryConfig, db SQLQueries,
4967
        channelIDs []int64) (map[int64][]int, error) {
×
4968

×
4969
        features := make(map[int64][]int)
×
4970

×
4971
        return features, sqldb.ExecuteBatchQuery(
×
4972
                ctx, cfg, channelIDs,
×
4973
                func(id int64) int64 {
×
4974
                        return id
×
4975
                },
×
4976
                func(ctx context.Context,
4977
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
UNCOV
4978

×
UNCOV
4979
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
4980
                },
×
4981
                func(ctx context.Context,
4982
                        feature sqlc.GraphChannelFeature) error {
×
4983

×
4984
                        features[feature.ChannelID] = append(
×
UNCOV
4985
                                features[feature.ChannelID],
×
4986
                                int(feature.FeatureBit),
×
4987
                        )
×
4988

×
4989
                        return nil
×
UNCOV
4990
                },
×
4991
        )
4992
}
4993

4994
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
4995
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
4996
// query. It returns a map from DB channel ID to a map of TLV type to extra
4997
// signed field bytes.
4998
func batchLoadChannelExtrasHelper(ctx context.Context,
4999
        cfg *sqldb.QueryConfig, db SQLQueries,
5000
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5001

×
5002
        extras := make(map[int64]map[uint64][]byte)
×
5003

×
5004
        cb := func(ctx context.Context,
×
5005
                extra sqlc.GraphChannelExtraType) error {
×
5006

×
5007
                if extras[extra.ChannelID] == nil {
×
UNCOV
5008
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5009
                }
×
5010
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
5011

×
5012
                return nil
×
5013
        }
5014

5015
        return extras, sqldb.ExecuteBatchQuery(
×
5016
                ctx, cfg, channelIDs,
×
5017
                func(id int64) int64 {
×
5018
                        return id
×
5019
                },
×
5020
                func(ctx context.Context,
5021
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
UNCOV
5022

×
UNCOV
5023
                        return db.GetChannelExtrasBatch(ctx, ids)
×
UNCOV
5024
                }, cb,
×
5025
        )
5026
}
5027

5028
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5029
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5030
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5031
// a map of TLV type to extra signed field bytes.
5032
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5033
        cfg *sqldb.QueryConfig, db SQLQueries,
5034
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5035

×
5036
        extras := make(map[int64]map[uint64][]byte)
×
5037

×
5038
        return extras, sqldb.ExecuteBatchQuery(
×
5039
                ctx, cfg, policyIDs,
×
5040
                func(id int64) int64 {
×
5041
                        return id
×
5042
                },
×
5043
                func(ctx context.Context, ids []int64) (
5044
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
UNCOV
5045

×
5046
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5047
                },
×
5048
                func(ctx context.Context,
UNCOV
5049
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5050

×
5051
                        if extras[row.PolicyID] == nil {
×
5052
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
UNCOV
5053
                        }
×
5054
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5055

×
5056
                        return nil
×
5057
                },
5058
        )
5059
}
5060

5061
// forEachNodePaginated executes a paginated query to process each node in the
5062
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5063
// and applies the provided processNode function to each node.
5064
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5065
        db SQLQueries, protocol ProtocolVersion,
5066
        processNode func(context.Context, int64,
5067
                *models.LightningNode) error) error {
×
UNCOV
5068

×
5069
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
UNCOV
5070
                limit int32) ([]sqlc.GraphNode, error) {
×
UNCOV
5071

×
5072
                return db.ListNodesPaginated(
×
5073
                        ctx, sqlc.ListNodesPaginatedParams{
×
5074
                                Version: int16(protocol),
×
5075
                                ID:      lastID,
×
UNCOV
5076
                                Limit:   limit,
×
UNCOV
5077
                        },
×
UNCOV
5078
                )
×
UNCOV
5079
        }
×
5080

UNCOV
5081
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
UNCOV
5082
                return node.ID
×
5083
        }
×
5084

5085
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5086
                return node.ID, nil
×
5087
        }
×
5088

5089
        batchQueryFunc := func(ctx context.Context,
×
5090
                nodeIDs []int64) (*batchNodeData, error) {
×
5091

×
5092
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5093
        }
×
5094

5095
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5096
                batchData *batchNodeData) error {
×
5097

×
5098
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
5099
                if err != nil {
×
5100
                        return fmt.Errorf("unable to build "+
×
5101
                                "node(id=%d): %w", dbNode.ID, err)
×
UNCOV
5102
                }
×
5103

5104
                return processNode(ctx, dbNode.ID, node)
×
5105
        }
5106

5107
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
UNCOV
5108
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5109
                collectFunc, batchQueryFunc, processItem,
×
5110
        )
×
5111
}
5112

5113
// forEachChannelWithPolicies executes a paginated query to process each channel
5114
// with policies in the graph.
5115
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5116
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5117
                *models.ChannelEdgePolicy,
5118
                *models.ChannelEdgePolicy) error) error {
×
5119

×
5120
        type channelBatchIDs struct {
×
UNCOV
5121
                channelID int64
×
5122
                policyIDs []int64
×
5123
        }
×
5124

×
5125
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5126
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5127
                error) {
×
UNCOV
5128

×
5129
                return db.ListChannelsWithPoliciesPaginated(
×
UNCOV
5130
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
UNCOV
5131
                                Version: int16(ProtocolV1),
×
5132
                                ID:      lastID,
×
5133
                                Limit:   limit,
×
5134
                        },
×
5135
                )
×
5136
        }
×
5137

5138
        extractPageCursor := func(
×
5139
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5140

×
5141
                return row.GraphChannel.ID
×
5142
        }
×
5143

5144
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
UNCOV
5145
                channelBatchIDs, error) {
×
5146

×
5147
                ids := channelBatchIDs{
×
5148
                        channelID: row.GraphChannel.ID,
×
UNCOV
5149
                }
×
UNCOV
5150

×
5151
                // Extract policy IDs from the row.
×
5152
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5153
                if err != nil {
×
5154
                        return ids, err
×
5155
                }
×
5156

5157
                if dbPol1 != nil {
×
5158
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5159
                }
×
5160
                if dbPol2 != nil {
×
UNCOV
5161
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5162
                }
×
5163

5164
                return ids, nil
×
5165
        }
5166

5167
        batchDataFunc := func(ctx context.Context,
×
5168
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5169

×
UNCOV
5170
                // Separate channel IDs from policy IDs.
×
5171
                var (
×
5172
                        channelIDs = make([]int64, len(allIDs))
×
5173
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5174
                )
×
UNCOV
5175

×
5176
                for i, ids := range allIDs {
×
5177
                        channelIDs[i] = ids.channelID
×
5178
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5179
                }
×
5180

5181
                return batchLoadChannelData(
×
UNCOV
5182
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5183
                )
×
5184
        }
5185

5186
        processItem := func(ctx context.Context,
×
5187
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5188
                batchData *batchChannelData) error {
×
5189

×
UNCOV
5190
                node1, node2, err := buildNodeVertices(
×
UNCOV
5191
                        row.Node1Pubkey, row.Node2Pubkey,
×
UNCOV
5192
                )
×
UNCOV
5193
                if err != nil {
×
UNCOV
5194
                        return err
×
UNCOV
5195
                }
×
5196

5197
                edge, err := buildEdgeInfoWithBatchData(
×
5198
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5199
                        batchData,
×
5200
                )
×
5201
                if err != nil {
×
5202
                        return fmt.Errorf("unable to build channel info: %w",
×
5203
                                err)
×
5204
                }
×
5205

5206
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5207
                if err != nil {
×
5208
                        return err
×
5209
                }
×
5210

5211
                p1, p2, err := buildChanPoliciesWithBatchData(
×
UNCOV
5212
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
5213
                )
×
5214
                if err != nil {
×
5215
                        return err
×
5216
                }
×
5217

UNCOV
5218
                return processChannel(edge, p1, p2)
×
5219
        }
5220

5221
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5222
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5223
                collectFunc, batchDataFunc, processItem,
×
5224
        )
×
5225
}
5226

5227
// buildDirectedChannel builds a DirectedChannel instance from the provided
5228
// data.
5229
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5230
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5231
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5232
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
5233

×
5234
        node1, node2, err := buildNodeVertices(
×
5235
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5236
        )
×
UNCOV
5237
        if err != nil {
×
UNCOV
5238
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5239
        }
×
5240

5241
        edge, err := buildEdgeInfoWithBatchData(
×
5242
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5243
        )
×
5244
        if err != nil {
×
UNCOV
5245
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
UNCOV
5246
        }
×
5247

5248
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
5249
        if err != nil {
×
5250
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5251
                        err)
×
UNCOV
5252
        }
×
5253

UNCOV
5254
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5255
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5256
                channelBatchData,
×
5257
        )
×
5258
        if err != nil {
×
5259
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5260
                        err)
×
5261
        }
×
5262

5263
        // Determine outgoing and incoming policy for this specific node.
5264
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5265
        p2ToNode := channelRow.GraphChannel.NodeID1
×
5266
        outPolicy, inPolicy := p1, p2
×
5267
        if (p1 != nil && p1ToNode == nodeID) ||
×
UNCOV
5268
                (p2 != nil && p2ToNode != nodeID) {
×
5269

×
UNCOV
5270
                outPolicy, inPolicy = p2, p1
×
UNCOV
5271
        }
×
5272

5273
        // Build cached policy.
UNCOV
5274
        var cachedInPolicy *models.CachedEdgePolicy
×
5275
        if inPolicy != nil {
×
5276
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5277
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5278
                cachedInPolicy.ToNodeFeatures = features
×
5279
        }
×
5280

5281
        // Extract inbound fee.
5282
        var inboundFee lnwire.Fee
×
5283
        if outPolicy != nil {
×
5284
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5285
                        inboundFee = fee
×
5286
                })
×
5287
        }
5288

5289
        // Build directed channel.
5290
        directedChannel := &DirectedChannel{
×
5291
                ChannelID:    edge.ChannelID,
×
5292
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
5293
                OtherNode:    edge.NodeKey2Bytes,
×
5294
                Capacity:     edge.Capacity,
×
5295
                OutPolicySet: outPolicy != nil,
×
5296
                InPolicy:     cachedInPolicy,
×
5297
                InboundFee:   inboundFee,
×
5298
        }
×
5299

×
5300
        if nodePub == edge.NodeKey2Bytes {
×
5301
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5302
        }
×
5303

5304
        return directedChannel, nil
×
5305
}
5306

5307
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
5308
// provided rows. It uses batch loading for channels, policies, and nodes.
5309
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
5310
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
5311

×
5312
        var (
×
5313
                channelIDs = make([]int64, len(rows))
×
5314
                policyIDs  = make([]int64, 0, len(rows)*2)
×
5315
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
UNCOV
5316

×
5317
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
5318
                nodeIDSet = make(map[int64]bool)
×
5319

×
5320
                // edges will hold the final channel edges built from the rows.
×
UNCOV
5321
                edges = make([]ChannelEdge, 0, len(rows))
×
UNCOV
5322
        )
×
UNCOV
5323

×
5324
        // Collect all IDs needed for batch loading.
×
5325
        for i, row := range rows {
×
5326
                channelIDs[i] = row.Channel().ID
×
5327

×
5328
                // Collect policy IDs
×
5329
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5330
                if err != nil {
×
UNCOV
5331
                        return nil, fmt.Errorf("unable to extract channel "+
×
UNCOV
5332
                                "policies: %w", err)
×
5333
                }
×
5334
                if dbPol1 != nil {
×
5335
                        policyIDs = append(policyIDs, dbPol1.ID)
×
5336
                }
×
5337
                if dbPol2 != nil {
×
UNCOV
5338
                        policyIDs = append(policyIDs, dbPol2.ID)
×
UNCOV
5339
                }
×
5340

5341
                var (
×
5342
                        node1ID = row.Node1().ID
×
5343
                        node2ID = row.Node2().ID
×
5344
                )
×
5345

×
UNCOV
5346
                // Collect unique node IDs.
×
5347
                if !nodeIDSet[node1ID] {
×
5348
                        nodeIDs = append(nodeIDs, node1ID)
×
5349
                        nodeIDSet[node1ID] = true
×
5350
                }
×
5351

UNCOV
5352
                if !nodeIDSet[node2ID] {
×
5353
                        nodeIDs = append(nodeIDs, node2ID)
×
5354
                        nodeIDSet[node2ID] = true
×
5355
                }
×
5356
        }
5357

5358
        // Batch the data for all the channels and policies.
5359
        channelBatchData, err := batchLoadChannelData(
×
5360
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
UNCOV
5361
        )
×
UNCOV
5362
        if err != nil {
×
5363
                return nil, fmt.Errorf("unable to batch load channel and "+
×
5364
                        "policy data: %w", err)
×
5365
        }
×
5366

5367
        // Batch the data for all the nodes.
UNCOV
5368
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
5369
        if err != nil {
×
5370
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
5371
                        err)
×
5372
        }
×
5373

5374
        // Build all channel edges using batch data.
5375
        for _, row := range rows {
×
5376
                // Build nodes using batch data.
×
UNCOV
5377
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
5378
                if err != nil {
×
5379
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
5380
                }
×
5381

5382
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
5383
                if err != nil {
×
5384
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
UNCOV
5385
                }
×
5386

5387
                // Build channel info using batch data.
UNCOV
5388
                channel, err := buildEdgeInfoWithBatchData(
×
UNCOV
5389
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
UNCOV
5390
                        node2.PubKeyBytes, channelBatchData,
×
UNCOV
5391
                )
×
UNCOV
5392
                if err != nil {
×
UNCOV
5393
                        return nil, fmt.Errorf("unable to build channel "+
×
5394
                                "info: %w", err)
×
5395
                }
×
5396

5397
                // Extract and build policies using batch data.
5398
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
UNCOV
5399
                if err != nil {
×
UNCOV
5400
                        return nil, fmt.Errorf("unable to extract channel "+
×
5401
                                "policies: %w", err)
×
5402
                }
×
5403

5404
                p1, p2, err := buildChanPoliciesWithBatchData(
×
UNCOV
5405
                        dbPol1, dbPol2, channel.ChannelID,
×
UNCOV
5406
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
5407
                )
×
5408
                if err != nil {
×
5409
                        return nil, fmt.Errorf("unable to build channel "+
×
5410
                                "policies: %w", err)
×
5411
                }
×
5412

5413
                edges = append(edges, ChannelEdge{
×
UNCOV
5414
                        Info:    channel,
×
UNCOV
5415
                        Policy1: p1,
×
5416
                        Policy2: p2,
×
5417
                        Node1:   node1,
×
5418
                        Node2:   node2,
×
5419
                })
×
5420
        }
5421

5422
        return edges, nil
×
5423
}
5424

5425
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
5426
// instances from the provided rows using batch loading for channel data.
5427
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
5428
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
5429
        []*models.ChannelEdgeInfo, []int64, error) {
×
5430

×
5431
        if len(rows) == 0 {
×
5432
                return nil, nil, nil
×
UNCOV
5433
        }
×
5434

5435
        // Collect all the channel IDs needed for batch loading.
UNCOV
5436
        channelIDs := make([]int64, len(rows))
×
5437
        for i, row := range rows {
×
UNCOV
5438
                channelIDs[i] = row.Channel().ID
×
UNCOV
5439
        }
×
5440

5441
        // Batch load the channel data.
UNCOV
5442
        channelBatchData, err := batchLoadChannelData(
×
UNCOV
5443
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
UNCOV
5444
        )
×
UNCOV
5445
        if err != nil {
×
5446
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
5447
                        "data: %w", err)
×
5448
        }
×
5449

5450
        // Build all channel edges using batch data.
5451
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
5452
        for _, row := range rows {
×
5453
                node1, node2, err := buildNodeVertices(
×
5454
                        row.Node1Pub(), row.Node2Pub(),
×
5455
                )
×
5456
                if err != nil {
×
5457
                        return nil, nil, err
×
5458
                }
×
5459

5460
                // Build channel info using batch data
5461
                info, err := buildEdgeInfoWithBatchData(
×
5462
                        cfg.ChainHash, row.Channel(), node1, node2,
×
5463
                        channelBatchData,
×
5464
                )
×
UNCOV
5465
                if err != nil {
×
UNCOV
5466
                        return nil, nil, err
×
5467
                }
×
5468

5469
                edges = append(edges, info)
×
5470
        }
5471

5472
        return edges, channelIDs, nil
×
5473
}
5474

5475
// handleZombieMarking is a helper function that handles the logic of
5476
// marking a channel as a zombie in the database. It takes into account whether
5477
// we are in strict zombie pruning mode, and adjusts the node public keys
5478
// accordingly based on the last update timestamps of the channel policies.
5479
func handleZombieMarking(ctx context.Context, db SQLQueries,
5480
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
UNCOV
5481
        strictZombiePruning bool, scid uint64) error {
×
UNCOV
5482

×
UNCOV
5483
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
UNCOV
5484

×
UNCOV
5485
        if strictZombiePruning {
×
UNCOV
5486
                var e1UpdateTime, e2UpdateTime *time.Time
×
UNCOV
5487
                if row.Policy1LastUpdate.Valid {
×
UNCOV
5488
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
UNCOV
5489
                        e1UpdateTime = &e1Time
×
UNCOV
5490
                }
×
UNCOV
5491
                if row.Policy2LastUpdate.Valid {
×
UNCOV
5492
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
UNCOV
5493
                        e2UpdateTime = &e2Time
×
UNCOV
5494
                }
×
5495

UNCOV
5496
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
UNCOV
5497
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
UNCOV
5498
                        e2UpdateTime,
×
UNCOV
5499
                )
×
5500
        }
5501

UNCOV
5502
        return db.UpsertZombieChannel(
×
UNCOV
5503
                ctx, sqlc.UpsertZombieChannelParams{
×
UNCOV
5504
                        Version:  int16(ProtocolV1),
×
UNCOV
5505
                        Scid:     channelIDToBytes(scid),
×
UNCOV
5506
                        NodeKey1: nodeKey1[:],
×
UNCOV
5507
                        NodeKey2: nodeKey2[:],
×
UNCOV
5508
                },
×
UNCOV
5509
        )
×
5510
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc