• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16796616815

07 Aug 2025 06:12AM UTC coverage: 66.9% (+0.03%) from 66.868%
16796616815

Pull #10129

github

web-flow
Merge 522e200c0 into e5359f2f5
Pull Request #10129: [8] graph/db: use batch loading for various graph SQL methods

6 of 332 new or added lines in 4 files covered. (1.81%)

118 existing lines in 23 files now uncovered.

135673 of 202800 relevant lines covered (66.9%)

21550.39 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// ProtocolVersion is an enum that defines the gossip protocol version of a
35
// message.
36
type ProtocolVersion uint8
37

38
const (
39
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
40
        ProtocolV1 ProtocolVersion = 1
41
)
42

43
// String returns a string representation of the protocol version.
44
func (v ProtocolVersion) String() string {
×
45
        return fmt.Sprintf("V%d", v)
×
46
}
×
47

48
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
49
// execute queries against the SQL graph tables.
50
//
51
//nolint:ll,interfacebloat
52
type SQLQueries interface {
53
        /*
54
                Node queries.
55
        */
56
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
57
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
58
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
59
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
60
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
61
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
62
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
63
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
64
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
65
        DeleteNode(ctx context.Context, id int64) error
66

67
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
68
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
69
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
70
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
71

72
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
73
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
74
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
75
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
76

77
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
78
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
79
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
80
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
81
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
82

83
        /*
84
                Source node queries.
85
        */
86
        AddSourceNode(ctx context.Context, nodeID int64) error
87
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
88

89
        /*
90
                Channel queries.
91
        */
92
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
93
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
94
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
95
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
96
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
97
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
98
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
99
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
100
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
101
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
102
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
103
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
104
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
105
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
106
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
107
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
108
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
109
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
110
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
111
        DeleteChannels(ctx context.Context, ids []int64) error
112

113
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
114
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
115
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
116
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
117

118
        /*
119
                Channel Policy table queries.
120
        */
121
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
122
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
123
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
124

125
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
126
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
127
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
128

129
        /*
130
                Zombie index queries.
131
        */
132
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
133
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
134
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
135
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
136
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
137

138
        /*
139
                Prune log table queries.
140
        */
141
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
142
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
143
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
144
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
145

146
        /*
147
                Closed SCID table queries.
148
        */
149
        InsertClosedChannel(ctx context.Context, scid []byte) error
150
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
151
}
152

153
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
154
// database operations.
155
type BatchedSQLQueries interface {
156
        SQLQueries
157
        sqldb.BatchedTx[SQLQueries]
158
}
159

160
// SQLStore is an implementation of the V1Store interface that uses a SQL
161
// database as the backend.
162
type SQLStore struct {
163
        cfg *SQLStoreConfig
164
        db  BatchedSQLQueries
165

166
        // cacheMu guards all caches (rejectCache and chanCache). If
167
        // this mutex will be acquired at the same time as the DB mutex then
168
        // the cacheMu MUST be acquired first to prevent deadlock.
169
        cacheMu     sync.RWMutex
170
        rejectCache *rejectCache
171
        chanCache   *channelCache
172

173
        chanScheduler batch.Scheduler[SQLQueries]
174
        nodeScheduler batch.Scheduler[SQLQueries]
175

176
        srcNodes  map[ProtocolVersion]*srcNodeInfo
177
        srcNodeMu sync.Mutex
178
}
179

180
// A compile-time assertion to ensure that SQLStore implements the V1Store
181
// interface.
182
var _ V1Store = (*SQLStore)(nil)
183

184
// SQLStoreConfig holds the configuration for the SQLStore.
185
type SQLStoreConfig struct {
186
        // ChainHash is the genesis hash for the chain that all the gossip
187
        // messages in this store are aimed at.
188
        ChainHash chainhash.Hash
189

190
        // QueryConfig holds configuration values for SQL queries.
191
        QueryCfg *sqldb.QueryConfig
192
}
193

194
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
195
// storage backend.
196
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
197
        options ...StoreOptionModifier) (*SQLStore, error) {
×
198

×
199
        opts := DefaultOptions()
×
200
        for _, o := range options {
×
201
                o(opts)
×
202
        }
×
203

204
        if opts.NoMigration {
×
205
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
206
                        "supported for SQL stores")
×
207
        }
×
208

209
        s := &SQLStore{
×
210
                cfg:         cfg,
×
211
                db:          db,
×
212
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
213
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
214
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
215
        }
×
216

×
217
        s.chanScheduler = batch.NewTimeScheduler(
×
218
                db, &s.cacheMu, opts.BatchCommitInterval,
×
219
        )
×
220
        s.nodeScheduler = batch.NewTimeScheduler(
×
221
                db, nil, opts.BatchCommitInterval,
×
222
        )
×
223

×
224
        return s, nil
×
225
}
226

227
// AddLightningNode adds a vertex/node to the graph database. If the node is not
228
// in the database from before, this will add a new, unconnected one to the
229
// graph. If it is present from before, this will update that node's
230
// information.
231
//
232
// NOTE: part of the V1Store interface.
233
func (s *SQLStore) AddLightningNode(ctx context.Context,
234
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
235

×
236
        r := &batch.Request[SQLQueries]{
×
237
                Opts: batch.NewSchedulerOptions(opts...),
×
238
                Do: func(queries SQLQueries) error {
×
239
                        _, err := upsertNode(ctx, queries, node)
×
240
                        return err
×
241
                },
×
242
        }
243

244
        return s.nodeScheduler.Execute(ctx, r)
×
245
}
246

247
// FetchLightningNode attempts to look up a target node by its identity public
248
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
249
// returned.
250
//
251
// NOTE: part of the V1Store interface.
252
func (s *SQLStore) FetchLightningNode(ctx context.Context,
253
        pubKey route.Vertex) (*models.LightningNode, error) {
×
254

×
255
        var node *models.LightningNode
×
256
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
257
                var err error
×
258
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
259

×
260
                return err
×
261
        }, sqldb.NoOpReset)
×
262
        if err != nil {
×
263
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
264
        }
×
265

266
        return node, nil
×
267
}
268

269
// HasLightningNode determines if the graph has a vertex identified by the
270
// target node identity public key. If the node exists in the database, a
271
// timestamp of when the data for the node was lasted updated is returned along
272
// with a true boolean. Otherwise, an empty time.Time is returned with a false
273
// boolean.
274
//
275
// NOTE: part of the V1Store interface.
276
func (s *SQLStore) HasLightningNode(ctx context.Context,
277
        pubKey [33]byte) (time.Time, bool, error) {
×
278

×
279
        var (
×
280
                exists     bool
×
281
                lastUpdate time.Time
×
282
        )
×
283
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
284
                dbNode, err := db.GetNodeByPubKey(
×
285
                        ctx, sqlc.GetNodeByPubKeyParams{
×
286
                                Version: int16(ProtocolV1),
×
287
                                PubKey:  pubKey[:],
×
288
                        },
×
289
                )
×
290
                if errors.Is(err, sql.ErrNoRows) {
×
291
                        return nil
×
292
                } else if err != nil {
×
293
                        return fmt.Errorf("unable to fetch node: %w", err)
×
294
                }
×
295

296
                exists = true
×
297

×
298
                if dbNode.LastUpdate.Valid {
×
299
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
300
                }
×
301

302
                return nil
×
303
        }, sqldb.NoOpReset)
304
        if err != nil {
×
305
                return time.Time{}, false,
×
306
                        fmt.Errorf("unable to fetch node: %w", err)
×
307
        }
×
308

309
        return lastUpdate, exists, nil
×
310
}
311

312
// AddrsForNode returns all known addresses for the target node public key
313
// that the graph DB is aware of. The returned boolean indicates if the
314
// given node is unknown to the graph DB or not.
315
//
316
// NOTE: part of the V1Store interface.
317
func (s *SQLStore) AddrsForNode(ctx context.Context,
318
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
319

×
320
        var (
×
321
                addresses []net.Addr
×
322
                known     bool
×
323
        )
×
324
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
325
                // First, check if the node exists and get its DB ID if it
×
326
                // does.
×
327
                dbID, err := db.GetNodeIDByPubKey(
×
328
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
329
                                Version: int16(ProtocolV1),
×
330
                                PubKey:  nodePub.SerializeCompressed(),
×
331
                        },
×
332
                )
×
333
                if errors.Is(err, sql.ErrNoRows) {
×
334
                        return nil
×
335
                }
×
336

337
                known = true
×
338

×
339
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
340
                if err != nil {
×
341
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
342
                                err)
×
343
                }
×
344

345
                return nil
×
346
        }, sqldb.NoOpReset)
347
        if err != nil {
×
348
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
349
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
350
        }
×
351

352
        return known, addresses, nil
×
353
}
354

355
// DeleteLightningNode starts a new database transaction to remove a vertex/node
356
// from the database according to the node's public key.
357
//
358
// NOTE: part of the V1Store interface.
359
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
360
        pubKey route.Vertex) error {
×
361

×
362
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
363
                res, err := db.DeleteNodeByPubKey(
×
364
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
365
                                Version: int16(ProtocolV1),
×
366
                                PubKey:  pubKey[:],
×
367
                        },
×
368
                )
×
369
                if err != nil {
×
370
                        return err
×
371
                }
×
372

373
                rows, err := res.RowsAffected()
×
374
                if err != nil {
×
375
                        return err
×
376
                }
×
377

378
                if rows == 0 {
×
379
                        return ErrGraphNodeNotFound
×
380
                } else if rows > 1 {
×
381
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
382
                }
×
383

384
                return err
×
385
        }, sqldb.NoOpReset)
386
        if err != nil {
×
387
                return fmt.Errorf("unable to delete node: %w", err)
×
388
        }
×
389

390
        return nil
×
391
}
392

393
// FetchNodeFeatures returns the features of the given node. If no features are
394
// known for the node, an empty feature vector is returned.
395
//
396
// NOTE: this is part of the graphdb.NodeTraverser interface.
397
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
398
        *lnwire.FeatureVector, error) {
×
399

×
400
        ctx := context.TODO()
×
401

×
402
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
403
}
×
404

405
// DisabledChannelIDs returns the channel ids of disabled channels.
406
// A channel is disabled when two of the associated ChanelEdgePolicies
407
// have their disabled bit on.
408
//
409
// NOTE: part of the V1Store interface.
410
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
411
        var (
×
412
                ctx     = context.TODO()
×
413
                chanIDs []uint64
×
414
        )
×
415
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
416
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
417
                if err != nil {
×
418
                        return fmt.Errorf("unable to fetch disabled "+
×
419
                                "channels: %w", err)
×
420
                }
×
421

422
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
423

×
424
                return nil
×
425
        }, sqldb.NoOpReset)
426
        if err != nil {
×
427
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
428
                        err)
×
429
        }
×
430

431
        return chanIDs, nil
×
432
}
433

434
// LookupAlias attempts to return the alias as advertised by the target node.
435
//
436
// NOTE: part of the V1Store interface.
437
func (s *SQLStore) LookupAlias(ctx context.Context,
438
        pub *btcec.PublicKey) (string, error) {
×
439

×
440
        var alias string
×
441
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
442
                dbNode, err := db.GetNodeByPubKey(
×
443
                        ctx, sqlc.GetNodeByPubKeyParams{
×
444
                                Version: int16(ProtocolV1),
×
445
                                PubKey:  pub.SerializeCompressed(),
×
446
                        },
×
447
                )
×
448
                if errors.Is(err, sql.ErrNoRows) {
×
449
                        return ErrNodeAliasNotFound
×
450
                } else if err != nil {
×
451
                        return fmt.Errorf("unable to fetch node: %w", err)
×
452
                }
×
453

454
                if !dbNode.Alias.Valid {
×
455
                        return ErrNodeAliasNotFound
×
456
                }
×
457

458
                alias = dbNode.Alias.String
×
459

×
460
                return nil
×
461
        }, sqldb.NoOpReset)
462
        if err != nil {
×
463
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
464
        }
×
465

466
        return alias, nil
×
467
}
468

469
// SourceNode returns the source node of the graph. The source node is treated
470
// as the center node within a star-graph. This method may be used to kick off
471
// a path finding algorithm in order to explore the reachability of another
472
// node based off the source node.
473
//
474
// NOTE: part of the V1Store interface.
475
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
476
        error) {
×
477

×
478
        var node *models.LightningNode
×
479
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
480
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
481
                if err != nil {
×
482
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
483
                                err)
×
484
                }
×
485

486
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
487

×
488
                return err
×
489
        }, sqldb.NoOpReset)
490
        if err != nil {
×
491
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
492
        }
×
493

494
        return node, nil
×
495
}
496

497
// SetSourceNode sets the source node within the graph database. The source
498
// node is to be used as the center of a star-graph within path finding
499
// algorithms.
500
//
501
// NOTE: part of the V1Store interface.
502
func (s *SQLStore) SetSourceNode(ctx context.Context,
503
        node *models.LightningNode) error {
×
504

×
505
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
506
                id, err := upsertNode(ctx, db, node)
×
507
                if err != nil {
×
508
                        return fmt.Errorf("unable to upsert source node: %w",
×
509
                                err)
×
510
                }
×
511

512
                // Make sure that if a source node for this version is already
513
                // set, then the ID is the same as the one we are about to set.
514
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
515
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
516
                        return fmt.Errorf("unable to fetch source node: %w",
×
517
                                err)
×
518
                } else if err == nil {
×
519
                        if dbSourceNodeID != id {
×
520
                                return fmt.Errorf("v1 source node already "+
×
521
                                        "set to a different node: %d vs %d",
×
522
                                        dbSourceNodeID, id)
×
523
                        }
×
524

525
                        return nil
×
526
                }
527

528
                return db.AddSourceNode(ctx, id)
×
529
        }, sqldb.NoOpReset)
530
}
531

532
// NodeUpdatesInHorizon returns all the known lightning node which have an
533
// update timestamp within the passed range. This method can be used by two
534
// nodes to quickly determine if they have the same set of up to date node
535
// announcements.
536
//
537
// NOTE: This is part of the V1Store interface.
538
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
539
        endTime time.Time) ([]models.LightningNode, error) {
×
540

×
541
        ctx := context.TODO()
×
542

×
543
        var nodes []models.LightningNode
×
544
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
545
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
546
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
547
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
548
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
549
                        },
×
550
                )
×
551
                if err != nil {
×
552
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
553
                }
×
554

555
                err = forEachNodeInBatch(
×
556
                        ctx, s.cfg.QueryCfg, db, dbNodes,
×
557
                        func(_ int64, node *models.LightningNode) error {
×
558
                                nodes = append(nodes, *node)
×
559

×
560
                                return nil
×
561
                        },
×
562
                )
563
                if err != nil {
×
564
                        return fmt.Errorf("unable to build nodes: %w", err)
×
565
                }
×
566

567
                return nil
×
568
        }, sqldb.NoOpReset)
569
        if err != nil {
×
570
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
571
        }
×
572

573
        return nodes, nil
×
574
}
575

576
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
577
// undirected edge from the two target nodes are created. The information stored
578
// denotes the static attributes of the channel, such as the channelID, the keys
579
// involved in creation of the channel, and the set of features that the channel
580
// supports. The chanPoint and chanID are used to uniquely identify the edge
581
// globally within the database.
582
//
583
// NOTE: part of the V1Store interface.
584
func (s *SQLStore) AddChannelEdge(ctx context.Context,
585
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
586

×
587
        var alreadyExists bool
×
588
        r := &batch.Request[SQLQueries]{
×
589
                Opts: batch.NewSchedulerOptions(opts...),
×
590
                Reset: func() {
×
591
                        alreadyExists = false
×
592
                },
×
593
                Do: func(tx SQLQueries) error {
×
594
                        _, err := insertChannel(ctx, tx, edge)
×
595

×
596
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
597
                        // succeed, but propagate the error via local state.
×
598
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
599
                                alreadyExists = true
×
600
                                return nil
×
601
                        }
×
602

603
                        return err
×
604
                },
605
                OnCommit: func(err error) error {
×
606
                        switch {
×
607
                        case err != nil:
×
608
                                return err
×
609
                        case alreadyExists:
×
610
                                return ErrEdgeAlreadyExist
×
611
                        default:
×
612
                                s.rejectCache.remove(edge.ChannelID)
×
613
                                s.chanCache.remove(edge.ChannelID)
×
614
                                return nil
×
615
                        }
616
                },
617
        }
618

619
        return s.chanScheduler.Execute(ctx, r)
×
620
}
621

622
// HighestChanID returns the "highest" known channel ID in the channel graph.
623
// This represents the "newest" channel from the PoV of the chain. This method
624
// can be used by peers to quickly determine if their graphs are in sync.
625
//
626
// NOTE: This is part of the V1Store interface.
627
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
628
        var highestChanID uint64
×
629
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
630
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
631
                if errors.Is(err, sql.ErrNoRows) {
×
632
                        return nil
×
633
                } else if err != nil {
×
634
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
635
                                err)
×
636
                }
×
637

638
                highestChanID = byteOrder.Uint64(chanID)
×
639

×
640
                return nil
×
641
        }, sqldb.NoOpReset)
642
        if err != nil {
×
643
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
644
        }
×
645

646
        return highestChanID, nil
×
647
}
648

649
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
650
// within the database for the referenced channel. The `flags` attribute within
651
// the ChannelEdgePolicy determines which of the directed edges are being
652
// updated. If the flag is 1, then the first node's information is being
653
// updated, otherwise it's the second node's information. The node ordering is
654
// determined by the lexicographical ordering of the identity public keys of the
655
// nodes on either side of the channel.
656
//
657
// NOTE: part of the V1Store interface.
658
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
659
        edge *models.ChannelEdgePolicy,
660
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
661

×
662
        var (
×
663
                isUpdate1    bool
×
664
                edgeNotFound bool
×
665
                from, to     route.Vertex
×
666
        )
×
667

×
668
        r := &batch.Request[SQLQueries]{
×
669
                Opts: batch.NewSchedulerOptions(opts...),
×
670
                Reset: func() {
×
671
                        isUpdate1 = false
×
672
                        edgeNotFound = false
×
673
                },
×
674
                Do: func(tx SQLQueries) error {
×
675
                        var err error
×
676
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
677
                                ctx, tx, edge,
×
678
                        )
×
679
                        if err != nil {
×
680
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
681
                        }
×
682

683
                        // Silence ErrEdgeNotFound so that the batch can
684
                        // succeed, but propagate the error via local state.
685
                        if errors.Is(err, ErrEdgeNotFound) {
×
686
                                edgeNotFound = true
×
687
                                return nil
×
688
                        }
×
689

690
                        return err
×
691
                },
692
                OnCommit: func(err error) error {
×
693
                        switch {
×
694
                        case err != nil:
×
695
                                return err
×
696
                        case edgeNotFound:
×
697
                                return ErrEdgeNotFound
×
698
                        default:
×
699
                                s.updateEdgeCache(edge, isUpdate1)
×
700
                                return nil
×
701
                        }
702
                },
703
        }
704

705
        err := s.chanScheduler.Execute(ctx, r)
×
706

×
707
        return from, to, err
×
708
}
709

710
// updateEdgeCache updates our reject and channel caches with the new
711
// edge policy information.
712
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
713
        isUpdate1 bool) {
×
714

×
715
        // If an entry for this channel is found in reject cache, we'll modify
×
716
        // the entry with the updated timestamp for the direction that was just
×
717
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
718
        // during the next query for this edge.
×
719
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
720
                if isUpdate1 {
×
721
                        entry.upd1Time = e.LastUpdate.Unix()
×
722
                } else {
×
723
                        entry.upd2Time = e.LastUpdate.Unix()
×
724
                }
×
725
                s.rejectCache.insert(e.ChannelID, entry)
×
726
        }
727

728
        // If an entry for this channel is found in channel cache, we'll modify
729
        // the entry with the updated policy for the direction that was just
730
        // written. If the edge doesn't exist, we'll defer loading the info and
731
        // policies and lazily read from disk during the next query.
732
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
733
                if isUpdate1 {
×
734
                        channel.Policy1 = e
×
735
                } else {
×
736
                        channel.Policy2 = e
×
737
                }
×
738
                s.chanCache.insert(e.ChannelID, channel)
×
739
        }
740
}
741

742
// ForEachSourceNodeChannel iterates through all channels of the source node,
743
// executing the passed callback on each. The call-back is provided with the
744
// channel's outpoint, whether we have a policy for the channel and the channel
745
// peer's node information.
746
//
747
// NOTE: part of the V1Store interface.
748
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
749
        cb func(chanPoint wire.OutPoint, havePolicy bool,
750
                otherNode *models.LightningNode) error, reset func()) error {
×
751

×
752
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
753
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
754
                if err != nil {
×
755
                        return fmt.Errorf("unable to fetch source node: %w",
×
756
                                err)
×
757
                }
×
758

759
                return forEachNodeChannel(
×
760
                        ctx, db, s.cfg, nodeID,
×
761
                        func(info *models.ChannelEdgeInfo,
×
762
                                outPolicy *models.ChannelEdgePolicy,
×
763
                                _ *models.ChannelEdgePolicy) error {
×
764

×
765
                                // Fetch the other node.
×
766
                                var (
×
767
                                        otherNodePub [33]byte
×
768
                                        node1        = info.NodeKey1Bytes
×
769
                                        node2        = info.NodeKey2Bytes
×
770
                                )
×
771
                                switch {
×
772
                                case bytes.Equal(node1[:], nodePub[:]):
×
773
                                        otherNodePub = node2
×
774
                                case bytes.Equal(node2[:], nodePub[:]):
×
775
                                        otherNodePub = node1
×
776
                                default:
×
777
                                        return fmt.Errorf("node not " +
×
778
                                                "participating in this channel")
×
779
                                }
780

781
                                _, otherNode, err := getNodeByPubKey(
×
782
                                        ctx, db, otherNodePub,
×
783
                                )
×
784
                                if err != nil {
×
785
                                        return fmt.Errorf("unable to fetch "+
×
786
                                                "other node(%x): %w",
×
787
                                                otherNodePub, err)
×
788
                                }
×
789

790
                                return cb(
×
791
                                        info.ChannelPoint, outPolicy != nil,
×
792
                                        otherNode,
×
793
                                )
×
794
                        },
795
                )
796
        }, reset)
797
}
798

799
// ForEachNode iterates through all the stored vertices/nodes in the graph,
800
// executing the passed callback with each node encountered. If the callback
801
// returns an error, then the transaction is aborted and the iteration stops
802
// early.
803
//
804
// NOTE: part of the V1Store interface.
805
func (s *SQLStore) ForEachNode(ctx context.Context,
806
        cb func(node *models.LightningNode) error, reset func()) error {
×
807

×
808
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
809
                return forEachNodePaginated(
×
810
                        ctx, s.cfg.QueryCfg, db,
×
811
                        ProtocolV1, func(_ context.Context, _ int64,
×
812
                                node *models.LightningNode) error {
×
813

×
814
                                return cb(node)
×
815
                        },
×
816
                )
817
        }, reset)
818
}
819

820
// ForEachNodeDirectedChannel iterates through all channels of a given node,
821
// executing the passed callback on the directed edge representing the channel
822
// and its incoming policy. If the callback returns an error, then the iteration
823
// is halted with the error propagated back up to the caller.
824
//
825
// Unknown policies are passed into the callback as nil values.
826
//
827
// NOTE: this is part of the graphdb.NodeTraverser interface.
828
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
829
        cb func(channel *DirectedChannel) error, reset func()) error {
×
830

×
831
        var ctx = context.TODO()
×
832

×
833
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
834
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
835
        }, reset)
×
836
}
837

838
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
839
// graph, executing the passed callback with each node encountered. If the
840
// callback returns an error, then the transaction is aborted and the iteration
841
// stops early.
842
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
843
        cb func(route.Vertex, *lnwire.FeatureVector) error,
844
        reset func()) error {
×
845

×
846
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
847
                return forEachNodeCacheable(
×
848
                        ctx, s.cfg.QueryCfg, db,
×
849
                        func(_ int64, nodePub route.Vertex,
×
850
                                features *lnwire.FeatureVector) error {
×
851

×
852
                                return cb(nodePub, features)
×
853
                        },
×
854
                )
855
        }, reset)
856
        if err != nil {
×
857
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
858
        }
×
859

860
        return nil
×
861
}
862

863
// ForEachNodeChannel iterates through all channels of the given node,
864
// executing the passed callback with an edge info structure and the policies
865
// of each end of the channel. The first edge policy is the outgoing edge *to*
866
// the connecting node, while the second is the incoming edge *from* the
867
// connecting node. If the callback returns an error, then the iteration is
868
// halted with the error propagated back up to the caller.
869
//
870
// Unknown policies are passed into the callback as nil values.
871
//
872
// NOTE: part of the V1Store interface.
873
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
874
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
875
                *models.ChannelEdgePolicy) error, reset func()) error {
×
876

×
877
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
878
                dbNode, err := db.GetNodeByPubKey(
×
879
                        ctx, sqlc.GetNodeByPubKeyParams{
×
880
                                Version: int16(ProtocolV1),
×
881
                                PubKey:  nodePub[:],
×
882
                        },
×
883
                )
×
884
                if errors.Is(err, sql.ErrNoRows) {
×
885
                        return nil
×
886
                } else if err != nil {
×
887
                        return fmt.Errorf("unable to fetch node: %w", err)
×
888
                }
×
889

890
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
891
        }, reset)
892
}
893

894
// ChanUpdatesInHorizon returns all the known channel edges which have at least
895
// one edge that has an update timestamp within the specified horizon.
896
//
897
// NOTE: This is part of the V1Store interface.
898
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
899
        endTime time.Time) ([]ChannelEdge, error) {
×
900

×
901
        s.cacheMu.Lock()
×
902
        defer s.cacheMu.Unlock()
×
903

×
904
        var (
×
905
                ctx = context.TODO()
×
906
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
907
                // an additional map to keep track of the edges already seen to
×
908
                // prevent re-adding it.
×
909
                edgesSeen    = make(map[uint64]struct{})
×
910
                edgesToCache = make(map[uint64]ChannelEdge)
×
911
                edges        []ChannelEdge
×
912
                hits         int
×
913
        )
×
NEW
914

×
915
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
916
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
917
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
918
                                Version:   int16(ProtocolV1),
×
919
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
920
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
921
                        },
×
922
                )
×
923
                if err != nil {
×
924
                        return err
×
925
                }
×
926

NEW
927
                if len(rows) == 0 {
×
NEW
928
                        return nil
×
NEW
929
                }
×
930

931
                // We'll pre-allocate the slices and maps here with a best
932
                // effort size in order to avoid unnecessary allocations later
933
                // on.
NEW
934
                uncachedRows := make(
×
NEW
935
                        []sqlc.GetChannelsByPolicyLastUpdateRangeRow, 0,
×
NEW
936
                        len(rows),
×
NEW
937
                )
×
NEW
938
                edgesToCache = make(map[uint64]ChannelEdge, len(rows))
×
NEW
939
                edgesSeen = make(map[uint64]struct{}, len(rows))
×
NEW
940
                edges = make([]ChannelEdge, 0, len(rows))
×
NEW
941

×
NEW
942
                // Separate cached from non-cached channels since we will only
×
NEW
943
                // batch load the data for the ones we haven't cached yet.
×
944
                for _, row := range rows {
×
945
                        chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid)
×
NEW
946

×
NEW
947
                        // Skip duplicates.
×
948
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
949
                                continue
×
950
                        }
NEW
951
                        edgesSeen[chanIDInt] = struct{}{}
×
952

×
NEW
953
                        // Check cache first.
×
954
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
955
                                hits++
×
956
                                edges = append(edges, channel)
×
957
                                continue
×
958
                        }
959

960
                        // Mark this row as one we need to batch load data for.
NEW
961
                        uncachedRows = append(uncachedRows, row)
×
962
                }
963

964
                // If there are no uncached rows, then we can return early.
NEW
965
                if len(uncachedRows) == 0 {
×
NEW
966
                        return nil
×
NEW
967
                }
×
968

969
                // Batch load data for all uncached channels.
NEW
970
                newEdges, err := batchBuildChannelEdges(
×
NEW
971
                        ctx, s.cfg, db, uncachedRows,
×
NEW
972
                )
×
NEW
973
                if err != nil {
×
NEW
974
                        return fmt.Errorf("unable to batch build channel "+
×
NEW
975
                                "edges: %w", err)
×
UNCOV
976
                }
×
977

NEW
978
                edges = append(edges, newEdges...)
×
NEW
979

×
980
                return nil
×
981
        }, sqldb.NoOpReset)
982
        if err != nil {
×
983
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
984
        }
×
985

986
        // Insert any edges loaded from disk into the cache.
987
        for chanid, channel := range edgesToCache {
×
988
                s.chanCache.insert(chanid, channel)
×
989
        }
×
990

991
        if len(edges) > 0 {
×
992
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
993
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
994
        } else {
×
995
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
996
                        "horizon (%s, %s)", startTime, endTime)
×
997
        }
×
998

999
        return edges, nil
×
1000
}
1001

1002
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1003
// data to the call-back. If withAddrs is true, then the call-back will also be
1004
// provided with the addresses associated with the node. The address retrieval
1005
// result in an additional round-trip to the database, so it should only be used
1006
// if the addresses are actually needed.
1007
//
1008
// NOTE: part of the V1Store interface.
1009
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1010
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1011
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1012

×
1013
        type nodeCachedBatchData struct {
×
1014
                features      map[int64][]int
×
1015
                addrs         map[int64][]nodeAddress
×
1016
                chanBatchData *batchChannelData
×
1017
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1018
        }
×
1019

×
1020
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1021
                // pageQueryFunc is used to query the next page of nodes.
×
1022
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1023
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1024

×
1025
                        return db.ListNodeIDsAndPubKeys(
×
1026
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1027
                                        Version: int16(ProtocolV1),
×
1028
                                        ID:      lastID,
×
1029
                                        Limit:   limit,
×
1030
                                },
×
1031
                        )
×
1032
                }
×
1033

1034
                // batchDataFunc is then used to batch load the data required
1035
                // for each page of nodes.
1036
                batchDataFunc := func(ctx context.Context,
×
1037
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1038

×
1039
                        // Batch load node features.
×
1040
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1041
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1042
                        )
×
1043
                        if err != nil {
×
1044
                                return nil, fmt.Errorf("unable to batch load "+
×
1045
                                        "node features: %w", err)
×
1046
                        }
×
1047

1048
                        // Maybe fetch the node's addresses if requested.
1049
                        var nodeAddrs map[int64][]nodeAddress
×
1050
                        if withAddrs {
×
1051
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1052
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1053
                                )
×
1054
                                if err != nil {
×
1055
                                        return nil, fmt.Errorf("unable to "+
×
1056
                                                "batch load node "+
×
1057
                                                "addresses: %w", err)
×
1058
                                }
×
1059
                        }
1060

1061
                        // Batch load ALL unique channels for ALL nodes in this
1062
                        // page.
1063
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1064
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1065
                                        Version:  int16(ProtocolV1),
×
1066
                                        Node1Ids: nodeIDs,
×
1067
                                        Node2Ids: nodeIDs,
×
1068
                                },
×
1069
                        )
×
1070
                        if err != nil {
×
1071
                                return nil, fmt.Errorf("unable to batch "+
×
1072
                                        "fetch channels for nodes: %w", err)
×
1073
                        }
×
1074

1075
                        // Deduplicate channels and collect IDs.
1076
                        var (
×
1077
                                allChannelIDs []int64
×
1078
                                allPolicyIDs  []int64
×
1079
                        )
×
1080
                        uniqueChannels := make(
×
1081
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1082
                        )
×
1083

×
1084
                        for _, channel := range allChannels {
×
1085
                                channelID := channel.GraphChannel.ID
×
1086

×
1087
                                // Only process each unique channel once.
×
1088
                                _, exists := uniqueChannels[channelID]
×
1089
                                if exists {
×
1090
                                        continue
×
1091
                                }
1092

1093
                                uniqueChannels[channelID] = channel
×
1094
                                allChannelIDs = append(allChannelIDs, channelID)
×
1095

×
1096
                                if channel.Policy1ID.Valid {
×
1097
                                        allPolicyIDs = append(
×
1098
                                                allPolicyIDs,
×
1099
                                                channel.Policy1ID.Int64,
×
1100
                                        )
×
1101
                                }
×
1102
                                if channel.Policy2ID.Valid {
×
1103
                                        allPolicyIDs = append(
×
1104
                                                allPolicyIDs,
×
1105
                                                channel.Policy2ID.Int64,
×
1106
                                        )
×
1107
                                }
×
1108
                        }
1109

1110
                        // Batch load channel data for all unique channels.
1111
                        channelBatchData, err := batchLoadChannelData(
×
1112
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1113
                                allPolicyIDs,
×
1114
                        )
×
1115
                        if err != nil {
×
1116
                                return nil, fmt.Errorf("unable to batch "+
×
1117
                                        "load channel data: %w", err)
×
1118
                        }
×
1119

1120
                        // Create map of node ID to channels that involve this
1121
                        // node.
1122
                        nodeIDSet := make(map[int64]bool)
×
1123
                        for _, nodeID := range nodeIDs {
×
1124
                                nodeIDSet[nodeID] = true
×
1125
                        }
×
1126

1127
                        nodeChannelMap := make(
×
1128
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1129
                        )
×
1130
                        for _, channel := range uniqueChannels {
×
1131
                                // Add channel to both nodes if they're in our
×
1132
                                // current page.
×
1133
                                node1 := channel.GraphChannel.NodeID1
×
1134
                                if nodeIDSet[node1] {
×
1135
                                        nodeChannelMap[node1] = append(
×
1136
                                                nodeChannelMap[node1], channel,
×
1137
                                        )
×
1138
                                }
×
1139
                                node2 := channel.GraphChannel.NodeID2
×
1140
                                if nodeIDSet[node2] {
×
1141
                                        nodeChannelMap[node2] = append(
×
1142
                                                nodeChannelMap[node2], channel,
×
1143
                                        )
×
1144
                                }
×
1145
                        }
1146

1147
                        return &nodeCachedBatchData{
×
1148
                                features:      nodeFeatures,
×
1149
                                addrs:         nodeAddrs,
×
1150
                                chanBatchData: channelBatchData,
×
1151
                                chanMap:       nodeChannelMap,
×
1152
                        }, nil
×
1153
                }
1154

1155
                // processItem is used to process each node in the current page.
1156
                processItem := func(ctx context.Context,
×
1157
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1158
                        batchData *nodeCachedBatchData) error {
×
1159

×
1160
                        // Build feature vector for this node.
×
1161
                        fv := lnwire.EmptyFeatureVector()
×
1162
                        features, exists := batchData.features[nodeData.ID]
×
1163
                        if exists {
×
1164
                                for _, bit := range features {
×
1165
                                        fv.Set(lnwire.FeatureBit(bit))
×
1166
                                }
×
1167
                        }
1168

1169
                        var nodePub route.Vertex
×
1170
                        copy(nodePub[:], nodeData.PubKey)
×
1171

×
1172
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1173

×
1174
                        toNodeCallback := func() route.Vertex {
×
1175
                                return nodePub
×
1176
                        }
×
1177

1178
                        // Build cached channels map for this node.
1179
                        channels := make(map[uint64]*DirectedChannel)
×
1180
                        for _, channelRow := range nodeChannels {
×
1181
                                directedChan, err := buildDirectedChannel(
×
1182
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1183
                                        channelRow, batchData.chanBatchData, fv,
×
1184
                                        toNodeCallback,
×
1185
                                )
×
1186
                                if err != nil {
×
1187
                                        return err
×
1188
                                }
×
1189

1190
                                channels[directedChan.ChannelID] = directedChan
×
1191
                        }
1192

1193
                        addrs, err := buildNodeAddresses(
×
1194
                                batchData.addrs[nodeData.ID],
×
1195
                        )
×
1196
                        if err != nil {
×
1197
                                return fmt.Errorf("unable to build node "+
×
1198
                                        "addresses: %w", err)
×
1199
                        }
×
1200

1201
                        return cb(ctx, nodePub, addrs, channels)
×
1202
                }
1203

1204
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1205
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1206
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1207
                                return node.ID
×
1208
                        },
×
1209
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1210
                                error) {
×
1211

×
1212
                                return node.ID, nil
×
1213
                        },
×
1214
                        batchDataFunc, processItem,
1215
                )
1216
        }, reset)
1217
}
1218

1219
// ForEachChannelCacheable iterates through all the channel edges stored
1220
// within the graph and invokes the passed callback for each edge. The
1221
// callback takes two edges as since this is a directed graph, both the
1222
// in/out edges are visited. If the callback returns an error, then the
1223
// transaction is aborted and the iteration stops early.
1224
//
1225
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1226
// pointer for that particular channel edge routing policy will be
1227
// passed into the callback.
1228
//
1229
// NOTE: this method is like ForEachChannel but fetches only the data
1230
// required for the graph cache.
1231
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1232
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1233
        reset func()) error {
×
1234

×
1235
        ctx := context.TODO()
×
1236

×
1237
        handleChannel := func(_ context.Context,
×
1238
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1239

×
1240
                node1, node2, err := buildNodeVertices(
×
1241
                        row.Node1Pubkey, row.Node2Pubkey,
×
1242
                )
×
1243
                if err != nil {
×
1244
                        return err
×
1245
                }
×
1246

1247
                edge := buildCacheableChannelInfo(
×
1248
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1249
                )
×
1250

×
1251
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1252
                if err != nil {
×
1253
                        return err
×
1254
                }
×
1255

1256
                pol1, pol2, err := buildCachedChanPolicies(
×
1257
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1258
                )
×
1259
                if err != nil {
×
1260
                        return err
×
1261
                }
×
1262

1263
                return cb(edge, pol1, pol2)
×
1264
        }
1265

1266
        extractCursor := func(
×
1267
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1268

×
1269
                return row.ID
×
1270
        }
×
1271

1272
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1273
                //nolint:ll
×
1274
                queryFunc := func(ctx context.Context, lastID int64,
×
1275
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1276
                        error) {
×
1277

×
1278
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1279
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1280
                                        Version: int16(ProtocolV1),
×
1281
                                        ID:      lastID,
×
1282
                                        Limit:   limit,
×
1283
                                },
×
1284
                        )
×
1285
                }
×
1286

1287
                return sqldb.ExecutePaginatedQuery(
×
1288
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1289
                        extractCursor, handleChannel,
×
1290
                )
×
1291
        }, reset)
1292
}
1293

1294
// ForEachChannel iterates through all the channel edges stored within the
1295
// graph and invokes the passed callback for each edge. The callback takes two
1296
// edges as since this is a directed graph, both the in/out edges are visited.
1297
// If the callback returns an error, then the transaction is aborted and the
1298
// iteration stops early.
1299
//
1300
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1301
// for that particular channel edge routing policy will be passed into the
1302
// callback.
1303
//
1304
// NOTE: part of the V1Store interface.
1305
func (s *SQLStore) ForEachChannel(ctx context.Context,
1306
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1307
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1308

×
1309
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1310
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1311
        }, reset)
×
1312
}
1313

1314
// FilterChannelRange returns the channel ID's of all known channels which were
1315
// mined in a block height within the passed range. The channel IDs are grouped
1316
// by their common block height. This method can be used to quickly share with a
1317
// peer the set of channels we know of within a particular range to catch them
1318
// up after a period of time offline. If withTimestamps is true then the
1319
// timestamp info of the latest received channel update messages of the channel
1320
// will be included in the response.
1321
//
1322
// NOTE: This is part of the V1Store interface.
1323
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1324
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1325

×
1326
        var (
×
1327
                ctx       = context.TODO()
×
1328
                startSCID = &lnwire.ShortChannelID{
×
1329
                        BlockHeight: startHeight,
×
1330
                }
×
1331
                endSCID = lnwire.ShortChannelID{
×
1332
                        BlockHeight: endHeight,
×
1333
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1334
                        TxPosition:  math.MaxUint16,
×
1335
                }
×
1336
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1337
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1338
        )
×
1339

×
1340
        // 1) get all channels where channelID is between start and end chan ID.
×
1341
        // 2) skip if not public (ie, no channel_proof)
×
1342
        // 3) collect that channel.
×
1343
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1344
        //    and add those timestamps to the collected channel.
×
1345
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1346
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1347
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1348
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1349
                                StartScid: chanIDStart,
×
1350
                                EndScid:   chanIDEnd,
×
1351
                        },
×
1352
                )
×
1353
                if err != nil {
×
1354
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1355
                                err)
×
1356
                }
×
1357

1358
                for _, dbChan := range dbChans {
×
1359
                        cid := lnwire.NewShortChanIDFromInt(
×
1360
                                byteOrder.Uint64(dbChan.Scid),
×
1361
                        )
×
1362
                        chanInfo := NewChannelUpdateInfo(
×
1363
                                cid, time.Time{}, time.Time{},
×
1364
                        )
×
1365

×
1366
                        if !withTimestamps {
×
1367
                                channelsPerBlock[cid.BlockHeight] = append(
×
1368
                                        channelsPerBlock[cid.BlockHeight],
×
1369
                                        chanInfo,
×
1370
                                )
×
1371

×
1372
                                continue
×
1373
                        }
1374

1375
                        //nolint:ll
1376
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1377
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1378
                                        Version:   int16(ProtocolV1),
×
1379
                                        ChannelID: dbChan.ID,
×
1380
                                        NodeID:    dbChan.NodeID1,
×
1381
                                },
×
1382
                        )
×
1383
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1384
                                return fmt.Errorf("unable to fetch node1 "+
×
1385
                                        "policy: %w", err)
×
1386
                        } else if err == nil {
×
1387
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1388
                                        node1Policy.LastUpdate.Int64, 0,
×
1389
                                )
×
1390
                        }
×
1391

1392
                        //nolint:ll
1393
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1394
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1395
                                        Version:   int16(ProtocolV1),
×
1396
                                        ChannelID: dbChan.ID,
×
1397
                                        NodeID:    dbChan.NodeID2,
×
1398
                                },
×
1399
                        )
×
1400
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1401
                                return fmt.Errorf("unable to fetch node2 "+
×
1402
                                        "policy: %w", err)
×
1403
                        } else if err == nil {
×
1404
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1405
                                        node2Policy.LastUpdate.Int64, 0,
×
1406
                                )
×
1407
                        }
×
1408

1409
                        channelsPerBlock[cid.BlockHeight] = append(
×
1410
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1411
                        )
×
1412
                }
1413

1414
                return nil
×
1415
        }, func() {
×
1416
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1417
        })
×
1418
        if err != nil {
×
1419
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1420
        }
×
1421

1422
        if len(channelsPerBlock) == 0 {
×
1423
                return nil, nil
×
1424
        }
×
1425

1426
        // Return the channel ranges in ascending block height order.
1427
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1428
        slices.Sort(blocks)
×
1429

×
1430
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1431
                return BlockChannelRange{
×
1432
                        Height:   block,
×
1433
                        Channels: channelsPerBlock[block],
×
1434
                }
×
1435
        }), nil
×
1436
}
1437

1438
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1439
// zombie. This method is used on an ad-hoc basis, when channels need to be
1440
// marked as zombies outside the normal pruning cycle.
1441
//
1442
// NOTE: part of the V1Store interface.
1443
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1444
        pubKey1, pubKey2 [33]byte) error {
×
1445

×
1446
        ctx := context.TODO()
×
1447

×
1448
        s.cacheMu.Lock()
×
1449
        defer s.cacheMu.Unlock()
×
1450

×
1451
        chanIDB := channelIDToBytes(chanID)
×
1452

×
1453
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1454
                return db.UpsertZombieChannel(
×
1455
                        ctx, sqlc.UpsertZombieChannelParams{
×
1456
                                Version:  int16(ProtocolV1),
×
1457
                                Scid:     chanIDB,
×
1458
                                NodeKey1: pubKey1[:],
×
1459
                                NodeKey2: pubKey2[:],
×
1460
                        },
×
1461
                )
×
1462
        }, sqldb.NoOpReset)
×
1463
        if err != nil {
×
1464
                return fmt.Errorf("unable to upsert zombie channel "+
×
1465
                        "(channel_id=%d): %w", chanID, err)
×
1466
        }
×
1467

1468
        s.rejectCache.remove(chanID)
×
1469
        s.chanCache.remove(chanID)
×
1470

×
1471
        return nil
×
1472
}
1473

1474
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1475
//
1476
// NOTE: part of the V1Store interface.
1477
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1478
        s.cacheMu.Lock()
×
1479
        defer s.cacheMu.Unlock()
×
1480

×
1481
        var (
×
1482
                ctx     = context.TODO()
×
1483
                chanIDB = channelIDToBytes(chanID)
×
1484
        )
×
1485

×
1486
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1487
                res, err := db.DeleteZombieChannel(
×
1488
                        ctx, sqlc.DeleteZombieChannelParams{
×
1489
                                Scid:    chanIDB,
×
1490
                                Version: int16(ProtocolV1),
×
1491
                        },
×
1492
                )
×
1493
                if err != nil {
×
1494
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1495
                                err)
×
1496
                }
×
1497

1498
                rows, err := res.RowsAffected()
×
1499
                if err != nil {
×
1500
                        return err
×
1501
                }
×
1502

1503
                if rows == 0 {
×
1504
                        return ErrZombieEdgeNotFound
×
1505
                } else if rows > 1 {
×
1506
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1507
                                "expected 1", rows)
×
1508
                }
×
1509

1510
                return nil
×
1511
        }, sqldb.NoOpReset)
1512
        if err != nil {
×
1513
                return fmt.Errorf("unable to mark edge live "+
×
1514
                        "(channel_id=%d): %w", chanID, err)
×
1515
        }
×
1516

1517
        s.rejectCache.remove(chanID)
×
1518
        s.chanCache.remove(chanID)
×
1519

×
1520
        return err
×
1521
}
1522

1523
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1524
// zombie, then the two node public keys corresponding to this edge are also
1525
// returned.
1526
//
1527
// NOTE: part of the V1Store interface.
1528
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1529
        error) {
×
1530

×
1531
        var (
×
1532
                ctx              = context.TODO()
×
1533
                isZombie         bool
×
1534
                pubKey1, pubKey2 route.Vertex
×
1535
                chanIDB          = channelIDToBytes(chanID)
×
1536
        )
×
1537

×
1538
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1539
                zombie, err := db.GetZombieChannel(
×
1540
                        ctx, sqlc.GetZombieChannelParams{
×
1541
                                Scid:    chanIDB,
×
1542
                                Version: int16(ProtocolV1),
×
1543
                        },
×
1544
                )
×
1545
                if errors.Is(err, sql.ErrNoRows) {
×
1546
                        return nil
×
1547
                }
×
1548
                if err != nil {
×
1549
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1550
                                err)
×
1551
                }
×
1552

1553
                copy(pubKey1[:], zombie.NodeKey1)
×
1554
                copy(pubKey2[:], zombie.NodeKey2)
×
1555
                isZombie = true
×
1556

×
1557
                return nil
×
1558
        }, sqldb.NoOpReset)
1559
        if err != nil {
×
1560
                return false, route.Vertex{}, route.Vertex{},
×
1561
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1562
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1563
        }
×
1564

1565
        return isZombie, pubKey1, pubKey2, nil
×
1566
}
1567

1568
// NumZombies returns the current number of zombie channels in the graph.
1569
//
1570
// NOTE: part of the V1Store interface.
1571
func (s *SQLStore) NumZombies() (uint64, error) {
×
1572
        var (
×
1573
                ctx        = context.TODO()
×
1574
                numZombies uint64
×
1575
        )
×
1576
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1577
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1578
                if err != nil {
×
1579
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1580
                                err)
×
1581
                }
×
1582

1583
                numZombies = uint64(count)
×
1584

×
1585
                return nil
×
1586
        }, sqldb.NoOpReset)
1587
        if err != nil {
×
1588
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1589
        }
×
1590

1591
        return numZombies, nil
×
1592
}
1593

1594
// DeleteChannelEdges removes edges with the given channel IDs from the
1595
// database and marks them as zombies. This ensures that we're unable to re-add
1596
// it to our database once again. If an edge does not exist within the
1597
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1598
// true, then when we mark these edges as zombies, we'll set up the keys such
1599
// that we require the node that failed to send the fresh update to be the one
1600
// that resurrects the channel from its zombie state. The markZombie bool
1601
// denotes whether to mark the channel as a zombie.
1602
//
1603
// NOTE: part of the V1Store interface.
1604
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1605
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1606

×
1607
        s.cacheMu.Lock()
×
1608
        defer s.cacheMu.Unlock()
×
1609

×
1610
        // Keep track of which channels we end up finding so that we can
×
1611
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1612
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1613
        for _, chanID := range chanIDs {
×
1614
                chanLookup[chanID] = struct{}{}
×
1615
        }
×
1616

1617
        var (
×
NEW
1618
                ctx   = context.TODO()
×
NEW
1619
                edges []*models.ChannelEdgeInfo
×
1620
        )
×
1621
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
1622
                // First, collect all channel rows.
×
NEW
1623
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1624
                chanCallBack := func(ctx context.Context,
×
1625
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1626

×
1627
                        // Deleting the entry from the map indicates that we
×
1628
                        // have found the channel.
×
1629
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1630
                        delete(chanLookup, scid)
×
1631

×
NEW
1632
                        channelRows = append(channelRows, row)
×
1633

×
1634
                        return nil
×
UNCOV
1635
                }
×
1636

1637
                err := s.forEachChanWithPoliciesInSCIDList(
×
1638
                        ctx, db, chanCallBack, chanIDs,
×
1639
                )
×
1640
                if err != nil {
×
1641
                        return err
×
1642
                }
×
1643

1644
                if len(chanLookup) > 0 {
×
1645
                        return ErrEdgeNotFound
×
1646
                }
×
1647

NEW
1648
                if len(channelRows) == 0 {
×
NEW
1649
                        return nil
×
NEW
1650
                }
×
1651

1652
                // Batch build all channel edges.
NEW
1653
                var chanIDsToDelete []int64
×
NEW
1654
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
NEW
1655
                        ctx, s.cfg, db, channelRows,
×
NEW
1656
                )
×
NEW
1657
                if err != nil {
×
NEW
1658
                        return err
×
NEW
1659
                }
×
1660

NEW
1661
                if markZombie {
×
NEW
1662
                        for i, row := range channelRows {
×
NEW
1663
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
NEW
1664

×
NEW
1665
                                err := handleZombieMarking(
×
NEW
1666
                                        ctx, db, row, edges[i],
×
NEW
1667
                                        strictZombiePruning, scid,
×
NEW
1668
                                )
×
NEW
1669
                                if err != nil {
×
NEW
1670
                                        return fmt.Errorf("unable to mark "+
×
NEW
1671
                                                "channel as zombie: %w", err)
×
NEW
1672
                                }
×
1673
                        }
1674
                }
1675

1676
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1677
        }, func() {
×
NEW
1678
                edges = nil
×
1679

×
1680
                // Re-fill the lookup map.
×
1681
                for _, chanID := range chanIDs {
×
1682
                        chanLookup[chanID] = struct{}{}
×
1683
                }
×
1684
        })
1685
        if err != nil {
×
1686
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1687
                        err)
×
1688
        }
×
1689

1690
        for _, chanID := range chanIDs {
×
1691
                s.rejectCache.remove(chanID)
×
1692
                s.chanCache.remove(chanID)
×
1693
        }
×
1694

NEW
1695
        return edges, nil
×
1696
}
1697

1698
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1699
// channel identified by the channel ID. If the channel can't be found, then
1700
// ErrEdgeNotFound is returned. A struct which houses the general information
1701
// for the channel itself is returned as well as two structs that contain the
1702
// routing policies for the channel in either direction.
1703
//
1704
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1705
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1706
// the ChannelEdgeInfo will only include the public keys of each node.
1707
//
1708
// NOTE: part of the V1Store interface.
1709
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1710
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1711
        *models.ChannelEdgePolicy, error) {
×
1712

×
1713
        var (
×
1714
                ctx              = context.TODO()
×
1715
                edge             *models.ChannelEdgeInfo
×
1716
                policy1, policy2 *models.ChannelEdgePolicy
×
1717
                chanIDB          = channelIDToBytes(chanID)
×
1718
        )
×
1719
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1720
                row, err := db.GetChannelBySCIDWithPolicies(
×
1721
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1722
                                Scid:    chanIDB,
×
1723
                                Version: int16(ProtocolV1),
×
1724
                        },
×
1725
                )
×
1726
                if errors.Is(err, sql.ErrNoRows) {
×
1727
                        // First check if this edge is perhaps in the zombie
×
1728
                        // index.
×
1729
                        zombie, err := db.GetZombieChannel(
×
1730
                                ctx, sqlc.GetZombieChannelParams{
×
1731
                                        Scid:    chanIDB,
×
1732
                                        Version: int16(ProtocolV1),
×
1733
                                },
×
1734
                        )
×
1735
                        if errors.Is(err, sql.ErrNoRows) {
×
1736
                                return ErrEdgeNotFound
×
1737
                        } else if err != nil {
×
1738
                                return fmt.Errorf("unable to check if "+
×
1739
                                        "channel is zombie: %w", err)
×
1740
                        }
×
1741

1742
                        // At this point, we know the channel is a zombie, so
1743
                        // we'll return an error indicating this, and we will
1744
                        // populate the edge info with the public keys of each
1745
                        // party as this is the only information we have about
1746
                        // it.
1747
                        edge = &models.ChannelEdgeInfo{}
×
1748
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1749
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1750

×
1751
                        return ErrZombieEdge
×
1752
                } else if err != nil {
×
1753
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1754
                }
×
1755

1756
                node1, node2, err := buildNodeVertices(
×
1757
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1758
                )
×
1759
                if err != nil {
×
1760
                        return err
×
1761
                }
×
1762

1763
                edge, err = getAndBuildEdgeInfo(
×
1764
                        ctx, db, s.cfg.ChainHash, row.GraphChannel, node1,
×
1765
                        node2,
×
1766
                )
×
1767
                if err != nil {
×
1768
                        return fmt.Errorf("unable to build channel info: %w",
×
1769
                                err)
×
1770
                }
×
1771

1772
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1773
                if err != nil {
×
1774
                        return fmt.Errorf("unable to extract channel "+
×
1775
                                "policies: %w", err)
×
1776
                }
×
1777

1778
                policy1, policy2, err = getAndBuildChanPolicies(
×
1779
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1780
                )
×
1781
                if err != nil {
×
1782
                        return fmt.Errorf("unable to build channel "+
×
1783
                                "policies: %w", err)
×
1784
                }
×
1785

1786
                return nil
×
1787
        }, sqldb.NoOpReset)
1788
        if err != nil {
×
1789
                // If we are returning the ErrZombieEdge, then we also need to
×
1790
                // return the edge info as the method comment indicates that
×
1791
                // this will be populated when the edge is a zombie.
×
1792
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1793
                        err)
×
1794
        }
×
1795

1796
        return edge, policy1, policy2, nil
×
1797
}
1798

1799
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1800
// the channel identified by the funding outpoint. If the channel can't be
1801
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1802
// information for the channel itself is returned as well as two structs that
1803
// contain the routing policies for the channel in either direction.
1804
//
1805
// NOTE: part of the V1Store interface.
1806
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1807
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1808
        *models.ChannelEdgePolicy, error) {
×
1809

×
1810
        var (
×
1811
                ctx              = context.TODO()
×
1812
                edge             *models.ChannelEdgeInfo
×
1813
                policy1, policy2 *models.ChannelEdgePolicy
×
1814
        )
×
1815
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1816
                row, err := db.GetChannelByOutpointWithPolicies(
×
1817
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1818
                                Outpoint: op.String(),
×
1819
                                Version:  int16(ProtocolV1),
×
1820
                        },
×
1821
                )
×
1822
                if errors.Is(err, sql.ErrNoRows) {
×
1823
                        return ErrEdgeNotFound
×
1824
                } else if err != nil {
×
1825
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1826
                }
×
1827

1828
                node1, node2, err := buildNodeVertices(
×
1829
                        row.Node1Pubkey, row.Node2Pubkey,
×
1830
                )
×
1831
                if err != nil {
×
1832
                        return err
×
1833
                }
×
1834

1835
                edge, err = getAndBuildEdgeInfo(
×
1836
                        ctx, db, s.cfg.ChainHash, row.GraphChannel, node1,
×
1837
                        node2,
×
1838
                )
×
1839
                if err != nil {
×
1840
                        return fmt.Errorf("unable to build channel info: %w",
×
1841
                                err)
×
1842
                }
×
1843

1844
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1845
                if err != nil {
×
1846
                        return fmt.Errorf("unable to extract channel "+
×
1847
                                "policies: %w", err)
×
1848
                }
×
1849

1850
                policy1, policy2, err = getAndBuildChanPolicies(
×
1851
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1852
                )
×
1853
                if err != nil {
×
1854
                        return fmt.Errorf("unable to build channel "+
×
1855
                                "policies: %w", err)
×
1856
                }
×
1857

1858
                return nil
×
1859
        }, sqldb.NoOpReset)
1860
        if err != nil {
×
1861
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1862
                        err)
×
1863
        }
×
1864

1865
        return edge, policy1, policy2, nil
×
1866
}
1867

1868
// HasChannelEdge returns true if the database knows of a channel edge with the
1869
// passed channel ID, and false otherwise. If an edge with that ID is found
1870
// within the graph, then two time stamps representing the last time the edge
1871
// was updated for both directed edges are returned along with the boolean. If
1872
// it is not found, then the zombie index is checked and its result is returned
1873
// as the second boolean.
1874
//
1875
// NOTE: part of the V1Store interface.
1876
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1877
        bool, error) {
×
1878

×
1879
        ctx := context.TODO()
×
1880

×
1881
        var (
×
1882
                exists          bool
×
1883
                isZombie        bool
×
1884
                node1LastUpdate time.Time
×
1885
                node2LastUpdate time.Time
×
1886
        )
×
1887

×
1888
        // We'll query the cache with the shared lock held to allow multiple
×
1889
        // readers to access values in the cache concurrently if they exist.
×
1890
        s.cacheMu.RLock()
×
1891
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1892
                s.cacheMu.RUnlock()
×
1893
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1894
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1895
                exists, isZombie = entry.flags.unpack()
×
1896

×
1897
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1898
        }
×
1899
        s.cacheMu.RUnlock()
×
1900

×
1901
        s.cacheMu.Lock()
×
1902
        defer s.cacheMu.Unlock()
×
1903

×
1904
        // The item was not found with the shared lock, so we'll acquire the
×
1905
        // exclusive lock and check the cache again in case another method added
×
1906
        // the entry to the cache while no lock was held.
×
1907
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1908
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1909
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1910
                exists, isZombie = entry.flags.unpack()
×
1911

×
1912
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1913
        }
×
1914

1915
        chanIDB := channelIDToBytes(chanID)
×
1916
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1917
                channel, err := db.GetChannelBySCID(
×
1918
                        ctx, sqlc.GetChannelBySCIDParams{
×
1919
                                Scid:    chanIDB,
×
1920
                                Version: int16(ProtocolV1),
×
1921
                        },
×
1922
                )
×
1923
                if errors.Is(err, sql.ErrNoRows) {
×
1924
                        // Check if it is a zombie channel.
×
1925
                        isZombie, err = db.IsZombieChannel(
×
1926
                                ctx, sqlc.IsZombieChannelParams{
×
1927
                                        Scid:    chanIDB,
×
1928
                                        Version: int16(ProtocolV1),
×
1929
                                },
×
1930
                        )
×
1931
                        if err != nil {
×
1932
                                return fmt.Errorf("could not check if channel "+
×
1933
                                        "is zombie: %w", err)
×
1934
                        }
×
1935

1936
                        return nil
×
1937
                } else if err != nil {
×
1938
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1939
                }
×
1940

1941
                exists = true
×
1942

×
1943
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
1944
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1945
                                Version:   int16(ProtocolV1),
×
1946
                                ChannelID: channel.ID,
×
1947
                                NodeID:    channel.NodeID1,
×
1948
                        },
×
1949
                )
×
1950
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1951
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
1952
                                err)
×
1953
                } else if err == nil {
×
1954
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
1955
                }
×
1956

1957
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
1958
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1959
                                Version:   int16(ProtocolV1),
×
1960
                                ChannelID: channel.ID,
×
1961
                                NodeID:    channel.NodeID2,
×
1962
                        },
×
1963
                )
×
1964
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1965
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
1966
                                err)
×
1967
                } else if err == nil {
×
1968
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
1969
                }
×
1970

1971
                return nil
×
1972
        }, sqldb.NoOpReset)
1973
        if err != nil {
×
1974
                return time.Time{}, time.Time{}, false, false,
×
1975
                        fmt.Errorf("unable to fetch channel: %w", err)
×
1976
        }
×
1977

1978
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
1979
                upd1Time: node1LastUpdate.Unix(),
×
1980
                upd2Time: node2LastUpdate.Unix(),
×
1981
                flags:    packRejectFlags(exists, isZombie),
×
1982
        })
×
1983

×
1984
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1985
}
1986

1987
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
1988
// passed channel point (outpoint). If the passed channel doesn't exist within
1989
// the database, then ErrEdgeNotFound is returned.
1990
//
1991
// NOTE: part of the V1Store interface.
1992
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
1993
        var (
×
1994
                ctx       = context.TODO()
×
1995
                channelID uint64
×
1996
        )
×
1997
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1998
                chanID, err := db.GetSCIDByOutpoint(
×
1999
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2000
                                Outpoint: chanPoint.String(),
×
2001
                                Version:  int16(ProtocolV1),
×
2002
                        },
×
2003
                )
×
2004
                if errors.Is(err, sql.ErrNoRows) {
×
2005
                        return ErrEdgeNotFound
×
2006
                } else if err != nil {
×
2007
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2008
                                err)
×
2009
                }
×
2010

2011
                channelID = byteOrder.Uint64(chanID)
×
2012

×
2013
                return nil
×
2014
        }, sqldb.NoOpReset)
2015
        if err != nil {
×
2016
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2017
        }
×
2018

2019
        return channelID, nil
×
2020
}
2021

2022
// IsPublicNode is a helper method that determines whether the node with the
2023
// given public key is seen as a public node in the graph from the graph's
2024
// source node's point of view.
2025
//
2026
// NOTE: part of the V1Store interface.
2027
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2028
        ctx := context.TODO()
×
2029

×
2030
        var isPublic bool
×
2031
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2032
                var err error
×
2033
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2034

×
2035
                return err
×
2036
        }, sqldb.NoOpReset)
×
2037
        if err != nil {
×
2038
                return false, fmt.Errorf("unable to check if node is "+
×
2039
                        "public: %w", err)
×
2040
        }
×
2041

2042
        return isPublic, nil
×
2043
}
2044

2045
// FetchChanInfos returns the set of channel edges that correspond to the passed
2046
// channel ID's. If an edge is the query is unknown to the database, it will
2047
// skipped and the result will contain only those edges that exist at the time
2048
// of the query. This can be used to respond to peer queries that are seeking to
2049
// fill in gaps in their view of the channel graph.
2050
//
2051
// NOTE: part of the V1Store interface.
2052
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2053
        var (
×
2054
                ctx   = context.TODO()
×
2055
                edges = make(map[uint64]ChannelEdge)
×
2056
        )
×
2057
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2058
                // First, collect all channel rows.
×
NEW
2059
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2060
                chanCallBack := func(ctx context.Context,
×
2061
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2062

×
NEW
2063
                        channelRows = append(channelRows, row)
×
NEW
2064
                        return nil
×
NEW
2065
                }
×
2066

NEW
2067
                err := s.forEachChanWithPoliciesInSCIDList(
×
NEW
2068
                        ctx, db, chanCallBack, chanIDs,
×
NEW
2069
                )
×
NEW
2070
                if err != nil {
×
NEW
2071
                        return err
×
NEW
2072
                }
×
2073

NEW
2074
                if len(channelRows) == 0 {
×
UNCOV
2075
                        return nil
×
UNCOV
2076
                }
×
2077

2078
                // Batch build all channel edges.
NEW
2079
                chans, err := batchBuildChannelEdges(
×
NEW
2080
                        ctx, s.cfg, db, channelRows,
×
2081
                )
×
NEW
2082
                if err != nil {
×
NEW
2083
                        return fmt.Errorf("unable to build channel edges: %w",
×
NEW
2084
                                err)
×
NEW
2085
                }
×
2086

NEW
2087
                for _, c := range chans {
×
NEW
2088
                        edges[c.Info.ChannelID] = c
×
NEW
2089
                }
×
2090

NEW
2091
                return err
×
2092
        }, func() {
×
2093
                clear(edges)
×
2094
        })
×
2095
        if err != nil {
×
2096
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2097
        }
×
2098

2099
        res := make([]ChannelEdge, 0, len(edges))
×
2100
        for _, chanID := range chanIDs {
×
2101
                edge, ok := edges[chanID]
×
2102
                if !ok {
×
2103
                        continue
×
2104
                }
2105

2106
                res = append(res, edge)
×
2107
        }
2108

2109
        return res, nil
×
2110
}
2111

2112
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2113
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2114
// channels in a paginated manner.
2115
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2116
        db SQLQueries, cb func(ctx context.Context,
2117
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2118
        chanIDs []uint64) error {
×
2119

×
2120
        queryWrapper := func(ctx context.Context,
×
2121
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2122
                error) {
×
2123

×
2124
                return db.GetChannelsBySCIDWithPolicies(
×
2125
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2126
                                Version: int16(ProtocolV1),
×
2127
                                Scids:   scids,
×
2128
                        },
×
2129
                )
×
2130
        }
×
2131

2132
        return sqldb.ExecuteBatchQuery(
×
2133
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2134
                cb,
×
2135
        )
×
2136
}
2137

2138
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2139
// ID's that we don't know and are not known zombies of the passed set. In other
2140
// words, we perform a set difference of our set of chan ID's and the ones
2141
// passed in. This method can be used by callers to determine the set of
2142
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2143
// known zombies is also returned.
2144
//
2145
// NOTE: part of the V1Store interface.
2146
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2147
        []ChannelUpdateInfo, error) {
×
2148

×
2149
        var (
×
2150
                ctx          = context.TODO()
×
2151
                newChanIDs   []uint64
×
2152
                knownZombies []ChannelUpdateInfo
×
2153
                infoLookup   = make(
×
2154
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2155
                )
×
2156
        )
×
2157

×
2158
        // We first build a lookup map of the channel ID's to the
×
2159
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2160
        // already know about.
×
2161
        for _, chanInfo := range chansInfo {
×
2162
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2163
        }
×
2164

2165
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2166
                // The call-back function deletes known channels from
×
2167
                // infoLookup, so that we can later check which channels are
×
2168
                // zombies by only looking at the remaining channels in the set.
×
2169
                cb := func(ctx context.Context,
×
2170
                        channel sqlc.GraphChannel) error {
×
2171

×
2172
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2173

×
2174
                        return nil
×
2175
                }
×
2176

2177
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2178
                if err != nil {
×
2179
                        return fmt.Errorf("unable to iterate through "+
×
2180
                                "channels: %w", err)
×
2181
                }
×
2182

2183
                // We want to ensure that we deal with the channels in the
2184
                // same order that they were passed in, so we iterate over the
2185
                // original chansInfo slice and then check if that channel is
2186
                // still in the infoLookup map.
2187
                for _, chanInfo := range chansInfo {
×
2188
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2189
                        if _, ok := infoLookup[channelID]; !ok {
×
2190
                                continue
×
2191
                        }
2192

2193
                        isZombie, err := db.IsZombieChannel(
×
2194
                                ctx, sqlc.IsZombieChannelParams{
×
2195
                                        Scid:    channelIDToBytes(channelID),
×
2196
                                        Version: int16(ProtocolV1),
×
2197
                                },
×
2198
                        )
×
2199
                        if err != nil {
×
2200
                                return fmt.Errorf("unable to fetch zombie "+
×
2201
                                        "channel: %w", err)
×
2202
                        }
×
2203

2204
                        if isZombie {
×
2205
                                knownZombies = append(knownZombies, chanInfo)
×
2206

×
2207
                                continue
×
2208
                        }
2209

2210
                        newChanIDs = append(newChanIDs, channelID)
×
2211
                }
2212

2213
                return nil
×
2214
        }, func() {
×
2215
                newChanIDs = nil
×
2216
                knownZombies = nil
×
2217
                // Rebuild the infoLookup map in case of a rollback.
×
2218
                for _, chanInfo := range chansInfo {
×
2219
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2220
                        infoLookup[scid] = chanInfo
×
2221
                }
×
2222
        })
2223
        if err != nil {
×
2224
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2225
        }
×
2226

2227
        return newChanIDs, knownZombies, nil
×
2228
}
2229

2230
// forEachChanInSCIDList is a helper method that executes a paged query
2231
// against the database to fetch all channels that match the passed
2232
// ChannelUpdateInfo slice. The callback function is called for each channel
2233
// that is found.
2234
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2235
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2236
        chansInfo []ChannelUpdateInfo) error {
×
2237

×
2238
        queryWrapper := func(ctx context.Context,
×
2239
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2240

×
2241
                return db.GetChannelsBySCIDs(
×
2242
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2243
                                Version: int16(ProtocolV1),
×
2244
                                Scids:   scids,
×
2245
                        },
×
2246
                )
×
2247
        }
×
2248

2249
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2250
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2251

×
2252
                return channelIDToBytes(channelID)
×
2253
        }
×
2254

2255
        return sqldb.ExecuteBatchQuery(
×
2256
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2257
                cb,
×
2258
        )
×
2259
}
2260

2261
// PruneGraphNodes is a garbage collection method which attempts to prune out
2262
// any nodes from the channel graph that are currently unconnected. This ensure
2263
// that we only maintain a graph of reachable nodes. In the event that a pruned
2264
// node gains more channels, it will be re-added back to the graph.
2265
//
2266
// NOTE: this prunes nodes across protocol versions. It will never prune the
2267
// source nodes.
2268
//
2269
// NOTE: part of the V1Store interface.
2270
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2271
        var ctx = context.TODO()
×
2272

×
2273
        var prunedNodes []route.Vertex
×
2274
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2275
                var err error
×
2276
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2277

×
2278
                return err
×
2279
        }, func() {
×
2280
                prunedNodes = nil
×
2281
        })
×
2282
        if err != nil {
×
2283
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2284
        }
×
2285

2286
        return prunedNodes, nil
×
2287
}
2288

2289
// PruneGraph prunes newly closed channels from the channel graph in response
2290
// to a new block being solved on the network. Any transactions which spend the
2291
// funding output of any known channels within he graph will be deleted.
2292
// Additionally, the "prune tip", or the last block which has been used to
2293
// prune the graph is stored so callers can ensure the graph is fully in sync
2294
// with the current UTXO state. A slice of channels that have been closed by
2295
// the target block along with any pruned nodes are returned if the function
2296
// succeeds without error.
2297
//
2298
// NOTE: part of the V1Store interface.
2299
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2300
        blockHash *chainhash.Hash, blockHeight uint32) (
2301
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2302

×
2303
        ctx := context.TODO()
×
2304

×
2305
        s.cacheMu.Lock()
×
2306
        defer s.cacheMu.Unlock()
×
2307

×
2308
        var (
×
2309
                closedChans []*models.ChannelEdgeInfo
×
2310
                prunedNodes []route.Vertex
×
2311
        )
×
2312
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
2313
                // First, collect all channel rows that need to be pruned.
×
NEW
2314
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2315
                channelCallback := func(ctx context.Context,
×
2316
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2317

×
NEW
2318
                        channelRows = append(channelRows, row)
×
2319

×
2320
                        return nil
×
UNCOV
2321
                }
×
2322

2323
                err := s.forEachChanInOutpoints(
×
2324
                        ctx, db, spentOutputs, channelCallback,
×
2325
                )
×
2326
                if err != nil {
×
2327
                        return fmt.Errorf("unable to fetch channels by "+
×
2328
                                "outpoints: %w", err)
×
2329
                }
×
2330

NEW
2331
                if len(channelRows) == 0 {
×
NEW
2332
                        // There are no channels to prune. So we can exit early
×
NEW
2333
                        // after updating the prune log.
×
NEW
2334
                        err = db.UpsertPruneLogEntry(
×
NEW
2335
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
NEW
2336
                                        BlockHash:   blockHash[:],
×
NEW
2337
                                        BlockHeight: int64(blockHeight),
×
NEW
2338
                                },
×
NEW
2339
                        )
×
NEW
2340
                        if err != nil {
×
NEW
2341
                                return fmt.Errorf("unable to insert prune log "+
×
NEW
2342
                                        "entry: %w", err)
×
NEW
2343
                        }
×
2344

NEW
2345
                        return nil
×
2346
                }
2347

2348
                // Batch build all channel edges for pruning.
NEW
2349
                var chansToDelete []int64
×
NEW
2350
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
NEW
2351
                        ctx, s.cfg, db, channelRows,
×
NEW
2352
                )
×
NEW
2353
                if err != nil {
×
NEW
2354
                        return err
×
NEW
2355
                }
×
2356

2357
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2358
                if err != nil {
×
2359
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2360
                }
×
2361

2362
                err = db.UpsertPruneLogEntry(
×
2363
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2364
                                BlockHash:   blockHash[:],
×
2365
                                BlockHeight: int64(blockHeight),
×
2366
                        },
×
2367
                )
×
2368
                if err != nil {
×
2369
                        return fmt.Errorf("unable to insert prune log "+
×
2370
                                "entry: %w", err)
×
2371
                }
×
2372

2373
                // Now that we've pruned some channels, we'll also prune any
2374
                // nodes that no longer have any channels.
2375
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2376
                if err != nil {
×
2377
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2378
                                err)
×
2379
                }
×
2380

2381
                return nil
×
2382
        }, func() {
×
2383
                prunedNodes = nil
×
2384
                closedChans = nil
×
2385
        })
×
2386
        if err != nil {
×
2387
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2388
        }
×
2389

2390
        for _, channel := range closedChans {
×
2391
                s.rejectCache.remove(channel.ChannelID)
×
2392
                s.chanCache.remove(channel.ChannelID)
×
2393
        }
×
2394

2395
        return closedChans, prunedNodes, nil
×
2396
}
2397

2398
// forEachChanInOutpoints is a helper function that executes a paginated
2399
// query to fetch channels by their outpoints and applies the given call-back
2400
// to each.
2401
//
2402
// NOTE: this fetches channels for all protocol versions.
2403
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2404
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2405
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2406

×
2407
        // Create a wrapper that uses the transaction's db instance to execute
×
2408
        // the query.
×
2409
        queryWrapper := func(ctx context.Context,
×
2410
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2411
                error) {
×
2412

×
2413
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2414
        }
×
2415

2416
        // Define the conversion function from Outpoint to string.
2417
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2418
                return outpoint.String()
×
2419
        }
×
2420

2421
        return sqldb.ExecuteBatchQuery(
×
2422
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2423
                queryWrapper, cb,
×
2424
        )
×
2425
}
2426

2427
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2428
        dbIDs []int64) error {
×
2429

×
2430
        // Create a wrapper that uses the transaction's db instance to execute
×
2431
        // the query.
×
2432
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2433
                return nil, db.DeleteChannels(ctx, ids)
×
2434
        }
×
2435

2436
        idConverter := func(id int64) int64 {
×
2437
                return id
×
2438
        }
×
2439

2440
        return sqldb.ExecuteBatchQuery(
×
2441
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2442
                queryWrapper, func(ctx context.Context, _ any) error {
×
2443
                        return nil
×
2444
                },
×
2445
        )
2446
}
2447

2448
// ChannelView returns the verifiable edge information for each active channel
2449
// within the known channel graph. The set of UTXOs (along with their scripts)
2450
// returned are the ones that need to be watched on chain to detect channel
2451
// closes on the resident blockchain.
2452
//
2453
// NOTE: part of the V1Store interface.
2454
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2455
        var (
×
2456
                ctx        = context.TODO()
×
2457
                edgePoints []EdgePoint
×
2458
        )
×
2459

×
2460
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2461
                handleChannel := func(_ context.Context,
×
2462
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2463

×
2464
                        pkScript, err := genMultiSigP2WSH(
×
2465
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2466
                        )
×
2467
                        if err != nil {
×
2468
                                return err
×
2469
                        }
×
2470

2471
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2472
                        if err != nil {
×
2473
                                return err
×
2474
                        }
×
2475

2476
                        edgePoints = append(edgePoints, EdgePoint{
×
2477
                                FundingPkScript: pkScript,
×
2478
                                OutPoint:        *op,
×
2479
                        })
×
2480

×
2481
                        return nil
×
2482
                }
2483

2484
                queryFunc := func(ctx context.Context, lastID int64,
×
2485
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2486

×
2487
                        return db.ListChannelsPaginated(
×
2488
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2489
                                        Version: int16(ProtocolV1),
×
2490
                                        ID:      lastID,
×
2491
                                        Limit:   limit,
×
2492
                                },
×
2493
                        )
×
2494
                }
×
2495

2496
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2497
                        return row.ID
×
2498
                }
×
2499

2500
                return sqldb.ExecutePaginatedQuery(
×
2501
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
2502
                        extractCursor, handleChannel,
×
2503
                )
×
2504
        }, func() {
×
2505
                edgePoints = nil
×
2506
        })
×
2507
        if err != nil {
×
2508
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2509
        }
×
2510

2511
        return edgePoints, nil
×
2512
}
2513

2514
// PruneTip returns the block height and hash of the latest block that has been
2515
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2516
// to tell if the graph is currently in sync with the current best known UTXO
2517
// state.
2518
//
2519
// NOTE: part of the V1Store interface.
2520
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2521
        var (
×
2522
                ctx       = context.TODO()
×
2523
                tipHash   chainhash.Hash
×
2524
                tipHeight uint32
×
2525
        )
×
2526
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2527
                pruneTip, err := db.GetPruneTip(ctx)
×
2528
                if errors.Is(err, sql.ErrNoRows) {
×
2529
                        return ErrGraphNeverPruned
×
2530
                } else if err != nil {
×
2531
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2532
                }
×
2533

2534
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2535
                tipHeight = uint32(pruneTip.BlockHeight)
×
2536

×
2537
                return nil
×
2538
        }, sqldb.NoOpReset)
2539
        if err != nil {
×
2540
                return nil, 0, err
×
2541
        }
×
2542

2543
        return &tipHash, tipHeight, nil
×
2544
}
2545

2546
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2547
//
2548
// NOTE: this prunes nodes across protocol versions. It will never prune the
2549
// source nodes.
2550
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2551
        db SQLQueries) ([]route.Vertex, error) {
×
2552

×
2553
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2554
        if err != nil {
×
2555
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2556
                        "nodes: %w", err)
×
2557
        }
×
2558

2559
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2560
        for i, nodeKey := range nodeKeys {
×
2561
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2562
                if err != nil {
×
2563
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2564
                                "from bytes: %w", err)
×
2565
                }
×
2566

2567
                prunedNodes[i] = pub
×
2568
        }
2569

2570
        return prunedNodes, nil
×
2571
}
2572

2573
// DisconnectBlockAtHeight is used to indicate that the block specified
2574
// by the passed height has been disconnected from the main chain. This
2575
// will "rewind" the graph back to the height below, deleting channels
2576
// that are no longer confirmed from the graph. The prune log will be
2577
// set to the last prune height valid for the remaining chain.
2578
// Channels that were removed from the graph resulting from the
2579
// disconnected block are returned.
2580
//
2581
// NOTE: part of the V1Store interface.
2582
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2583
        []*models.ChannelEdgeInfo, error) {
×
2584

×
2585
        ctx := context.TODO()
×
2586

×
2587
        var (
×
2588
                // Every channel having a ShortChannelID starting at 'height'
×
2589
                // will no longer be confirmed.
×
2590
                startShortChanID = lnwire.ShortChannelID{
×
2591
                        BlockHeight: height,
×
2592
                }
×
2593

×
2594
                // Delete everything after this height from the db up until the
×
2595
                // SCID alias range.
×
2596
                endShortChanID = aliasmgr.StartingAlias
×
2597

×
2598
                removedChans []*models.ChannelEdgeInfo
×
2599

×
2600
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2601
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2602
        )
×
2603

×
2604
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2605
                rows, err := db.GetChannelsBySCIDRange(
×
2606
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2607
                                StartScid: chanIDStart,
×
2608
                                EndScid:   chanIDEnd,
×
2609
                        },
×
2610
                )
×
2611
                if err != nil {
×
2612
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2613
                }
×
2614

NEW
2615
                if len(rows) == 0 {
×
NEW
2616
                        // No channels to disconnect, but still clean up prune
×
NEW
2617
                        // log.
×
NEW
2618
                        return db.DeletePruneLogEntriesInRange(
×
NEW
2619
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
NEW
2620
                                        StartHeight: int64(height),
×
NEW
2621
                                        EndHeight: int64(
×
NEW
2622
                                                endShortChanID.BlockHeight,
×
NEW
2623
                                        ),
×
NEW
2624
                                },
×
2625
                        )
×
NEW
2626
                }
×
2627

2628
                // Batch build all channel edges for disconnection.
NEW
2629
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
NEW
2630
                        ctx, s.cfg, db, rows,
×
NEW
2631
                )
×
NEW
2632
                if err != nil {
×
NEW
2633
                        return err
×
UNCOV
2634
                }
×
2635

NEW
2636
                removedChans = channelEdges
×
NEW
2637

×
2638
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2639
                if err != nil {
×
2640
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2641
                }
×
2642

2643
                return db.DeletePruneLogEntriesInRange(
×
2644
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2645
                                StartHeight: int64(height),
×
2646
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2647
                        },
×
2648
                )
×
2649
        }, func() {
×
2650
                removedChans = nil
×
2651
        })
×
2652
        if err != nil {
×
2653
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2654
                        "height: %w", err)
×
2655
        }
×
2656

2657
        for _, channel := range removedChans {
×
2658
                s.rejectCache.remove(channel.ChannelID)
×
2659
                s.chanCache.remove(channel.ChannelID)
×
2660
        }
×
2661

2662
        return removedChans, nil
×
2663
}
2664

2665
// AddEdgeProof sets the proof of an existing edge in the graph database.
2666
//
2667
// NOTE: part of the V1Store interface.
2668
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2669
        proof *models.ChannelAuthProof) error {
×
2670

×
2671
        var (
×
2672
                ctx       = context.TODO()
×
2673
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2674
        )
×
2675

×
2676
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2677
                res, err := db.AddV1ChannelProof(
×
2678
                        ctx, sqlc.AddV1ChannelProofParams{
×
2679
                                Scid:              scidBytes,
×
2680
                                Node1Signature:    proof.NodeSig1Bytes,
×
2681
                                Node2Signature:    proof.NodeSig2Bytes,
×
2682
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2683
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2684
                        },
×
2685
                )
×
2686
                if err != nil {
×
2687
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2688
                }
×
2689

2690
                n, err := res.RowsAffected()
×
2691
                if err != nil {
×
2692
                        return err
×
2693
                }
×
2694

2695
                if n == 0 {
×
2696
                        return fmt.Errorf("no rows affected when adding edge "+
×
2697
                                "proof for SCID %v", scid)
×
2698
                } else if n > 1 {
×
2699
                        return fmt.Errorf("multiple rows affected when adding "+
×
2700
                                "edge proof for SCID %v: %d rows affected",
×
2701
                                scid, n)
×
2702
                }
×
2703

2704
                return nil
×
2705
        }, sqldb.NoOpReset)
2706
        if err != nil {
×
2707
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2708
        }
×
2709

2710
        return nil
×
2711
}
2712

2713
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2714
// that we can ignore channel announcements that we know to be closed without
2715
// having to validate them and fetch a block.
2716
//
2717
// NOTE: part of the V1Store interface.
2718
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2719
        var (
×
2720
                ctx     = context.TODO()
×
2721
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2722
        )
×
2723

×
2724
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2725
                return db.InsertClosedChannel(ctx, chanIDB)
×
2726
        }, sqldb.NoOpReset)
×
2727
}
2728

2729
// IsClosedScid checks whether a channel identified by the passed in scid is
2730
// closed. This helps avoid having to perform expensive validation checks.
2731
//
2732
// NOTE: part of the V1Store interface.
2733
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2734
        var (
×
2735
                ctx      = context.TODO()
×
2736
                isClosed bool
×
2737
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2738
        )
×
2739
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2740
                var err error
×
2741
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2742
                if err != nil {
×
2743
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2744
                                err)
×
2745
                }
×
2746

2747
                return nil
×
2748
        }, sqldb.NoOpReset)
2749
        if err != nil {
×
2750
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2751
                        err)
×
2752
        }
×
2753

2754
        return isClosed, nil
×
2755
}
2756

2757
// GraphSession will provide the call-back with access to a NodeTraverser
2758
// instance which can be used to perform queries against the channel graph.
2759
//
2760
// NOTE: part of the V1Store interface.
2761
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
2762
        reset func()) error {
×
2763

×
2764
        var ctx = context.TODO()
×
2765

×
2766
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2767
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2768
        }, reset)
×
2769
}
2770

2771
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2772
// read only transaction for a consistent view of the graph.
2773
type sqlNodeTraverser struct {
2774
        db    SQLQueries
2775
        chain chainhash.Hash
2776
}
2777

2778
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2779
// NodeTraverser interface.
2780
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2781

2782
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2783
func newSQLNodeTraverser(db SQLQueries,
2784
        chain chainhash.Hash) *sqlNodeTraverser {
×
2785

×
2786
        return &sqlNodeTraverser{
×
2787
                db:    db,
×
2788
                chain: chain,
×
2789
        }
×
2790
}
×
2791

2792
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2793
// node.
2794
//
2795
// NOTE: Part of the NodeTraverser interface.
2796
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2797
        cb func(channel *DirectedChannel) error, _ func()) error {
×
2798

×
2799
        ctx := context.TODO()
×
2800

×
2801
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2802
}
×
2803

2804
// FetchNodeFeatures returns the features of the given node. If the node is
2805
// unknown, assume no additional features are supported.
2806
//
2807
// NOTE: Part of the NodeTraverser interface.
2808
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2809
        *lnwire.FeatureVector, error) {
×
2810

×
2811
        ctx := context.TODO()
×
2812

×
2813
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2814
}
×
2815

2816
// forEachNodeDirectedChannel iterates through all channels of a given
2817
// node, executing the passed callback on the directed edge representing the
2818
// channel and its incoming policy. If the node is not found, no error is
2819
// returned.
2820
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2821
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2822

×
2823
        toNodeCallback := func() route.Vertex {
×
2824
                return nodePub
×
2825
        }
×
2826

2827
        dbID, err := db.GetNodeIDByPubKey(
×
2828
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2829
                        Version: int16(ProtocolV1),
×
2830
                        PubKey:  nodePub[:],
×
2831
                },
×
2832
        )
×
2833
        if errors.Is(err, sql.ErrNoRows) {
×
2834
                return nil
×
2835
        } else if err != nil {
×
2836
                return fmt.Errorf("unable to fetch node: %w", err)
×
2837
        }
×
2838

2839
        rows, err := db.ListChannelsByNodeID(
×
2840
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2841
                        Version: int16(ProtocolV1),
×
2842
                        NodeID1: dbID,
×
2843
                },
×
2844
        )
×
2845
        if err != nil {
×
2846
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2847
        }
×
2848

2849
        // Exit early if there are no channels for this node so we don't
2850
        // do the unnecessary feature fetching.
2851
        if len(rows) == 0 {
×
2852
                return nil
×
2853
        }
×
2854

2855
        features, err := getNodeFeatures(ctx, db, dbID)
×
2856
        if err != nil {
×
2857
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2858
        }
×
2859

2860
        for _, row := range rows {
×
2861
                node1, node2, err := buildNodeVertices(
×
2862
                        row.Node1Pubkey, row.Node2Pubkey,
×
2863
                )
×
2864
                if err != nil {
×
2865
                        return fmt.Errorf("unable to build node vertices: %w",
×
2866
                                err)
×
2867
                }
×
2868

2869
                edge := buildCacheableChannelInfo(
×
2870
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
2871
                        node1, node2,
×
2872
                )
×
2873

×
2874
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2875
                if err != nil {
×
2876
                        return err
×
2877
                }
×
2878

2879
                p1, p2, err := buildCachedChanPolicies(
×
2880
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
2881
                )
×
2882
                if err != nil {
×
2883
                        return err
×
2884
                }
×
2885

2886
                // Determine the outgoing and incoming policy for this
2887
                // channel and node combo.
2888
                outPolicy, inPolicy := p1, p2
×
2889
                if p1 != nil && node2 == nodePub {
×
2890
                        outPolicy, inPolicy = p2, p1
×
2891
                } else if p2 != nil && node1 != nodePub {
×
2892
                        outPolicy, inPolicy = p2, p1
×
2893
                }
×
2894

2895
                var cachedInPolicy *models.CachedEdgePolicy
×
2896
                if inPolicy != nil {
×
2897
                        cachedInPolicy = inPolicy
×
2898
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2899
                        cachedInPolicy.ToNodeFeatures = features
×
2900
                }
×
2901

2902
                directedChannel := &DirectedChannel{
×
2903
                        ChannelID:    edge.ChannelID,
×
2904
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2905
                        OtherNode:    edge.NodeKey2Bytes,
×
2906
                        Capacity:     edge.Capacity,
×
2907
                        OutPolicySet: outPolicy != nil,
×
2908
                        InPolicy:     cachedInPolicy,
×
2909
                }
×
2910
                if outPolicy != nil {
×
2911
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2912
                                directedChannel.InboundFee = fee
×
2913
                        })
×
2914
                }
2915

2916
                if nodePub == edge.NodeKey2Bytes {
×
2917
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2918
                }
×
2919

2920
                if err := cb(directedChannel); err != nil {
×
2921
                        return err
×
2922
                }
×
2923
        }
2924

2925
        return nil
×
2926
}
2927

2928
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2929
// and executes the provided callback for each node. It does so via pagination
2930
// along with batch loading of the node feature bits.
2931
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
2932
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
2933
                features *lnwire.FeatureVector) error) error {
×
2934

×
2935
        handleNode := func(_ context.Context,
×
2936
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
2937
                featureBits map[int64][]int) error {
×
2938

×
2939
                fv := lnwire.EmptyFeatureVector()
×
2940
                if features, exists := featureBits[dbNode.ID]; exists {
×
2941
                        for _, bit := range features {
×
2942
                                fv.Set(lnwire.FeatureBit(bit))
×
2943
                        }
×
2944
                }
2945

2946
                var pub route.Vertex
×
2947
                copy(pub[:], dbNode.PubKey)
×
2948

×
2949
                return processNode(dbNode.ID, pub, fv)
×
2950
        }
2951

2952
        queryFunc := func(ctx context.Context, lastID int64,
×
2953
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
2954

×
2955
                return db.ListNodeIDsAndPubKeys(
×
2956
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
2957
                                Version: int16(ProtocolV1),
×
2958
                                ID:      lastID,
×
2959
                                Limit:   limit,
×
2960
                        },
×
2961
                )
×
2962
        }
×
2963

2964
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
2965
                return row.ID
×
2966
        }
×
2967

2968
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
2969
                return node.ID, nil
×
2970
        }
×
2971

2972
        batchQueryFunc := func(ctx context.Context,
×
2973
                nodeIDs []int64) (map[int64][]int, error) {
×
2974

×
2975
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
2976
        }
×
2977

2978
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
2979
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
2980
                batchQueryFunc, handleNode,
×
2981
        )
×
2982
}
2983

2984
// forEachNodeChannel iterates through all channels of a node, executing
2985
// the passed callback on each. The call-back is provided with the channel's
2986
// edge information, the outgoing policy and the incoming policy for the
2987
// channel and node combo.
2988
func forEachNodeChannel(ctx context.Context, db SQLQueries,
2989
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
2990
                *models.ChannelEdgePolicy,
2991
                *models.ChannelEdgePolicy) error) error {
×
2992

×
2993
        // Get all the V1 channels for this node.
×
2994
        rows, err := db.ListChannelsByNodeID(
×
2995
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2996
                        Version: int16(ProtocolV1),
×
2997
                        NodeID1: id,
×
2998
                },
×
2999
        )
×
3000
        if err != nil {
×
3001
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3002
        }
×
3003

3004
        // Collect all the channel and policy IDs.
3005
        var (
×
3006
                chanIDs   = make([]int64, 0, len(rows))
×
3007
                policyIDs = make([]int64, 0, 2*len(rows))
×
3008
        )
×
3009
        for _, row := range rows {
×
3010
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3011

×
3012
                if row.Policy1ID.Valid {
×
3013
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3014
                }
×
3015
                if row.Policy2ID.Valid {
×
3016
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3017
                }
×
3018
        }
3019

3020
        batchData, err := batchLoadChannelData(
×
3021
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3022
        )
×
3023
        if err != nil {
×
3024
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3025
        }
×
3026

3027
        // Call the call-back for each channel and its known policies.
3028
        for _, row := range rows {
×
3029
                node1, node2, err := buildNodeVertices(
×
3030
                        row.Node1Pubkey, row.Node2Pubkey,
×
3031
                )
×
3032
                if err != nil {
×
3033
                        return fmt.Errorf("unable to build node vertices: %w",
×
3034
                                err)
×
3035
                }
×
3036

3037
                edge, err := buildEdgeInfoWithBatchData(
×
3038
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3039
                        batchData,
×
3040
                )
×
3041
                if err != nil {
×
3042
                        return fmt.Errorf("unable to build channel info: %w",
×
3043
                                err)
×
3044
                }
×
3045

3046
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3047
                if err != nil {
×
3048
                        return fmt.Errorf("unable to extract channel "+
×
3049
                                "policies: %w", err)
×
3050
                }
×
3051

3052
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3053
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3054
                )
×
3055
                if err != nil {
×
3056
                        return fmt.Errorf("unable to build channel "+
×
3057
                                "policies: %w", err)
×
3058
                }
×
3059

3060
                // Determine the outgoing and incoming policy for this
3061
                // channel and node combo.
3062
                p1ToNode := row.GraphChannel.NodeID2
×
3063
                p2ToNode := row.GraphChannel.NodeID1
×
3064
                outPolicy, inPolicy := p1, p2
×
3065
                if (p1 != nil && p1ToNode == id) ||
×
3066
                        (p2 != nil && p2ToNode != id) {
×
3067

×
3068
                        outPolicy, inPolicy = p2, p1
×
3069
                }
×
3070

3071
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3072
                        return err
×
3073
                }
×
3074
        }
3075

3076
        return nil
×
3077
}
3078

3079
// updateChanEdgePolicy upserts the channel policy info we have stored for
3080
// a channel we already know of.
3081
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3082
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3083
        error) {
×
3084

×
3085
        var (
×
3086
                node1Pub, node2Pub route.Vertex
×
3087
                isNode1            bool
×
3088
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3089
        )
×
3090

×
3091
        // Check that this edge policy refers to a channel that we already
×
3092
        // know of. We do this explicitly so that we can return the appropriate
×
3093
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3094
        // abort the transaction which would abort the entire batch.
×
3095
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3096
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3097
                        Scid:    chanIDB,
×
3098
                        Version: int16(ProtocolV1),
×
3099
                },
×
3100
        )
×
3101
        if errors.Is(err, sql.ErrNoRows) {
×
3102
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3103
        } else if err != nil {
×
3104
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3105
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3106
        }
×
3107

3108
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3109
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3110

×
3111
        // Figure out which node this edge is from.
×
3112
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3113
        nodeID := dbChan.NodeID1
×
3114
        if !isNode1 {
×
3115
                nodeID = dbChan.NodeID2
×
3116
        }
×
3117

3118
        var (
×
3119
                inboundBase sql.NullInt64
×
3120
                inboundRate sql.NullInt64
×
3121
        )
×
3122
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3123
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3124
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3125
        })
×
3126

3127
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3128
                Version:     int16(ProtocolV1),
×
3129
                ChannelID:   dbChan.ID,
×
3130
                NodeID:      nodeID,
×
3131
                Timelock:    int32(edge.TimeLockDelta),
×
3132
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3133
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3134
                MinHtlcMsat: int64(edge.MinHTLC),
×
3135
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3136
                Disabled: sql.NullBool{
×
3137
                        Valid: true,
×
3138
                        Bool:  edge.IsDisabled(),
×
3139
                },
×
3140
                MaxHtlcMsat: sql.NullInt64{
×
3141
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3142
                        Int64: int64(edge.MaxHTLC),
×
3143
                },
×
3144
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3145
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3146
                InboundBaseFeeMsat:      inboundBase,
×
3147
                InboundFeeRateMilliMsat: inboundRate,
×
3148
                Signature:               edge.SigBytes,
×
3149
        })
×
3150
        if err != nil {
×
3151
                return node1Pub, node2Pub, isNode1,
×
3152
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3153
        }
×
3154

3155
        // Convert the flat extra opaque data into a map of TLV types to
3156
        // values.
3157
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3158
        if err != nil {
×
3159
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3160
                        "marshal extra opaque data: %w", err)
×
3161
        }
×
3162

3163
        // Update the channel policy's extra signed fields.
3164
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3165
        if err != nil {
×
3166
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3167
                        "policy extra TLVs: %w", err)
×
3168
        }
×
3169

3170
        return node1Pub, node2Pub, isNode1, nil
×
3171
}
3172

3173
// getNodeByPubKey attempts to look up a target node by its public key.
3174
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3175
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3176

×
3177
        dbNode, err := db.GetNodeByPubKey(
×
3178
                ctx, sqlc.GetNodeByPubKeyParams{
×
3179
                        Version: int16(ProtocolV1),
×
3180
                        PubKey:  pubKey[:],
×
3181
                },
×
3182
        )
×
3183
        if errors.Is(err, sql.ErrNoRows) {
×
3184
                return 0, nil, ErrGraphNodeNotFound
×
3185
        } else if err != nil {
×
3186
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3187
        }
×
3188

NEW
3189
        node, err := buildNode(ctx, db, dbNode)
×
3190
        if err != nil {
×
3191
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3192
        }
×
3193

3194
        return dbNode.ID, node, nil
×
3195
}
3196

3197
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3198
// provided parameters.
3199
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3200
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3201

×
3202
        return &models.CachedEdgeInfo{
×
3203
                ChannelID:     byteOrder.Uint64(scid),
×
3204
                NodeKey1Bytes: node1Pub,
×
3205
                NodeKey2Bytes: node2Pub,
×
3206
                Capacity:      btcutil.Amount(capacity),
×
3207
        }
×
3208
}
×
3209

3210
// buildNode constructs a LightningNode instance from the given database node
3211
// record. The node's features, addresses and extra signed fields are also
3212
// fetched from the database and set on the node.
3213
func buildNode(ctx context.Context, db SQLQueries,
NEW
3214
        dbNode sqlc.GraphNode) (*models.LightningNode, error) {
×
3215

×
3216
        // NOTE: buildNode is only used to load the data for a single node, and
×
3217
        // so no paged queries will be performed. This means that it's ok to
×
3218
        // used pass in default config values here.
×
3219
        cfg := sqldb.DefaultQueryConfig()
×
3220

×
3221
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3222
        if err != nil {
×
3223
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3224
                        err)
×
3225
        }
×
3226

3227
        return buildNodeWithBatchData(dbNode, data)
×
3228
}
3229

3230
// buildNodeWithBatchData builds a models.LightningNode instance
3231
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3232
// features/addresses/extra fields, then the corresponding fields are expected
3233
// to be present in the batchNodeData.
3234
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3235
        batchData *batchNodeData) (*models.LightningNode, error) {
×
3236

×
3237
        if dbNode.Version != int16(ProtocolV1) {
×
3238
                return nil, fmt.Errorf("unsupported node version: %d",
×
3239
                        dbNode.Version)
×
3240
        }
×
3241

3242
        var pub [33]byte
×
3243
        copy(pub[:], dbNode.PubKey)
×
3244

×
3245
        node := &models.LightningNode{
×
3246
                PubKeyBytes: pub,
×
3247
                Features:    lnwire.EmptyFeatureVector(),
×
3248
                LastUpdate:  time.Unix(0, 0),
×
3249
        }
×
3250

×
3251
        if len(dbNode.Signature) == 0 {
×
3252
                return node, nil
×
3253
        }
×
3254

3255
        node.HaveNodeAnnouncement = true
×
3256
        node.AuthSigBytes = dbNode.Signature
×
3257
        node.Alias = dbNode.Alias.String
×
3258
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3259

×
3260
        var err error
×
3261
        if dbNode.Color.Valid {
×
3262
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3263
                if err != nil {
×
3264
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3265
                                err)
×
3266
                }
×
3267
        }
3268

3269
        // Use preloaded features.
3270
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3271
                fv := lnwire.EmptyFeatureVector()
×
3272
                for _, bit := range features {
×
3273
                        fv.Set(lnwire.FeatureBit(bit))
×
3274
                }
×
3275
                node.Features = fv
×
3276
        }
3277

3278
        // Use preloaded addresses.
3279
        addresses, exists := batchData.addresses[dbNode.ID]
×
3280
        if exists && len(addresses) > 0 {
×
3281
                node.Addresses, err = buildNodeAddresses(addresses)
×
3282
                if err != nil {
×
3283
                        return nil, fmt.Errorf("unable to build addresses "+
×
3284
                                "for node(%d): %w", dbNode.ID, err)
×
3285
                }
×
3286
        }
3287

3288
        // Use preloaded extra fields.
3289
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3290
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3291
                if err != nil {
×
3292
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3293
                                "signed fields: %w", err)
×
3294
                }
×
3295
                if len(recs) != 0 {
×
3296
                        node.ExtraOpaqueData = recs
×
3297
                }
×
3298
        }
3299

3300
        return node, nil
×
3301
}
3302

3303
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3304
// with the preloaded data, and executes the provided callback for each node.
3305
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3306
        db SQLQueries, nodes []sqlc.GraphNode,
3307
        cb func(dbID int64, node *models.LightningNode) error) error {
×
3308

×
3309
        // Extract node IDs for batch loading.
×
3310
        nodeIDs := make([]int64, len(nodes))
×
3311
        for i, node := range nodes {
×
3312
                nodeIDs[i] = node.ID
×
3313
        }
×
3314

3315
        // Batch load all related data for this page.
3316
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3317
        if err != nil {
×
3318
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3319
        }
×
3320

3321
        for _, dbNode := range nodes {
×
NEW
3322
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
3323
                if err != nil {
×
3324
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3325
                                dbNode.ID, err)
×
3326
                }
×
3327

3328
                if err := cb(dbNode.ID, node); err != nil {
×
3329
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3330
                                dbNode.ID, err)
×
3331
                }
×
3332
        }
3333

3334
        return nil
×
3335
}
3336

3337
// getNodeFeatures fetches the feature bits and constructs the feature vector
3338
// for a node with the given DB ID.
3339
func getNodeFeatures(ctx context.Context, db SQLQueries,
3340
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3341

×
3342
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3343
        if err != nil {
×
3344
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3345
                        nodeID, err)
×
3346
        }
×
3347

3348
        features := lnwire.EmptyFeatureVector()
×
3349
        for _, feature := range rows {
×
3350
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3351
        }
×
3352

3353
        return features, nil
×
3354
}
3355

3356
// upsertNode upserts the node record into the database. If the node already
3357
// exists, then the node's information is updated. If the node doesn't exist,
3358
// then a new node is created. The node's features, addresses and extra TLV
3359
// types are also updated. The node's DB ID is returned.
3360
func upsertNode(ctx context.Context, db SQLQueries,
3361
        node *models.LightningNode) (int64, error) {
×
3362

×
3363
        params := sqlc.UpsertNodeParams{
×
3364
                Version: int16(ProtocolV1),
×
3365
                PubKey:  node.PubKeyBytes[:],
×
3366
        }
×
3367

×
3368
        if node.HaveNodeAnnouncement {
×
3369
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3370
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3371
                params.Alias = sqldb.SQLStr(node.Alias)
×
3372
                params.Signature = node.AuthSigBytes
×
3373
        }
×
3374

3375
        nodeID, err := db.UpsertNode(ctx, params)
×
3376
        if err != nil {
×
3377
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3378
                        err)
×
3379
        }
×
3380

3381
        // We can exit here if we don't have the announcement yet.
3382
        if !node.HaveNodeAnnouncement {
×
3383
                return nodeID, nil
×
3384
        }
×
3385

3386
        // Update the node's features.
3387
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3388
        if err != nil {
×
3389
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3390
        }
×
3391

3392
        // Update the node's addresses.
3393
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3394
        if err != nil {
×
3395
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3396
        }
×
3397

3398
        // Convert the flat extra opaque data into a map of TLV types to
3399
        // values.
3400
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3401
        if err != nil {
×
3402
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3403
                        err)
×
3404
        }
×
3405

3406
        // Update the node's extra signed fields.
3407
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3408
        if err != nil {
×
3409
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3410
        }
×
3411

3412
        return nodeID, nil
×
3413
}
3414

3415
// upsertNodeFeatures updates the node's features node_features table. This
3416
// includes deleting any feature bits no longer present and inserting any new
3417
// feature bits. If the feature bit does not yet exist in the features table,
3418
// then an entry is created in that table first.
3419
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3420
        features *lnwire.FeatureVector) error {
×
3421

×
3422
        // Get any existing features for the node.
×
3423
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3424
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3425
                return err
×
3426
        }
×
3427

3428
        // Copy the nodes latest set of feature bits.
3429
        newFeatures := make(map[int32]struct{})
×
3430
        if features != nil {
×
3431
                for feature := range features.Features() {
×
3432
                        newFeatures[int32(feature)] = struct{}{}
×
3433
                }
×
3434
        }
3435

3436
        // For any current feature that already exists in the DB, remove it from
3437
        // the in-memory map. For any existing feature that does not exist in
3438
        // the in-memory map, delete it from the database.
3439
        for _, feature := range existingFeatures {
×
3440
                // The feature is still present, so there are no updates to be
×
3441
                // made.
×
3442
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3443
                        delete(newFeatures, feature.FeatureBit)
×
3444
                        continue
×
3445
                }
3446

3447
                // The feature is no longer present, so we remove it from the
3448
                // database.
3449
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3450
                        NodeID:     nodeID,
×
3451
                        FeatureBit: feature.FeatureBit,
×
3452
                })
×
3453
                if err != nil {
×
3454
                        return fmt.Errorf("unable to delete node(%d) "+
×
3455
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3456
                                err)
×
3457
                }
×
3458
        }
3459

3460
        // Any remaining entries in newFeatures are new features that need to be
3461
        // added to the database for the first time.
3462
        for feature := range newFeatures {
×
3463
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3464
                        NodeID:     nodeID,
×
3465
                        FeatureBit: feature,
×
3466
                })
×
3467
                if err != nil {
×
3468
                        return fmt.Errorf("unable to insert node(%d) "+
×
3469
                                "feature(%v): %w", nodeID, feature, err)
×
3470
                }
×
3471
        }
3472

3473
        return nil
×
3474
}
3475

3476
// fetchNodeFeatures fetches the features for a node with the given public key.
3477
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3478
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3479

×
3480
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3481
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3482
                        PubKey:  nodePub[:],
×
3483
                        Version: int16(ProtocolV1),
×
3484
                },
×
3485
        )
×
3486
        if err != nil {
×
3487
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3488
                        nodePub, err)
×
3489
        }
×
3490

3491
        features := lnwire.EmptyFeatureVector()
×
3492
        for _, bit := range rows {
×
3493
                features.Set(lnwire.FeatureBit(bit))
×
3494
        }
×
3495

3496
        return features, nil
×
3497
}
3498

3499
// dbAddressType is an enum type that represents the different address types
3500
// that we store in the node_addresses table. The address type determines how
3501
// the address is to be serialised/deserialize.
3502
type dbAddressType uint8
3503

3504
const (
3505
        addressTypeIPv4   dbAddressType = 1
3506
        addressTypeIPv6   dbAddressType = 2
3507
        addressTypeTorV2  dbAddressType = 3
3508
        addressTypeTorV3  dbAddressType = 4
3509
        addressTypeOpaque dbAddressType = math.MaxInt8
3510
)
3511

3512
// upsertNodeAddresses updates the node's addresses in the database. This
3513
// includes deleting any existing addresses and inserting the new set of
3514
// addresses. The deletion is necessary since the ordering of the addresses may
3515
// change, and we need to ensure that the database reflects the latest set of
3516
// addresses so that at the time of reconstructing the node announcement, the
3517
// order is preserved and the signature over the message remains valid.
3518
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3519
        addresses []net.Addr) error {
×
3520

×
3521
        // Delete any existing addresses for the node. This is required since
×
3522
        // even if the new set of addresses is the same, the ordering may have
×
3523
        // changed for a given address type.
×
3524
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3525
        if err != nil {
×
3526
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3527
                        nodeID, err)
×
3528
        }
×
3529

3530
        // Copy the nodes latest set of addresses.
3531
        newAddresses := map[dbAddressType][]string{
×
3532
                addressTypeIPv4:   {},
×
3533
                addressTypeIPv6:   {},
×
3534
                addressTypeTorV2:  {},
×
3535
                addressTypeTorV3:  {},
×
3536
                addressTypeOpaque: {},
×
3537
        }
×
3538
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3539
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3540
        }
×
3541

3542
        for _, address := range addresses {
×
3543
                switch addr := address.(type) {
×
3544
                case *net.TCPAddr:
×
3545
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3546
                                addAddr(addressTypeIPv4, addr)
×
3547
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3548
                                addAddr(addressTypeIPv6, addr)
×
3549
                        } else {
×
3550
                                return fmt.Errorf("unhandled IP address: %v",
×
3551
                                        addr)
×
3552
                        }
×
3553

3554
                case *tor.OnionAddr:
×
3555
                        switch len(addr.OnionService) {
×
3556
                        case tor.V2Len:
×
3557
                                addAddr(addressTypeTorV2, addr)
×
3558
                        case tor.V3Len:
×
3559
                                addAddr(addressTypeTorV3, addr)
×
3560
                        default:
×
3561
                                return fmt.Errorf("invalid length for a tor " +
×
3562
                                        "address")
×
3563
                        }
3564

3565
                case *lnwire.OpaqueAddrs:
×
3566
                        addAddr(addressTypeOpaque, addr)
×
3567

3568
                default:
×
3569
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3570
                }
3571
        }
3572

3573
        // Any remaining entries in newAddresses are new addresses that need to
3574
        // be added to the database for the first time.
3575
        for addrType, addrList := range newAddresses {
×
3576
                for position, addr := range addrList {
×
3577
                        err := db.InsertNodeAddress(
×
3578
                                ctx, sqlc.InsertNodeAddressParams{
×
3579
                                        NodeID:   nodeID,
×
3580
                                        Type:     int16(addrType),
×
3581
                                        Address:  addr,
×
3582
                                        Position: int32(position),
×
3583
                                },
×
3584
                        )
×
3585
                        if err != nil {
×
3586
                                return fmt.Errorf("unable to insert "+
×
3587
                                        "node(%d) address(%v): %w", nodeID,
×
3588
                                        addr, err)
×
3589
                        }
×
3590
                }
3591
        }
3592

3593
        return nil
×
3594
}
3595

3596
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3597
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
3598
        error) {
×
3599

×
3600
        // GetNodeAddresses ensures that the addresses for a given type are
×
3601
        // returned in the same order as they were inserted.
×
3602
        rows, err := db.GetNodeAddresses(ctx, id)
×
3603
        if err != nil {
×
3604
                return nil, err
×
3605
        }
×
3606

3607
        addresses := make([]net.Addr, 0, len(rows))
×
3608
        for _, row := range rows {
×
3609
                address := row.Address
×
3610

×
3611
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
3612
                if err != nil {
×
3613
                        return nil, fmt.Errorf("unable to parse address "+
×
3614
                                "for node(%d): %v: %w", id, address, err)
×
3615
                }
×
3616

3617
                addresses = append(addresses, addr)
×
3618
        }
3619

3620
        // If we have no addresses, then we'll return nil instead of an
3621
        // empty slice.
3622
        if len(addresses) == 0 {
×
3623
                addresses = nil
×
3624
        }
×
3625

3626
        return addresses, nil
×
3627
}
3628

3629
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3630
// database. This includes updating any existing types, inserting any new types,
3631
// and deleting any types that are no longer present.
3632
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3633
        nodeID int64, extraFields map[uint64][]byte) error {
×
3634

×
3635
        // Get any existing extra signed fields for the node.
×
3636
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3637
        if err != nil {
×
3638
                return err
×
3639
        }
×
3640

3641
        // Make a lookup map of the existing field types so that we can use it
3642
        // to keep track of any fields we should delete.
3643
        m := make(map[uint64]bool)
×
3644
        for _, field := range existingFields {
×
3645
                m[uint64(field.Type)] = true
×
3646
        }
×
3647

3648
        // For all the new fields, we'll upsert them and remove them from the
3649
        // map of existing fields.
3650
        for tlvType, value := range extraFields {
×
3651
                err = db.UpsertNodeExtraType(
×
3652
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3653
                                NodeID: nodeID,
×
3654
                                Type:   int64(tlvType),
×
3655
                                Value:  value,
×
3656
                        },
×
3657
                )
×
3658
                if err != nil {
×
3659
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3660
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3661
                }
×
3662

3663
                // Remove the field from the map of existing fields if it was
3664
                // present.
3665
                delete(m, tlvType)
×
3666
        }
3667

3668
        // For all the fields that are left in the map of existing fields, we'll
3669
        // delete them as they are no longer present in the new set of fields.
3670
        for tlvType := range m {
×
3671
                err = db.DeleteExtraNodeType(
×
3672
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3673
                                NodeID: nodeID,
×
3674
                                Type:   int64(tlvType),
×
3675
                        },
×
3676
                )
×
3677
                if err != nil {
×
3678
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3679
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3680
                }
×
3681
        }
3682

3683
        return nil
×
3684
}
3685

3686
// srcNodeInfo holds the information about the source node of the graph.
3687
type srcNodeInfo struct {
3688
        // id is the DB level ID of the source node entry in the "nodes" table.
3689
        id int64
3690

3691
        // pub is the public key of the source node.
3692
        pub route.Vertex
3693
}
3694

3695
// sourceNode returns the DB node ID and pub key of the source node for the
3696
// specified protocol version.
3697
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3698
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3699

×
3700
        s.srcNodeMu.Lock()
×
3701
        defer s.srcNodeMu.Unlock()
×
3702

×
3703
        // If we already have the source node ID and pub key cached, then
×
3704
        // return them.
×
3705
        if info, ok := s.srcNodes[version]; ok {
×
3706
                return info.id, info.pub, nil
×
3707
        }
×
3708

3709
        var pubKey route.Vertex
×
3710

×
3711
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3712
        if err != nil {
×
3713
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3714
                        err)
×
3715
        }
×
3716

3717
        if len(nodes) == 0 {
×
3718
                return 0, pubKey, ErrSourceNodeNotSet
×
3719
        } else if len(nodes) > 1 {
×
3720
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3721
                        "protocol %s found", version)
×
3722
        }
×
3723

3724
        copy(pubKey[:], nodes[0].PubKey)
×
3725

×
3726
        s.srcNodes[version] = &srcNodeInfo{
×
3727
                id:  nodes[0].NodeID,
×
3728
                pub: pubKey,
×
3729
        }
×
3730

×
3731
        return nodes[0].NodeID, pubKey, nil
×
3732
}
3733

3734
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3735
// This then produces a map from TLV type to value. If the input is not a
3736
// valid TLV stream, then an error is returned.
3737
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3738
        r := bytes.NewReader(data)
×
3739

×
3740
        tlvStream, err := tlv.NewStream()
×
3741
        if err != nil {
×
3742
                return nil, err
×
3743
        }
×
3744

3745
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3746
        // pass it into the P2P decoding variant.
3747
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3748
        if err != nil {
×
3749
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3750
        }
×
3751
        if len(parsedTypes) == 0 {
×
3752
                return nil, nil
×
3753
        }
×
3754

3755
        records := make(map[uint64][]byte)
×
3756
        for k, v := range parsedTypes {
×
3757
                records[uint64(k)] = v
×
3758
        }
×
3759

3760
        return records, nil
×
3761
}
3762

3763
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3764
// channel.
3765
type dbChanInfo struct {
3766
        channelID int64
3767
        node1ID   int64
3768
        node2ID   int64
3769
}
3770

3771
// insertChannel inserts a new channel record into the database.
3772
func insertChannel(ctx context.Context, db SQLQueries,
3773
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3774

×
3775
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3776

×
3777
        // Make sure that the channel doesn't already exist. We do this
×
3778
        // explicitly instead of relying on catching a unique constraint error
×
3779
        // because relying on SQL to throw that error would abort the entire
×
3780
        // batch of transactions.
×
3781
        _, err := db.GetChannelBySCID(
×
3782
                ctx, sqlc.GetChannelBySCIDParams{
×
3783
                        Scid:    chanIDB,
×
3784
                        Version: int16(ProtocolV1),
×
3785
                },
×
3786
        )
×
3787
        if err == nil {
×
3788
                return nil, ErrEdgeAlreadyExist
×
3789
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3790
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3791
        }
×
3792

3793
        // Make sure that at least a "shell" entry for each node is present in
3794
        // the nodes table.
3795
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3796
        if err != nil {
×
3797
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3798
        }
×
3799

3800
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3801
        if err != nil {
×
3802
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3803
        }
×
3804

3805
        var capacity sql.NullInt64
×
3806
        if edge.Capacity != 0 {
×
3807
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3808
        }
×
3809

3810
        createParams := sqlc.CreateChannelParams{
×
3811
                Version:     int16(ProtocolV1),
×
3812
                Scid:        chanIDB,
×
3813
                NodeID1:     node1DBID,
×
3814
                NodeID2:     node2DBID,
×
3815
                Outpoint:    edge.ChannelPoint.String(),
×
3816
                Capacity:    capacity,
×
3817
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3818
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3819
        }
×
3820

×
3821
        if edge.AuthProof != nil {
×
3822
                proof := edge.AuthProof
×
3823

×
3824
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3825
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3826
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3827
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3828
        }
×
3829

3830
        // Insert the new channel record.
3831
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3832
        if err != nil {
×
3833
                return nil, err
×
3834
        }
×
3835

3836
        // Insert any channel features.
3837
        for feature := range edge.Features.Features() {
×
3838
                err = db.InsertChannelFeature(
×
3839
                        ctx, sqlc.InsertChannelFeatureParams{
×
3840
                                ChannelID:  dbChanID,
×
3841
                                FeatureBit: int32(feature),
×
3842
                        },
×
3843
                )
×
3844
                if err != nil {
×
3845
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
3846
                                "feature(%v): %w", dbChanID, feature, err)
×
3847
                }
×
3848
        }
3849

3850
        // Finally, insert any extra TLV fields in the channel announcement.
3851
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3852
        if err != nil {
×
3853
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
3854
                        "data: %w", err)
×
3855
        }
×
3856

3857
        for tlvType, value := range extra {
×
3858
                err := db.CreateChannelExtraType(
×
3859
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3860
                                ChannelID: dbChanID,
×
3861
                                Type:      int64(tlvType),
×
3862
                                Value:     value,
×
3863
                        },
×
3864
                )
×
3865
                if err != nil {
×
3866
                        return nil, fmt.Errorf("unable to upsert "+
×
3867
                                "channel(%d) extra signed field(%v): %w",
×
3868
                                edge.ChannelID, tlvType, err)
×
3869
                }
×
3870
        }
3871

3872
        return &dbChanInfo{
×
3873
                channelID: dbChanID,
×
3874
                node1ID:   node1DBID,
×
3875
                node2ID:   node2DBID,
×
3876
        }, nil
×
3877
}
3878

3879
// maybeCreateShellNode checks if a shell node entry exists for the
3880
// given public key. If it does not exist, then a new shell node entry is
3881
// created. The ID of the node is returned. A shell node only has a protocol
3882
// version and public key persisted.
3883
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3884
        pubKey route.Vertex) (int64, error) {
×
3885

×
3886
        dbNode, err := db.GetNodeByPubKey(
×
3887
                ctx, sqlc.GetNodeByPubKeyParams{
×
3888
                        PubKey:  pubKey[:],
×
3889
                        Version: int16(ProtocolV1),
×
3890
                },
×
3891
        )
×
3892
        // The node exists. Return the ID.
×
3893
        if err == nil {
×
3894
                return dbNode.ID, nil
×
3895
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3896
                return 0, err
×
3897
        }
×
3898

3899
        // Otherwise, the node does not exist, so we create a shell entry for
3900
        // it.
3901
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3902
                Version: int16(ProtocolV1),
×
3903
                PubKey:  pubKey[:],
×
3904
        })
×
3905
        if err != nil {
×
3906
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3907
        }
×
3908

3909
        return id, nil
×
3910
}
3911

3912
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3913
// the database. This includes deleting any existing types and then inserting
3914
// the new types.
3915
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3916
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3917

×
3918
        // Delete all existing extra signed fields for the channel policy.
×
3919
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3920
        if err != nil {
×
3921
                return fmt.Errorf("unable to delete "+
×
3922
                        "existing policy extra signed fields for policy %d: %w",
×
3923
                        chanPolicyID, err)
×
3924
        }
×
3925

3926
        // Insert all new extra signed fields for the channel policy.
3927
        for tlvType, value := range extraFields {
×
3928
                err = db.InsertChanPolicyExtraType(
×
3929
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3930
                                ChannelPolicyID: chanPolicyID,
×
3931
                                Type:            int64(tlvType),
×
3932
                                Value:           value,
×
3933
                        },
×
3934
                )
×
3935
                if err != nil {
×
3936
                        return fmt.Errorf("unable to insert "+
×
3937
                                "channel_policy(%d) extra signed field(%v): %w",
×
3938
                                chanPolicyID, tlvType, err)
×
3939
                }
×
3940
        }
3941

3942
        return nil
×
3943
}
3944

3945
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3946
// provided dbChanRow and also fetches any other required information
3947
// to construct the edge info.
3948
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
3949
        chain chainhash.Hash, dbChan sqlc.GraphChannel, node1,
3950
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3951

×
3952
        // NOTE: getAndBuildEdgeInfo is only used to load the data for a single
×
3953
        // edge, and so no paged queries will be performed. This means that
×
3954
        // it's ok to used pass in default config values here.
×
3955
        cfg := sqldb.DefaultQueryConfig()
×
3956

×
3957
        data, err := batchLoadChannelData(ctx, cfg, db, []int64{dbChan.ID}, nil)
×
3958
        if err != nil {
×
3959
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
3960
                        err)
×
3961
        }
×
3962

3963
        return buildEdgeInfoWithBatchData(chain, dbChan, node1, node2, data)
×
3964
}
3965

3966
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
3967
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
3968
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
3969
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
3970

×
3971
        if dbChan.Version != int16(ProtocolV1) {
×
3972
                return nil, fmt.Errorf("unsupported channel version: %d",
×
3973
                        dbChan.Version)
×
3974
        }
×
3975

3976
        // Use pre-loaded features and extras types.
3977
        fv := lnwire.EmptyFeatureVector()
×
3978
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
3979
                for _, bit := range features {
×
3980
                        fv.Set(lnwire.FeatureBit(bit))
×
3981
                }
×
3982
        }
3983

3984
        var extras map[uint64][]byte
×
3985
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
3986
        if exists {
×
3987
                extras = channelExtras
×
3988
        } else {
×
3989
                extras = make(map[uint64][]byte)
×
3990
        }
×
3991

3992
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
3993
        if err != nil {
×
3994
                return nil, err
×
3995
        }
×
3996

3997
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3998
        if err != nil {
×
3999
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4000
                        "fields: %w", err)
×
4001
        }
×
4002
        if recs == nil {
×
4003
                recs = make([]byte, 0)
×
4004
        }
×
4005

4006
        var btcKey1, btcKey2 route.Vertex
×
4007
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4008
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4009

×
4010
        channel := &models.ChannelEdgeInfo{
×
4011
                ChainHash:        chain,
×
4012
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4013
                NodeKey1Bytes:    node1,
×
4014
                NodeKey2Bytes:    node2,
×
4015
                BitcoinKey1Bytes: btcKey1,
×
4016
                BitcoinKey2Bytes: btcKey2,
×
4017
                ChannelPoint:     *op,
×
4018
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4019
                Features:         fv,
×
4020
                ExtraOpaqueData:  recs,
×
4021
        }
×
4022

×
4023
        // We always set all the signatures at the same time, so we can
×
4024
        // safely check if one signature is present to determine if we have the
×
4025
        // rest of the signatures for the auth proof.
×
4026
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4027
                channel.AuthProof = &models.ChannelAuthProof{
×
4028
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4029
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4030
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4031
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4032
                }
×
4033
        }
×
4034

4035
        return channel, nil
×
4036
}
4037

4038
// buildNodeVertices is a helper that converts raw node public keys
4039
// into route.Vertex instances.
4040
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4041
        route.Vertex, error) {
×
4042

×
4043
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4044
        if err != nil {
×
4045
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4046
                        "create vertex from node1 pubkey: %w", err)
×
4047
        }
×
4048

4049
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4050
        if err != nil {
×
4051
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4052
                        "create vertex from node2 pubkey: %w", err)
×
4053
        }
×
4054

4055
        return node1Vertex, node2Vertex, nil
×
4056
}
4057

4058
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4059
// retrieves all the extra info required to build the complete
4060
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4061
// the provided sqlc.GraphChannelPolicy records are nil.
4062
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4063
        dbPol1, dbPol2 *sqlc.GraphChannelPolicy, channelID uint64, node1,
4064
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4065
        *models.ChannelEdgePolicy, error) {
×
4066

×
4067
        if dbPol1 == nil && dbPol2 == nil {
×
4068
                return nil, nil, nil
×
4069
        }
×
4070

4071
        var policyIDs = make([]int64, 0, 2)
×
4072
        if dbPol1 != nil {
×
4073
                policyIDs = append(policyIDs, dbPol1.ID)
×
4074
        }
×
4075
        if dbPol2 != nil {
×
4076
                policyIDs = append(policyIDs, dbPol2.ID)
×
4077
        }
×
4078

4079
        // NOTE: getAndBuildChanPolicies is only used to load the data for
4080
        // a maximum of two policies, and so no paged queries will be
4081
        // performed (unless the page size is one). So it's ok to use
4082
        // the default config values here.
4083
        cfg := sqldb.DefaultQueryConfig()
×
4084

×
4085
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4086
        if err != nil {
×
4087
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4088
                        "data: %w", err)
×
4089
        }
×
4090

4091
        pol1, err := buildChanPolicyWithBatchData(
×
4092
                dbPol1, channelID, node2, batchData,
×
4093
        )
×
4094
        if err != nil {
×
4095
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4096
        }
×
4097

4098
        pol2, err := buildChanPolicyWithBatchData(
×
4099
                dbPol2, channelID, node1, batchData,
×
4100
        )
×
4101
        if err != nil {
×
4102
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4103
        }
×
4104

4105
        return pol1, pol2, nil
×
4106
}
4107

4108
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4109
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4110
// then nil is returned for it.
4111
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4112
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4113
        *models.CachedEdgePolicy, error) {
×
4114

×
4115
        var p1, p2 *models.CachedEdgePolicy
×
4116
        if dbPol1 != nil {
×
4117
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
4118
                if err != nil {
×
4119
                        return nil, nil, err
×
4120
                }
×
4121

4122
                p1 = models.NewCachedPolicy(policy1)
×
4123
        }
4124
        if dbPol2 != nil {
×
4125
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4126
                if err != nil {
×
4127
                        return nil, nil, err
×
4128
                }
×
4129

4130
                p2 = models.NewCachedPolicy(policy2)
×
4131
        }
4132

4133
        return p1, p2, nil
×
4134
}
4135

4136
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4137
// provided sqlc.GraphChannelPolicy and other required information.
4138
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4139
        extras map[uint64][]byte,
4140
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4141

×
4142
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4143
        if err != nil {
×
4144
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4145
                        "fields: %w", err)
×
4146
        }
×
4147

4148
        var inboundFee fn.Option[lnwire.Fee]
×
4149
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4150
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4151

×
4152
                inboundFee = fn.Some(lnwire.Fee{
×
4153
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4154
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4155
                })
×
4156
        }
×
4157

4158
        return &models.ChannelEdgePolicy{
×
4159
                SigBytes:  dbPolicy.Signature,
×
4160
                ChannelID: channelID,
×
4161
                LastUpdate: time.Unix(
×
4162
                        dbPolicy.LastUpdate.Int64, 0,
×
4163
                ),
×
4164
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4165
                        dbPolicy.MessageFlags,
×
4166
                ),
×
4167
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4168
                        dbPolicy.ChannelFlags,
×
4169
                ),
×
4170
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4171
                MinHTLC: lnwire.MilliSatoshi(
×
4172
                        dbPolicy.MinHtlcMsat,
×
4173
                ),
×
4174
                MaxHTLC: lnwire.MilliSatoshi(
×
4175
                        dbPolicy.MaxHtlcMsat.Int64,
×
4176
                ),
×
4177
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4178
                        dbPolicy.BaseFeeMsat,
×
4179
                ),
×
4180
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4181
                ToNode:                    toNode,
×
4182
                InboundFee:                inboundFee,
×
4183
                ExtraOpaqueData:           recs,
×
4184
        }, nil
×
4185
}
4186

4187
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4188
// row which is expected to be a sqlc type that contains channel policy
4189
// information. It returns two policies, which may be nil if the policy
4190
// information is not present in the row.
4191
//
4192
//nolint:ll,dupl,funlen
4193
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4194
        *sqlc.GraphChannelPolicy, error) {
×
4195

×
4196
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4197
        switch r := row.(type) {
×
4198
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4199
                if r.Policy1Timelock.Valid {
×
4200
                        policy1 = &sqlc.GraphChannelPolicy{
×
4201
                                Timelock:                r.Policy1Timelock.Int32,
×
4202
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4203
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4204
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4205
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4206
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4207
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4208
                                Disabled:                r.Policy1Disabled,
×
4209
                                MessageFlags:            r.Policy1MessageFlags,
×
4210
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4211
                        }
×
4212
                }
×
4213
                if r.Policy2Timelock.Valid {
×
4214
                        policy2 = &sqlc.GraphChannelPolicy{
×
4215
                                Timelock:                r.Policy2Timelock.Int32,
×
4216
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4217
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4218
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4219
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4220
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4221
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4222
                                Disabled:                r.Policy2Disabled,
×
4223
                                MessageFlags:            r.Policy2MessageFlags,
×
4224
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4225
                        }
×
4226
                }
×
4227

4228
                return policy1, policy2, nil
×
4229

4230
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4231
                if r.Policy1ID.Valid {
×
4232
                        policy1 = &sqlc.GraphChannelPolicy{
×
4233
                                ID:                      r.Policy1ID.Int64,
×
4234
                                Version:                 r.Policy1Version.Int16,
×
4235
                                ChannelID:               r.GraphChannel.ID,
×
4236
                                NodeID:                  r.Policy1NodeID.Int64,
×
4237
                                Timelock:                r.Policy1Timelock.Int32,
×
4238
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4239
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4240
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4241
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4242
                                LastUpdate:              r.Policy1LastUpdate,
×
4243
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4244
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4245
                                Disabled:                r.Policy1Disabled,
×
4246
                                MessageFlags:            r.Policy1MessageFlags,
×
4247
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4248
                                Signature:               r.Policy1Signature,
×
4249
                        }
×
4250
                }
×
4251
                if r.Policy2ID.Valid {
×
4252
                        policy2 = &sqlc.GraphChannelPolicy{
×
4253
                                ID:                      r.Policy2ID.Int64,
×
4254
                                Version:                 r.Policy2Version.Int16,
×
4255
                                ChannelID:               r.GraphChannel.ID,
×
4256
                                NodeID:                  r.Policy2NodeID.Int64,
×
4257
                                Timelock:                r.Policy2Timelock.Int32,
×
4258
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4259
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4260
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4261
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4262
                                LastUpdate:              r.Policy2LastUpdate,
×
4263
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4264
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4265
                                Disabled:                r.Policy2Disabled,
×
4266
                                MessageFlags:            r.Policy2MessageFlags,
×
4267
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4268
                                Signature:               r.Policy2Signature,
×
4269
                        }
×
4270
                }
×
4271

4272
                return policy1, policy2, nil
×
4273

4274
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4275
                if r.Policy1ID.Valid {
×
4276
                        policy1 = &sqlc.GraphChannelPolicy{
×
4277
                                ID:                      r.Policy1ID.Int64,
×
4278
                                Version:                 r.Policy1Version.Int16,
×
4279
                                ChannelID:               r.GraphChannel.ID,
×
4280
                                NodeID:                  r.Policy1NodeID.Int64,
×
4281
                                Timelock:                r.Policy1Timelock.Int32,
×
4282
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4283
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4284
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4285
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4286
                                LastUpdate:              r.Policy1LastUpdate,
×
4287
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4288
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4289
                                Disabled:                r.Policy1Disabled,
×
4290
                                MessageFlags:            r.Policy1MessageFlags,
×
4291
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4292
                                Signature:               r.Policy1Signature,
×
4293
                        }
×
4294
                }
×
4295
                if r.Policy2ID.Valid {
×
4296
                        policy2 = &sqlc.GraphChannelPolicy{
×
4297
                                ID:                      r.Policy2ID.Int64,
×
4298
                                Version:                 r.Policy2Version.Int16,
×
4299
                                ChannelID:               r.GraphChannel.ID,
×
4300
                                NodeID:                  r.Policy2NodeID.Int64,
×
4301
                                Timelock:                r.Policy2Timelock.Int32,
×
4302
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4303
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4304
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4305
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4306
                                LastUpdate:              r.Policy2LastUpdate,
×
4307
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4308
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4309
                                Disabled:                r.Policy2Disabled,
×
4310
                                MessageFlags:            r.Policy2MessageFlags,
×
4311
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4312
                                Signature:               r.Policy2Signature,
×
4313
                        }
×
4314
                }
×
4315

4316
                return policy1, policy2, nil
×
4317

4318
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4319
                if r.Policy1ID.Valid {
×
4320
                        policy1 = &sqlc.GraphChannelPolicy{
×
4321
                                ID:                      r.Policy1ID.Int64,
×
4322
                                Version:                 r.Policy1Version.Int16,
×
4323
                                ChannelID:               r.GraphChannel.ID,
×
4324
                                NodeID:                  r.Policy1NodeID.Int64,
×
4325
                                Timelock:                r.Policy1Timelock.Int32,
×
4326
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4327
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4328
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4329
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4330
                                LastUpdate:              r.Policy1LastUpdate,
×
4331
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4332
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4333
                                Disabled:                r.Policy1Disabled,
×
4334
                                MessageFlags:            r.Policy1MessageFlags,
×
4335
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4336
                                Signature:               r.Policy1Signature,
×
4337
                        }
×
4338
                }
×
4339
                if r.Policy2ID.Valid {
×
4340
                        policy2 = &sqlc.GraphChannelPolicy{
×
4341
                                ID:                      r.Policy2ID.Int64,
×
4342
                                Version:                 r.Policy2Version.Int16,
×
4343
                                ChannelID:               r.GraphChannel.ID,
×
4344
                                NodeID:                  r.Policy2NodeID.Int64,
×
4345
                                Timelock:                r.Policy2Timelock.Int32,
×
4346
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4347
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4348
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4349
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4350
                                LastUpdate:              r.Policy2LastUpdate,
×
4351
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4352
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4353
                                Disabled:                r.Policy2Disabled,
×
4354
                                MessageFlags:            r.Policy2MessageFlags,
×
4355
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4356
                                Signature:               r.Policy2Signature,
×
4357
                        }
×
4358
                }
×
4359

4360
                return policy1, policy2, nil
×
4361

4362
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4363
                if r.Policy1ID.Valid {
×
4364
                        policy1 = &sqlc.GraphChannelPolicy{
×
4365
                                ID:                      r.Policy1ID.Int64,
×
4366
                                Version:                 r.Policy1Version.Int16,
×
4367
                                ChannelID:               r.GraphChannel.ID,
×
4368
                                NodeID:                  r.Policy1NodeID.Int64,
×
4369
                                Timelock:                r.Policy1Timelock.Int32,
×
4370
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4371
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4372
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4373
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4374
                                LastUpdate:              r.Policy1LastUpdate,
×
4375
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4376
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4377
                                Disabled:                r.Policy1Disabled,
×
4378
                                MessageFlags:            r.Policy1MessageFlags,
×
4379
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4380
                                Signature:               r.Policy1Signature,
×
4381
                        }
×
4382
                }
×
4383
                if r.Policy2ID.Valid {
×
4384
                        policy2 = &sqlc.GraphChannelPolicy{
×
4385
                                ID:                      r.Policy2ID.Int64,
×
4386
                                Version:                 r.Policy2Version.Int16,
×
4387
                                ChannelID:               r.GraphChannel.ID,
×
4388
                                NodeID:                  r.Policy2NodeID.Int64,
×
4389
                                Timelock:                r.Policy2Timelock.Int32,
×
4390
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4391
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4392
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4393
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4394
                                LastUpdate:              r.Policy2LastUpdate,
×
4395
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4396
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4397
                                Disabled:                r.Policy2Disabled,
×
4398
                                MessageFlags:            r.Policy2MessageFlags,
×
4399
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4400
                                Signature:               r.Policy2Signature,
×
4401
                        }
×
4402
                }
×
4403

4404
                return policy1, policy2, nil
×
4405

4406
        case sqlc.ListChannelsForNodeIDsRow:
×
4407
                if r.Policy1ID.Valid {
×
4408
                        policy1 = &sqlc.GraphChannelPolicy{
×
4409
                                ID:                      r.Policy1ID.Int64,
×
4410
                                Version:                 r.Policy1Version.Int16,
×
4411
                                ChannelID:               r.GraphChannel.ID,
×
4412
                                NodeID:                  r.Policy1NodeID.Int64,
×
4413
                                Timelock:                r.Policy1Timelock.Int32,
×
4414
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4415
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4416
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4417
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4418
                                LastUpdate:              r.Policy1LastUpdate,
×
4419
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4420
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4421
                                Disabled:                r.Policy1Disabled,
×
4422
                                MessageFlags:            r.Policy1MessageFlags,
×
4423
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4424
                                Signature:               r.Policy1Signature,
×
4425
                        }
×
4426
                }
×
4427
                if r.Policy2ID.Valid {
×
4428
                        policy2 = &sqlc.GraphChannelPolicy{
×
4429
                                ID:                      r.Policy2ID.Int64,
×
4430
                                Version:                 r.Policy2Version.Int16,
×
4431
                                ChannelID:               r.GraphChannel.ID,
×
4432
                                NodeID:                  r.Policy2NodeID.Int64,
×
4433
                                Timelock:                r.Policy2Timelock.Int32,
×
4434
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4435
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4436
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4437
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4438
                                LastUpdate:              r.Policy2LastUpdate,
×
4439
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4440
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4441
                                Disabled:                r.Policy2Disabled,
×
4442
                                MessageFlags:            r.Policy2MessageFlags,
×
4443
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4444
                                Signature:               r.Policy2Signature,
×
4445
                        }
×
4446
                }
×
4447

4448
                return policy1, policy2, nil
×
4449

4450
        case sqlc.ListChannelsByNodeIDRow:
×
4451
                if r.Policy1ID.Valid {
×
4452
                        policy1 = &sqlc.GraphChannelPolicy{
×
4453
                                ID:                      r.Policy1ID.Int64,
×
4454
                                Version:                 r.Policy1Version.Int16,
×
4455
                                ChannelID:               r.GraphChannel.ID,
×
4456
                                NodeID:                  r.Policy1NodeID.Int64,
×
4457
                                Timelock:                r.Policy1Timelock.Int32,
×
4458
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4459
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4460
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4461
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4462
                                LastUpdate:              r.Policy1LastUpdate,
×
4463
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4464
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4465
                                Disabled:                r.Policy1Disabled,
×
4466
                                MessageFlags:            r.Policy1MessageFlags,
×
4467
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4468
                                Signature:               r.Policy1Signature,
×
4469
                        }
×
4470
                }
×
4471
                if r.Policy2ID.Valid {
×
4472
                        policy2 = &sqlc.GraphChannelPolicy{
×
4473
                                ID:                      r.Policy2ID.Int64,
×
4474
                                Version:                 r.Policy2Version.Int16,
×
4475
                                ChannelID:               r.GraphChannel.ID,
×
4476
                                NodeID:                  r.Policy2NodeID.Int64,
×
4477
                                Timelock:                r.Policy2Timelock.Int32,
×
4478
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4479
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4480
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4481
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4482
                                LastUpdate:              r.Policy2LastUpdate,
×
4483
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4484
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4485
                                Disabled:                r.Policy2Disabled,
×
4486
                                MessageFlags:            r.Policy2MessageFlags,
×
4487
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4488
                                Signature:               r.Policy2Signature,
×
4489
                        }
×
4490
                }
×
4491

4492
                return policy1, policy2, nil
×
4493

4494
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4495
                if r.Policy1ID.Valid {
×
4496
                        policy1 = &sqlc.GraphChannelPolicy{
×
4497
                                ID:                      r.Policy1ID.Int64,
×
4498
                                Version:                 r.Policy1Version.Int16,
×
4499
                                ChannelID:               r.GraphChannel.ID,
×
4500
                                NodeID:                  r.Policy1NodeID.Int64,
×
4501
                                Timelock:                r.Policy1Timelock.Int32,
×
4502
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4503
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4504
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4505
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4506
                                LastUpdate:              r.Policy1LastUpdate,
×
4507
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4508
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4509
                                Disabled:                r.Policy1Disabled,
×
4510
                                MessageFlags:            r.Policy1MessageFlags,
×
4511
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4512
                                Signature:               r.Policy1Signature,
×
4513
                        }
×
4514
                }
×
4515
                if r.Policy2ID.Valid {
×
4516
                        policy2 = &sqlc.GraphChannelPolicy{
×
4517
                                ID:                      r.Policy2ID.Int64,
×
4518
                                Version:                 r.Policy2Version.Int16,
×
4519
                                ChannelID:               r.GraphChannel.ID,
×
4520
                                NodeID:                  r.Policy2NodeID.Int64,
×
4521
                                Timelock:                r.Policy2Timelock.Int32,
×
4522
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4523
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4524
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4525
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4526
                                LastUpdate:              r.Policy2LastUpdate,
×
4527
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4528
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4529
                                Disabled:                r.Policy2Disabled,
×
4530
                                MessageFlags:            r.Policy2MessageFlags,
×
4531
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4532
                                Signature:               r.Policy2Signature,
×
4533
                        }
×
4534
                }
×
4535

4536
                return policy1, policy2, nil
×
4537
        default:
×
4538
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4539
                        "extractChannelPolicies: %T", r)
×
4540
        }
4541
}
4542

4543
// channelIDToBytes converts a channel ID (SCID) to a byte array
4544
// representation.
4545
func channelIDToBytes(channelID uint64) []byte {
×
4546
        var chanIDB [8]byte
×
4547
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4548

×
4549
        return chanIDB[:]
×
4550
}
×
4551

4552
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4553
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4554
        if len(addresses) == 0 {
×
4555
                return nil, nil
×
4556
        }
×
4557

4558
        result := make([]net.Addr, 0, len(addresses))
×
4559
        for _, addr := range addresses {
×
4560
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
4561
                if err != nil {
×
4562
                        return nil, fmt.Errorf("unable to parse address %s "+
×
4563
                                "of type %d: %w", addr.address, addr.addrType,
×
4564
                                err)
×
4565
                }
×
4566
                if netAddr != nil {
×
4567
                        result = append(result, netAddr)
×
4568
                }
×
4569
        }
4570

4571
        // If we have no valid addresses, return nil instead of empty slice.
4572
        if len(result) == 0 {
×
4573
                return nil, nil
×
4574
        }
×
4575

4576
        return result, nil
×
4577
}
4578

4579
// parseAddress parses the given address string based on the address type
4580
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4581
// and opaque addresses.
4582
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4583
        switch addrType {
×
4584
        case addressTypeIPv4:
×
4585
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4586
                if err != nil {
×
4587
                        return nil, err
×
4588
                }
×
4589

4590
                tcp.IP = tcp.IP.To4()
×
4591

×
4592
                return tcp, nil
×
4593

4594
        case addressTypeIPv6:
×
4595
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
4596
                if err != nil {
×
4597
                        return nil, err
×
4598
                }
×
4599

4600
                return tcp, nil
×
4601

4602
        case addressTypeTorV3, addressTypeTorV2:
×
4603
                service, portStr, err := net.SplitHostPort(address)
×
4604
                if err != nil {
×
4605
                        return nil, fmt.Errorf("unable to split tor "+
×
4606
                                "address: %v", address)
×
4607
                }
×
4608

4609
                port, err := strconv.Atoi(portStr)
×
4610
                if err != nil {
×
4611
                        return nil, err
×
4612
                }
×
4613

4614
                return &tor.OnionAddr{
×
4615
                        OnionService: service,
×
4616
                        Port:         port,
×
4617
                }, nil
×
4618

4619
        case addressTypeOpaque:
×
4620
                opaque, err := hex.DecodeString(address)
×
4621
                if err != nil {
×
4622
                        return nil, fmt.Errorf("unable to decode opaque "+
×
4623
                                "address: %v", address)
×
4624
                }
×
4625

4626
                return &lnwire.OpaqueAddrs{
×
4627
                        Payload: opaque,
×
4628
                }, nil
×
4629

4630
        default:
×
4631
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
4632
        }
4633
}
4634

4635
// batchNodeData holds all the related data for a batch of nodes.
4636
type batchNodeData struct {
4637
        // features is a map from a DB node ID to the feature bits for that
4638
        // node.
4639
        features map[int64][]int
4640

4641
        // addresses is a map from a DB node ID to the node's addresses.
4642
        addresses map[int64][]nodeAddress
4643

4644
        // extraFields is a map from a DB node ID to the extra signed fields
4645
        // for that node.
4646
        extraFields map[int64]map[uint64][]byte
4647
}
4648

4649
// nodeAddress holds the address type, position and address string for a
4650
// node. This is used to batch the fetching of node addresses.
4651
type nodeAddress struct {
4652
        addrType dbAddressType
4653
        position int32
4654
        address  string
4655
}
4656

4657
// batchLoadNodeData loads all related data for a batch of node IDs using the
4658
// provided SQLQueries interface. It returns a batchNodeData instance containing
4659
// the node features, addresses and extra signed fields.
4660
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
4661
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
4662

×
4663
        // Batch load the node features.
×
4664
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
4665
        if err != nil {
×
4666
                return nil, fmt.Errorf("unable to batch load node "+
×
4667
                        "features: %w", err)
×
4668
        }
×
4669

4670
        // Batch load the node addresses.
4671
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
4672
        if err != nil {
×
4673
                return nil, fmt.Errorf("unable to batch load node "+
×
4674
                        "addresses: %w", err)
×
4675
        }
×
4676

4677
        // Batch load the node extra signed fields.
4678
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
4679
        if err != nil {
×
4680
                return nil, fmt.Errorf("unable to batch load node extra "+
×
4681
                        "signed fields: %w", err)
×
4682
        }
×
4683

4684
        return &batchNodeData{
×
4685
                features:    features,
×
4686
                addresses:   addrs,
×
4687
                extraFields: extraTypes,
×
4688
        }, nil
×
4689
}
4690

4691
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
4692
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
4693
func batchLoadNodeFeaturesHelper(ctx context.Context,
4694
        cfg *sqldb.QueryConfig, db SQLQueries,
4695
        nodeIDs []int64) (map[int64][]int, error) {
×
4696

×
4697
        features := make(map[int64][]int)
×
4698

×
4699
        return features, sqldb.ExecuteBatchQuery(
×
4700
                ctx, cfg, nodeIDs,
×
4701
                func(id int64) int64 {
×
4702
                        return id
×
4703
                },
×
4704
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
4705
                        error) {
×
4706

×
4707
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
4708
                },
×
4709
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
4710
                        features[feature.NodeID] = append(
×
4711
                                features[feature.NodeID],
×
4712
                                int(feature.FeatureBit),
×
4713
                        )
×
4714

×
4715
                        return nil
×
4716
                },
×
4717
        )
4718
}
4719

4720
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
4721
// wrapper around the GetNodeAddressesBatch query. It returns a map from
4722
// node ID to a slice of nodeAddress structs.
4723
func batchLoadNodeAddressesHelper(ctx context.Context,
4724
        cfg *sqldb.QueryConfig, db SQLQueries,
4725
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
4726

×
4727
        addrs := make(map[int64][]nodeAddress)
×
4728

×
4729
        return addrs, sqldb.ExecuteBatchQuery(
×
4730
                ctx, cfg, nodeIDs,
×
4731
                func(id int64) int64 {
×
4732
                        return id
×
4733
                },
×
4734
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
4735
                        error) {
×
4736

×
4737
                        return db.GetNodeAddressesBatch(ctx, ids)
×
4738
                },
×
4739
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
4740
                        addrs[addr.NodeID] = append(
×
4741
                                addrs[addr.NodeID], nodeAddress{
×
4742
                                        addrType: dbAddressType(addr.Type),
×
4743
                                        position: addr.Position,
×
4744
                                        address:  addr.Address,
×
4745
                                },
×
4746
                        )
×
4747

×
4748
                        return nil
×
4749
                },
×
4750
        )
4751
}
4752

4753
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
4754
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
4755
// query.
4756
func batchLoadNodeExtraTypesHelper(ctx context.Context,
4757
        cfg *sqldb.QueryConfig, db SQLQueries,
4758
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4759

×
4760
        extraFields := make(map[int64]map[uint64][]byte)
×
4761

×
4762
        callback := func(ctx context.Context,
×
4763
                field sqlc.GraphNodeExtraType) error {
×
4764

×
4765
                if extraFields[field.NodeID] == nil {
×
4766
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
4767
                }
×
4768
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
4769

×
4770
                return nil
×
4771
        }
4772

4773
        return extraFields, sqldb.ExecuteBatchQuery(
×
4774
                ctx, cfg, nodeIDs,
×
4775
                func(id int64) int64 {
×
4776
                        return id
×
4777
                },
×
4778
                func(ctx context.Context, ids []int64) (
4779
                        []sqlc.GraphNodeExtraType, error) {
×
4780

×
4781
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
4782
                },
×
4783
                callback,
4784
        )
4785
}
4786

4787
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
4788
// from the provided sqlc.GraphChannelPolicy records and the
4789
// provided batchChannelData.
4790
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4791
        channelID uint64, node1, node2 route.Vertex,
4792
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
4793
        *models.ChannelEdgePolicy, error) {
×
4794

×
4795
        pol1, err := buildChanPolicyWithBatchData(
×
4796
                dbPol1, channelID, node2, batchData,
×
4797
        )
×
4798
        if err != nil {
×
4799
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4800
        }
×
4801

4802
        pol2, err := buildChanPolicyWithBatchData(
×
4803
                dbPol2, channelID, node1, batchData,
×
4804
        )
×
4805
        if err != nil {
×
4806
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4807
        }
×
4808

4809
        return pol1, pol2, nil
×
4810
}
4811

4812
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
4813
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
4814
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
4815
        channelID uint64, toNode route.Vertex,
4816
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
4817

×
4818
        if dbPol == nil {
×
4819
                return nil, nil
×
4820
        }
×
4821

4822
        var dbPol1Extras map[uint64][]byte
×
4823
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
4824
                dbPol1Extras = extras
×
4825
        } else {
×
4826
                dbPol1Extras = make(map[uint64][]byte)
×
4827
        }
×
4828

4829
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
4830
}
4831

4832
// batchChannelData holds all the related data for a batch of channels.
4833
type batchChannelData struct {
4834
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
4835
        chanfeatures map[int64][]int
4836

4837
        // chanExtras is a map from DB channel ID to a map of TLV type to
4838
        // extra signed field bytes.
4839
        chanExtraTypes map[int64]map[uint64][]byte
4840

4841
        // policyExtras is a map from DB channel policy ID to a map of TLV type
4842
        // to extra signed field bytes.
4843
        policyExtras map[int64]map[uint64][]byte
4844
}
4845

4846
// batchLoadChannelData loads all related data for batches of channels and
4847
// policies.
4848
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
4849
        db SQLQueries, channelIDs []int64,
4850
        policyIDs []int64) (*batchChannelData, error) {
×
4851

×
4852
        batchData := &batchChannelData{
×
4853
                chanfeatures:   make(map[int64][]int),
×
4854
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
4855
                policyExtras:   make(map[int64]map[uint64][]byte),
×
4856
        }
×
4857

×
4858
        // Batch load channel features and extras
×
4859
        var err error
×
4860
        if len(channelIDs) > 0 {
×
4861
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
4862
                        ctx, cfg, db, channelIDs,
×
4863
                )
×
4864
                if err != nil {
×
4865
                        return nil, fmt.Errorf("unable to batch load "+
×
4866
                                "channel features: %w", err)
×
4867
                }
×
4868

4869
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
4870
                        ctx, cfg, db, channelIDs,
×
4871
                )
×
4872
                if err != nil {
×
4873
                        return nil, fmt.Errorf("unable to batch load "+
×
4874
                                "channel extras: %w", err)
×
4875
                }
×
4876
        }
4877

4878
        if len(policyIDs) > 0 {
×
4879
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
4880
                        ctx, cfg, db, policyIDs,
×
4881
                )
×
4882
                if err != nil {
×
4883
                        return nil, fmt.Errorf("unable to batch load "+
×
4884
                                "policy extras: %w", err)
×
4885
                }
×
4886
                batchData.policyExtras = policyExtras
×
4887
        }
4888

4889
        return batchData, nil
×
4890
}
4891

4892
// batchLoadChannelFeaturesHelper loads channel features for a batch of
4893
// channel IDs using ExecuteBatchQuery wrapper around the
4894
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
4895
// slice of feature bits.
4896
func batchLoadChannelFeaturesHelper(ctx context.Context,
4897
        cfg *sqldb.QueryConfig, db SQLQueries,
4898
        channelIDs []int64) (map[int64][]int, error) {
×
4899

×
4900
        features := make(map[int64][]int)
×
4901

×
4902
        return features, sqldb.ExecuteBatchQuery(
×
4903
                ctx, cfg, channelIDs,
×
4904
                func(id int64) int64 {
×
4905
                        return id
×
4906
                },
×
4907
                func(ctx context.Context,
4908
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
4909

×
4910
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
4911
                },
×
4912
                func(ctx context.Context,
4913
                        feature sqlc.GraphChannelFeature) error {
×
4914

×
4915
                        features[feature.ChannelID] = append(
×
4916
                                features[feature.ChannelID],
×
4917
                                int(feature.FeatureBit),
×
4918
                        )
×
4919

×
4920
                        return nil
×
4921
                },
×
4922
        )
4923
}
4924

4925
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
4926
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
4927
// query. It returns a map from DB channel ID to a map of TLV type to extra
4928
// signed field bytes.
4929
func batchLoadChannelExtrasHelper(ctx context.Context,
4930
        cfg *sqldb.QueryConfig, db SQLQueries,
4931
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4932

×
4933
        extras := make(map[int64]map[uint64][]byte)
×
4934

×
4935
        cb := func(ctx context.Context,
×
4936
                extra sqlc.GraphChannelExtraType) error {
×
4937

×
4938
                if extras[extra.ChannelID] == nil {
×
4939
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
4940
                }
×
4941
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
4942

×
4943
                return nil
×
4944
        }
4945

4946
        return extras, sqldb.ExecuteBatchQuery(
×
4947
                ctx, cfg, channelIDs,
×
4948
                func(id int64) int64 {
×
4949
                        return id
×
4950
                },
×
4951
                func(ctx context.Context,
4952
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
4953

×
4954
                        return db.GetChannelExtrasBatch(ctx, ids)
×
4955
                }, cb,
×
4956
        )
4957
}
4958

4959
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
4960
// batch of policy IDs using ExecuteBatchQuery wrapper around the
4961
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
4962
// a map of TLV type to extra signed field bytes.
4963
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
4964
        cfg *sqldb.QueryConfig, db SQLQueries,
4965
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4966

×
4967
        extras := make(map[int64]map[uint64][]byte)
×
4968

×
4969
        return extras, sqldb.ExecuteBatchQuery(
×
4970
                ctx, cfg, policyIDs,
×
4971
                func(id int64) int64 {
×
4972
                        return id
×
4973
                },
×
4974
                func(ctx context.Context, ids []int64) (
4975
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
4976

×
4977
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
4978
                },
×
4979
                func(ctx context.Context,
4980
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
4981

×
4982
                        if extras[row.PolicyID] == nil {
×
4983
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
4984
                        }
×
4985
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
4986

×
4987
                        return nil
×
4988
                },
4989
        )
4990
}
4991

4992
// forEachNodePaginated executes a paginated query to process each node in the
4993
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
4994
// and applies the provided processNode function to each node.
4995
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
4996
        db SQLQueries, protocol ProtocolVersion,
4997
        processNode func(context.Context, int64,
4998
                *models.LightningNode) error) error {
×
4999

×
5000
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5001
                limit int32) ([]sqlc.GraphNode, error) {
×
5002

×
5003
                return db.ListNodesPaginated(
×
5004
                        ctx, sqlc.ListNodesPaginatedParams{
×
5005
                                Version: int16(protocol),
×
5006
                                ID:      lastID,
×
5007
                                Limit:   limit,
×
5008
                        },
×
5009
                )
×
5010
        }
×
5011

5012
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5013
                return node.ID
×
5014
        }
×
5015

5016
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5017
                return node.ID, nil
×
5018
        }
×
5019

5020
        batchQueryFunc := func(ctx context.Context,
×
5021
                nodeIDs []int64) (*batchNodeData, error) {
×
5022

×
5023
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5024
        }
×
5025

5026
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5027
                batchData *batchNodeData) error {
×
5028

×
NEW
5029
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
5030
                if err != nil {
×
5031
                        return fmt.Errorf("unable to build "+
×
5032
                                "node(id=%d): %w", dbNode.ID, err)
×
5033
                }
×
5034

5035
                return processNode(ctx, dbNode.ID, node)
×
5036
        }
5037

5038
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5039
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5040
                collectFunc, batchQueryFunc, processItem,
×
5041
        )
×
5042
}
5043

5044
// forEachChannelWithPolicies executes a paginated query to process each channel
5045
// with policies in the graph.
5046
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5047
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5048
                *models.ChannelEdgePolicy,
5049
                *models.ChannelEdgePolicy) error) error {
×
5050

×
5051
        type channelBatchIDs struct {
×
5052
                channelID int64
×
5053
                policyIDs []int64
×
5054
        }
×
5055

×
5056
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5057
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5058
                error) {
×
5059

×
5060
                return db.ListChannelsWithPoliciesPaginated(
×
5061
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
5062
                                Version: int16(ProtocolV1),
×
5063
                                ID:      lastID,
×
5064
                                Limit:   limit,
×
5065
                        },
×
5066
                )
×
5067
        }
×
5068

5069
        extractPageCursor := func(
×
5070
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5071

×
5072
                return row.GraphChannel.ID
×
5073
        }
×
5074

5075
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
5076
                channelBatchIDs, error) {
×
5077

×
5078
                ids := channelBatchIDs{
×
5079
                        channelID: row.GraphChannel.ID,
×
5080
                }
×
5081

×
5082
                // Extract policy IDs from the row.
×
5083
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5084
                if err != nil {
×
5085
                        return ids, err
×
5086
                }
×
5087

5088
                if dbPol1 != nil {
×
5089
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5090
                }
×
5091
                if dbPol2 != nil {
×
5092
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5093
                }
×
5094

5095
                return ids, nil
×
5096
        }
5097

5098
        batchDataFunc := func(ctx context.Context,
×
5099
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5100

×
5101
                // Separate channel IDs from policy IDs.
×
5102
                var (
×
5103
                        channelIDs = make([]int64, len(allIDs))
×
5104
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5105
                )
×
5106

×
5107
                for i, ids := range allIDs {
×
5108
                        channelIDs[i] = ids.channelID
×
5109
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5110
                }
×
5111

5112
                return batchLoadChannelData(
×
5113
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5114
                )
×
5115
        }
5116

5117
        processItem := func(ctx context.Context,
×
5118
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5119
                batchData *batchChannelData) error {
×
5120

×
5121
                node1, node2, err := buildNodeVertices(
×
5122
                        row.Node1Pubkey, row.Node2Pubkey,
×
5123
                )
×
5124
                if err != nil {
×
5125
                        return err
×
5126
                }
×
5127

5128
                edge, err := buildEdgeInfoWithBatchData(
×
5129
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5130
                        batchData,
×
5131
                )
×
5132
                if err != nil {
×
5133
                        return fmt.Errorf("unable to build channel info: %w",
×
5134
                                err)
×
5135
                }
×
5136

5137
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5138
                if err != nil {
×
5139
                        return err
×
5140
                }
×
5141

5142
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5143
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
5144
                )
×
5145
                if err != nil {
×
5146
                        return err
×
5147
                }
×
5148

5149
                return processChannel(edge, p1, p2)
×
5150
        }
5151

5152
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5153
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5154
                collectFunc, batchDataFunc, processItem,
×
5155
        )
×
5156
}
5157

5158
// buildDirectedChannel builds a DirectedChannel instance from the provided
5159
// data.
5160
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5161
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5162
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5163
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
5164

×
5165
        node1, node2, err := buildNodeVertices(
×
5166
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5167
        )
×
5168
        if err != nil {
×
5169
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5170
        }
×
5171

5172
        edge, err := buildEdgeInfoWithBatchData(
×
5173
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5174
        )
×
5175
        if err != nil {
×
5176
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5177
        }
×
5178

5179
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
5180
        if err != nil {
×
5181
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5182
                        err)
×
5183
        }
×
5184

5185
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5186
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5187
                channelBatchData,
×
5188
        )
×
5189
        if err != nil {
×
5190
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5191
                        err)
×
5192
        }
×
5193

5194
        // Determine outgoing and incoming policy for this specific node.
5195
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5196
        p2ToNode := channelRow.GraphChannel.NodeID1
×
5197
        outPolicy, inPolicy := p1, p2
×
5198
        if (p1 != nil && p1ToNode == nodeID) ||
×
5199
                (p2 != nil && p2ToNode != nodeID) {
×
5200

×
5201
                outPolicy, inPolicy = p2, p1
×
5202
        }
×
5203

5204
        // Build cached policy.
5205
        var cachedInPolicy *models.CachedEdgePolicy
×
5206
        if inPolicy != nil {
×
5207
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5208
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5209
                cachedInPolicy.ToNodeFeatures = features
×
5210
        }
×
5211

5212
        // Extract inbound fee.
5213
        var inboundFee lnwire.Fee
×
5214
        if outPolicy != nil {
×
5215
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5216
                        inboundFee = fee
×
5217
                })
×
5218
        }
5219

5220
        // Build directed channel.
5221
        directedChannel := &DirectedChannel{
×
5222
                ChannelID:    edge.ChannelID,
×
5223
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
5224
                OtherNode:    edge.NodeKey2Bytes,
×
5225
                Capacity:     edge.Capacity,
×
5226
                OutPolicySet: outPolicy != nil,
×
5227
                InPolicy:     cachedInPolicy,
×
5228
                InboundFee:   inboundFee,
×
5229
        }
×
5230

×
5231
        if nodePub == edge.NodeKey2Bytes {
×
5232
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5233
        }
×
5234

5235
        return directedChannel, nil
×
5236
}
5237

5238
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
5239
// provided rows. It uses batch loading for channels, policies, and nodes.
5240
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
NEW
5241
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
NEW
5242

×
NEW
5243
        var (
×
NEW
5244
                channelIDs = make([]int64, len(rows))
×
NEW
5245
                policyIDs  = make([]int64, 0, len(rows)*2)
×
NEW
5246
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
NEW
5247

×
NEW
5248
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
NEW
5249
                nodeIDSet = make(map[int64]bool)
×
NEW
5250

×
NEW
5251
                // edges will hold the final channel edges built from the rows.
×
NEW
5252
                edges = make([]ChannelEdge, 0, len(rows))
×
NEW
5253
        )
×
NEW
5254

×
NEW
5255
        // Collect all IDs needed for batch loading.
×
NEW
5256
        for i, row := range rows {
×
NEW
5257
                channelIDs[i] = row.Channel().ID
×
NEW
5258

×
NEW
5259
                // Collect policy IDs
×
NEW
5260
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
NEW
5261
                if err != nil {
×
NEW
5262
                        return nil, fmt.Errorf("unable to extract channel "+
×
NEW
5263
                                "policies: %w", err)
×
NEW
5264
                }
×
NEW
5265
                if dbPol1 != nil {
×
NEW
5266
                        policyIDs = append(policyIDs, dbPol1.ID)
×
NEW
5267
                }
×
NEW
5268
                if dbPol2 != nil {
×
NEW
5269
                        policyIDs = append(policyIDs, dbPol2.ID)
×
NEW
5270
                }
×
5271

NEW
5272
                var (
×
NEW
5273
                        node1ID = row.Node1().ID
×
NEW
5274
                        node2ID = row.Node2().ID
×
NEW
5275
                )
×
NEW
5276

×
NEW
5277
                // Collect unique node IDs.
×
NEW
5278
                if !nodeIDSet[node1ID] {
×
NEW
5279
                        nodeIDs = append(nodeIDs, node1ID)
×
NEW
5280
                        nodeIDSet[node1ID] = true
×
NEW
5281
                }
×
5282

NEW
5283
                if !nodeIDSet[node2ID] {
×
NEW
5284
                        nodeIDs = append(nodeIDs, node2ID)
×
NEW
5285
                        nodeIDSet[node2ID] = true
×
NEW
5286
                }
×
5287
        }
5288

5289
        // Batch the data for all the channels and policies.
NEW
5290
        channelBatchData, err := batchLoadChannelData(
×
NEW
5291
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
NEW
5292
        )
×
NEW
5293
        if err != nil {
×
NEW
5294
                return nil, fmt.Errorf("unable to batch load channel and "+
×
NEW
5295
                        "policy data: %w", err)
×
NEW
5296
        }
×
5297

5298
        // Batch the data for all the nodes.
NEW
5299
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
NEW
5300
        if err != nil {
×
NEW
5301
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
NEW
5302
                        err)
×
NEW
5303
        }
×
5304

5305
        // Build all channel edges using batch data.
NEW
5306
        for _, row := range rows {
×
NEW
5307
                // Build nodes using batch data.
×
NEW
5308
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
NEW
5309
                if err != nil {
×
NEW
5310
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
NEW
5311
                }
×
5312

NEW
5313
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
NEW
5314
                if err != nil {
×
NEW
5315
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
NEW
5316
                }
×
5317

5318
                // Build channel info using batch data.
NEW
5319
                channel, err := buildEdgeInfoWithBatchData(
×
NEW
5320
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
NEW
5321
                        node2.PubKeyBytes, channelBatchData,
×
NEW
5322
                )
×
NEW
5323
                if err != nil {
×
NEW
5324
                        return nil, fmt.Errorf("unable to build channel "+
×
NEW
5325
                                "info: %w", err)
×
NEW
5326
                }
×
5327

5328
                // Extract and build policies using batch data.
NEW
5329
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
NEW
5330
                if err != nil {
×
NEW
5331
                        return nil, fmt.Errorf("unable to extract channel "+
×
NEW
5332
                                "policies: %w", err)
×
NEW
5333
                }
×
5334

NEW
5335
                p1, p2, err := buildChanPoliciesWithBatchData(
×
NEW
5336
                        dbPol1, dbPol2, channel.ChannelID,
×
NEW
5337
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
NEW
5338
                )
×
NEW
5339
                if err != nil {
×
NEW
5340
                        return nil, fmt.Errorf("unable to build channel "+
×
NEW
5341
                                "policies: %w", err)
×
NEW
5342
                }
×
5343

NEW
5344
                edges = append(edges, ChannelEdge{
×
NEW
5345
                        Info:    channel,
×
NEW
5346
                        Policy1: p1,
×
NEW
5347
                        Policy2: p2,
×
NEW
5348
                        Node1:   node1,
×
NEW
5349
                        Node2:   node2,
×
NEW
5350
                })
×
5351
        }
5352

NEW
5353
        return edges, nil
×
5354
}
5355

5356
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
5357
// instances from the provided rows using batch loading for channel data.
5358
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
5359
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
NEW
5360
        []*models.ChannelEdgeInfo, []int64, error) {
×
NEW
5361

×
NEW
5362
        if len(rows) == 0 {
×
NEW
5363
                return nil, nil, nil
×
NEW
5364
        }
×
5365

5366
        // Collect all the channel IDs needed for batch loading.
NEW
5367
        channelIDs := make([]int64, len(rows))
×
NEW
5368
        for i, row := range rows {
×
NEW
5369
                channelIDs[i] = row.Channel().ID
×
NEW
5370
        }
×
5371

5372
        // Batch load the channel data.
NEW
5373
        channelBatchData, err := batchLoadChannelData(
×
NEW
5374
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
NEW
5375
        )
×
NEW
5376
        if err != nil {
×
NEW
5377
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
NEW
5378
                        "data: %w", err)
×
NEW
5379
        }
×
5380

5381
        // Build all channel edges using batch data.
NEW
5382
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
NEW
5383
        for _, row := range rows {
×
NEW
5384
                node1, node2, err := buildNodeVertices(
×
NEW
5385
                        row.Node1Pub(), row.Node2Pub(),
×
NEW
5386
                )
×
NEW
5387
                if err != nil {
×
NEW
5388
                        return nil, nil, err
×
NEW
5389
                }
×
5390

5391
                // Build channel info using batch data
NEW
5392
                info, err := buildEdgeInfoWithBatchData(
×
NEW
5393
                        cfg.ChainHash, row.Channel(), node1, node2,
×
NEW
5394
                        channelBatchData,
×
NEW
5395
                )
×
NEW
5396
                if err != nil {
×
NEW
5397
                        return nil, nil, err
×
NEW
5398
                }
×
5399

NEW
5400
                edges = append(edges, info)
×
5401
        }
5402

NEW
5403
        return edges, channelIDs, nil
×
5404
}
5405

5406
// handleZombieMarking is a helper function that handles the logic of
5407
// marking a channel as a zombie in the database. It takes into account whether
5408
// we are in strict zombie pruning mode, and adjusts the node public keys
5409
// accordingly based on the last update timestamps of the channel policies.
5410
func handleZombieMarking(ctx context.Context, db SQLQueries,
5411
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
NEW
5412
        strictZombiePruning bool, scid uint64) error {
×
NEW
5413

×
NEW
5414
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
NEW
5415

×
NEW
5416
        if strictZombiePruning {
×
NEW
5417
                var e1UpdateTime, e2UpdateTime *time.Time
×
NEW
5418
                if row.Policy1LastUpdate.Valid {
×
NEW
5419
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
NEW
5420
                        e1UpdateTime = &e1Time
×
NEW
5421
                }
×
NEW
5422
                if row.Policy2LastUpdate.Valid {
×
NEW
5423
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
NEW
5424
                        e2UpdateTime = &e2Time
×
NEW
5425
                }
×
5426

NEW
5427
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
NEW
5428
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
NEW
5429
                        e2UpdateTime,
×
NEW
5430
                )
×
5431
        }
5432

NEW
5433
        return db.UpsertZombieChannel(
×
NEW
5434
                ctx, sqlc.UpsertZombieChannelParams{
×
NEW
5435
                        Version:  int16(ProtocolV1),
×
NEW
5436
                        Scid:     channelIDToBytes(scid),
×
NEW
5437
                        NodeKey1: nodeKey1[:],
×
NEW
5438
                        NodeKey2: nodeKey2[:],
×
NEW
5439
                },
×
NEW
5440
        )
×
5441
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc