• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16683051882

01 Aug 2025 07:03PM UTC coverage: 54.949% (-12.1%) from 67.047%
16683051882

Pull #9455

github

web-flow
Merge 3f1f50be8 into 37523b6cb
Pull Request #9455: discovery+lnwire: add support for DNS host name in NodeAnnouncement msg

144 of 226 new or added lines in 7 files covered. (63.72%)

23852 existing lines in 290 files now uncovered.

108751 of 197912 relevant lines covered (54.95%)

22080.83 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
//
37
// TODO(elle): make this configurable & have different defaults for SQLite and
38
// Postgres.
39
const pageSize = 10000
40

41
// ProtocolVersion is an enum that defines the gossip protocol version of a
42
// message.
43
type ProtocolVersion uint8
44

45
const (
46
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
47
        ProtocolV1 ProtocolVersion = 1
48
)
49

50
// String returns a string representation of the protocol version.
51
func (v ProtocolVersion) String() string {
×
52
        return fmt.Sprintf("V%d", v)
×
53
}
×
54

55
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
56
// execute queries against the SQL graph tables.
57
//
58
//nolint:ll,interfacebloat
59
type SQLQueries interface {
60
        /*
61
                Node queries.
62
        */
63
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
64
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
65
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
66
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
67
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
68
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
69
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
70
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
71
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
72
        DeleteNode(ctx context.Context, id int64) error
73

74
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
75
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
76
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
77
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
78

79
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
80
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
81
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
82
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
83

84
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
85
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
86
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
87
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
88
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
89

90
        /*
91
                Source node queries.
92
        */
93
        AddSourceNode(ctx context.Context, nodeID int64) error
94
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
95

96
        /*
97
                Channel queries.
98
        */
99
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
100
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
101
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
102
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
103
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
104
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
105
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
106
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
107
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
108
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
109
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
110
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
111
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
112
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
113
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
114
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
115
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
116
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
117
        DeleteChannels(ctx context.Context, ids []int64) error
118

119
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
120
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
121
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
122
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
123

124
        /*
125
                Channel Policy table queries.
126
        */
127
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
128
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
129
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
130

131
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
132
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
133
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
134

135
        /*
136
                Zombie index queries.
137
        */
138
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
139
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
140
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
141
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
142
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
143

144
        /*
145
                Prune log table queries.
146
        */
147
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
148
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
149
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
150
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
151

152
        /*
153
                Closed SCID table queries.
154
        */
155
        InsertClosedChannel(ctx context.Context, scid []byte) error
156
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
157
}
158

159
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
160
// database operations.
161
type BatchedSQLQueries interface {
162
        SQLQueries
163
        sqldb.BatchedTx[SQLQueries]
164
}
165

166
// SQLStore is an implementation of the V1Store interface that uses a SQL
167
// database as the backend.
168
type SQLStore struct {
169
        cfg *SQLStoreConfig
170
        db  BatchedSQLQueries
171

172
        // cacheMu guards all caches (rejectCache and chanCache). If
173
        // this mutex will be acquired at the same time as the DB mutex then
174
        // the cacheMu MUST be acquired first to prevent deadlock.
175
        cacheMu     sync.RWMutex
176
        rejectCache *rejectCache
177
        chanCache   *channelCache
178

179
        chanScheduler batch.Scheduler[SQLQueries]
180
        nodeScheduler batch.Scheduler[SQLQueries]
181

182
        srcNodes  map[ProtocolVersion]*srcNodeInfo
183
        srcNodeMu sync.Mutex
184
}
185

186
// A compile-time assertion to ensure that SQLStore implements the V1Store
187
// interface.
188
var _ V1Store = (*SQLStore)(nil)
189

190
// SQLStoreConfig holds the configuration for the SQLStore.
191
type SQLStoreConfig struct {
192
        // ChainHash is the genesis hash for the chain that all the gossip
193
        // messages in this store are aimed at.
194
        ChainHash chainhash.Hash
195

196
        // PaginationCfg is the configuration for paginated queries.
197
        PaginationCfg *sqldb.PagedQueryConfig
198
}
199

200
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
201
// storage backend.
202
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
203
        options ...StoreOptionModifier) (*SQLStore, error) {
×
204

×
205
        opts := DefaultOptions()
×
206
        for _, o := range options {
×
207
                o(opts)
×
208
        }
×
209

210
        if opts.NoMigration {
×
211
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
212
                        "supported for SQL stores")
×
213
        }
×
214

215
        s := &SQLStore{
×
216
                cfg:         cfg,
×
217
                db:          db,
×
218
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
219
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
220
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
221
        }
×
222

×
223
        s.chanScheduler = batch.NewTimeScheduler(
×
224
                db, &s.cacheMu, opts.BatchCommitInterval,
×
225
        )
×
226
        s.nodeScheduler = batch.NewTimeScheduler(
×
227
                db, nil, opts.BatchCommitInterval,
×
228
        )
×
229

×
230
        return s, nil
×
231
}
232

233
// AddLightningNode adds a vertex/node to the graph database. If the node is not
234
// in the database from before, this will add a new, unconnected one to the
235
// graph. If it is present from before, this will update that node's
236
// information.
237
//
238
// NOTE: part of the V1Store interface.
239
func (s *SQLStore) AddLightningNode(ctx context.Context,
240
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
241

×
242
        r := &batch.Request[SQLQueries]{
×
243
                Opts: batch.NewSchedulerOptions(opts...),
×
244
                Do: func(queries SQLQueries) error {
×
245
                        _, err := upsertNode(ctx, queries, node)
×
246
                        return err
×
247
                },
×
248
        }
249

250
        return s.nodeScheduler.Execute(ctx, r)
×
251
}
252

253
// FetchLightningNode attempts to look up a target node by its identity public
254
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
255
// returned.
256
//
257
// NOTE: part of the V1Store interface.
258
func (s *SQLStore) FetchLightningNode(ctx context.Context,
259
        pubKey route.Vertex) (*models.LightningNode, error) {
×
260

×
261
        var node *models.LightningNode
×
262
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
263
                var err error
×
264
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
265

×
266
                return err
×
267
        }, sqldb.NoOpReset)
×
268
        if err != nil {
×
269
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
270
        }
×
271

272
        return node, nil
×
273
}
274

275
// HasLightningNode determines if the graph has a vertex identified by the
276
// target node identity public key. If the node exists in the database, a
277
// timestamp of when the data for the node was lasted updated is returned along
278
// with a true boolean. Otherwise, an empty time.Time is returned with a false
279
// boolean.
280
//
281
// NOTE: part of the V1Store interface.
282
func (s *SQLStore) HasLightningNode(ctx context.Context,
283
        pubKey [33]byte) (time.Time, bool, error) {
×
284

×
285
        var (
×
286
                exists     bool
×
287
                lastUpdate time.Time
×
288
        )
×
289
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
290
                dbNode, err := db.GetNodeByPubKey(
×
291
                        ctx, sqlc.GetNodeByPubKeyParams{
×
292
                                Version: int16(ProtocolV1),
×
293
                                PubKey:  pubKey[:],
×
294
                        },
×
295
                )
×
296
                if errors.Is(err, sql.ErrNoRows) {
×
297
                        return nil
×
298
                } else if err != nil {
×
299
                        return fmt.Errorf("unable to fetch node: %w", err)
×
300
                }
×
301

302
                exists = true
×
303

×
304
                if dbNode.LastUpdate.Valid {
×
305
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
306
                }
×
307

308
                return nil
×
309
        }, sqldb.NoOpReset)
310
        if err != nil {
×
311
                return time.Time{}, false,
×
312
                        fmt.Errorf("unable to fetch node: %w", err)
×
313
        }
×
314

315
        return lastUpdate, exists, nil
×
316
}
317

318
// AddrsForNode returns all known addresses for the target node public key
319
// that the graph DB is aware of. The returned boolean indicates if the
320
// given node is unknown to the graph DB or not.
321
//
322
// NOTE: part of the V1Store interface.
323
func (s *SQLStore) AddrsForNode(ctx context.Context,
324
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
325

×
326
        var (
×
327
                addresses []net.Addr
×
328
                known     bool
×
329
        )
×
330
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
331
                // First, check if the node exists and get its DB ID if it
×
332
                // does.
×
333
                dbID, err := db.GetNodeIDByPubKey(
×
334
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
335
                                Version: int16(ProtocolV1),
×
336
                                PubKey:  nodePub.SerializeCompressed(),
×
337
                        },
×
338
                )
×
339
                if errors.Is(err, sql.ErrNoRows) {
×
340
                        return nil
×
341
                }
×
342

343
                known = true
×
344

×
345
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
346
                if err != nil {
×
347
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
348
                                err)
×
349
                }
×
350

351
                return nil
×
352
        }, sqldb.NoOpReset)
353
        if err != nil {
×
354
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
355
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
356
        }
×
357

358
        return known, addresses, nil
×
359
}
360

361
// DeleteLightningNode starts a new database transaction to remove a vertex/node
362
// from the database according to the node's public key.
363
//
364
// NOTE: part of the V1Store interface.
365
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
366
        pubKey route.Vertex) error {
×
367

×
368
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
369
                res, err := db.DeleteNodeByPubKey(
×
370
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
371
                                Version: int16(ProtocolV1),
×
372
                                PubKey:  pubKey[:],
×
373
                        },
×
374
                )
×
375
                if err != nil {
×
376
                        return err
×
377
                }
×
378

379
                rows, err := res.RowsAffected()
×
380
                if err != nil {
×
381
                        return err
×
382
                }
×
383

384
                if rows == 0 {
×
385
                        return ErrGraphNodeNotFound
×
386
                } else if rows > 1 {
×
387
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
388
                }
×
389

390
                return err
×
391
        }, sqldb.NoOpReset)
392
        if err != nil {
×
393
                return fmt.Errorf("unable to delete node: %w", err)
×
394
        }
×
395

396
        return nil
×
397
}
398

399
// FetchNodeFeatures returns the features of the given node. If no features are
400
// known for the node, an empty feature vector is returned.
401
//
402
// NOTE: this is part of the graphdb.NodeTraverser interface.
403
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
404
        *lnwire.FeatureVector, error) {
×
405

×
406
        ctx := context.TODO()
×
407

×
408
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
409
}
×
410

411
// DisabledChannelIDs returns the channel ids of disabled channels.
412
// A channel is disabled when two of the associated ChanelEdgePolicies
413
// have their disabled bit on.
414
//
415
// NOTE: part of the V1Store interface.
416
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
417
        var (
×
418
                ctx     = context.TODO()
×
419
                chanIDs []uint64
×
420
        )
×
421
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
422
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
423
                if err != nil {
×
424
                        return fmt.Errorf("unable to fetch disabled "+
×
425
                                "channels: %w", err)
×
426
                }
×
427

428
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
429

×
430
                return nil
×
431
        }, sqldb.NoOpReset)
432
        if err != nil {
×
433
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
434
                        err)
×
435
        }
×
436

437
        return chanIDs, nil
×
438
}
439

440
// LookupAlias attempts to return the alias as advertised by the target node.
441
//
442
// NOTE: part of the V1Store interface.
443
func (s *SQLStore) LookupAlias(ctx context.Context,
444
        pub *btcec.PublicKey) (string, error) {
×
445

×
446
        var alias string
×
447
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
448
                dbNode, err := db.GetNodeByPubKey(
×
449
                        ctx, sqlc.GetNodeByPubKeyParams{
×
450
                                Version: int16(ProtocolV1),
×
451
                                PubKey:  pub.SerializeCompressed(),
×
452
                        },
×
453
                )
×
454
                if errors.Is(err, sql.ErrNoRows) {
×
455
                        return ErrNodeAliasNotFound
×
456
                } else if err != nil {
×
457
                        return fmt.Errorf("unable to fetch node: %w", err)
×
458
                }
×
459

460
                if !dbNode.Alias.Valid {
×
461
                        return ErrNodeAliasNotFound
×
462
                }
×
463

464
                alias = dbNode.Alias.String
×
465

×
466
                return nil
×
467
        }, sqldb.NoOpReset)
468
        if err != nil {
×
469
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
470
        }
×
471

472
        return alias, nil
×
473
}
474

475
// SourceNode returns the source node of the graph. The source node is treated
476
// as the center node within a star-graph. This method may be used to kick off
477
// a path finding algorithm in order to explore the reachability of another
478
// node based off the source node.
479
//
480
// NOTE: part of the V1Store interface.
481
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
482
        error) {
×
483

×
484
        var node *models.LightningNode
×
485
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
486
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
487
                if err != nil {
×
488
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
489
                                err)
×
490
                }
×
491

492
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
493

×
494
                return err
×
495
        }, sqldb.NoOpReset)
496
        if err != nil {
×
497
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
498
        }
×
499

500
        return node, nil
×
501
}
502

503
// SetSourceNode sets the source node within the graph database. The source
504
// node is to be used as the center of a star-graph within path finding
505
// algorithms.
506
//
507
// NOTE: part of the V1Store interface.
508
func (s *SQLStore) SetSourceNode(ctx context.Context,
509
        node *models.LightningNode) error {
×
510

×
511
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
512
                id, err := upsertNode(ctx, db, node)
×
513
                if err != nil {
×
514
                        return fmt.Errorf("unable to upsert source node: %w",
×
515
                                err)
×
516
                }
×
517

518
                // Make sure that if a source node for this version is already
519
                // set, then the ID is the same as the one we are about to set.
520
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
521
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
522
                        return fmt.Errorf("unable to fetch source node: %w",
×
523
                                err)
×
524
                } else if err == nil {
×
525
                        if dbSourceNodeID != id {
×
526
                                return fmt.Errorf("v1 source node already "+
×
527
                                        "set to a different node: %d vs %d",
×
528
                                        dbSourceNodeID, id)
×
529
                        }
×
530

531
                        return nil
×
532
                }
533

534
                return db.AddSourceNode(ctx, id)
×
535
        }, sqldb.NoOpReset)
536
}
537

538
// NodeUpdatesInHorizon returns all the known lightning node which have an
539
// update timestamp within the passed range. This method can be used by two
540
// nodes to quickly determine if they have the same set of up to date node
541
// announcements.
542
//
543
// NOTE: This is part of the V1Store interface.
544
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
545
        endTime time.Time) ([]models.LightningNode, error) {
×
546

×
547
        ctx := context.TODO()
×
548

×
549
        var nodes []models.LightningNode
×
550
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
551
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
552
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
553
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
554
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
555
                        },
×
556
                )
×
557
                if err != nil {
×
558
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
559
                }
×
560

561
                err = forEachNodeInBatch(
×
562
                        ctx, s.cfg.PaginationCfg, db, dbNodes,
×
563
                        func(_ int64, node *models.LightningNode) error {
×
564
                                nodes = append(nodes, *node)
×
565

×
566
                                return nil
×
567
                        },
×
568
                )
569
                if err != nil {
×
570
                        return fmt.Errorf("unable to build nodes: %w", err)
×
571
                }
×
572

573
                return nil
×
574
        }, sqldb.NoOpReset)
575
        if err != nil {
×
576
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
577
        }
×
578

579
        return nodes, nil
×
580
}
581

582
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
583
// undirected edge from the two target nodes are created. The information stored
584
// denotes the static attributes of the channel, such as the channelID, the keys
585
// involved in creation of the channel, and the set of features that the channel
586
// supports. The chanPoint and chanID are used to uniquely identify the edge
587
// globally within the database.
588
//
589
// NOTE: part of the V1Store interface.
590
func (s *SQLStore) AddChannelEdge(ctx context.Context,
591
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
592

×
593
        var alreadyExists bool
×
594
        r := &batch.Request[SQLQueries]{
×
595
                Opts: batch.NewSchedulerOptions(opts...),
×
596
                Reset: func() {
×
597
                        alreadyExists = false
×
598
                },
×
599
                Do: func(tx SQLQueries) error {
×
600
                        _, err := insertChannel(ctx, tx, edge)
×
601

×
602
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
603
                        // succeed, but propagate the error via local state.
×
604
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
605
                                alreadyExists = true
×
606
                                return nil
×
607
                        }
×
608

609
                        return err
×
610
                },
611
                OnCommit: func(err error) error {
×
612
                        switch {
×
613
                        case err != nil:
×
614
                                return err
×
615
                        case alreadyExists:
×
616
                                return ErrEdgeAlreadyExist
×
617
                        default:
×
618
                                s.rejectCache.remove(edge.ChannelID)
×
619
                                s.chanCache.remove(edge.ChannelID)
×
620
                                return nil
×
621
                        }
622
                },
623
        }
624

625
        return s.chanScheduler.Execute(ctx, r)
×
626
}
627

628
// HighestChanID returns the "highest" known channel ID in the channel graph.
629
// This represents the "newest" channel from the PoV of the chain. This method
630
// can be used by peers to quickly determine if their graphs are in sync.
631
//
632
// NOTE: This is part of the V1Store interface.
633
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
634
        var highestChanID uint64
×
635
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
636
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
637
                if errors.Is(err, sql.ErrNoRows) {
×
638
                        return nil
×
639
                } else if err != nil {
×
640
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
641
                                err)
×
642
                }
×
643

644
                highestChanID = byteOrder.Uint64(chanID)
×
645

×
646
                return nil
×
647
        }, sqldb.NoOpReset)
648
        if err != nil {
×
649
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
650
        }
×
651

652
        return highestChanID, nil
×
653
}
654

655
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
656
// within the database for the referenced channel. The `flags` attribute within
657
// the ChannelEdgePolicy determines which of the directed edges are being
658
// updated. If the flag is 1, then the first node's information is being
659
// updated, otherwise it's the second node's information. The node ordering is
660
// determined by the lexicographical ordering of the identity public keys of the
661
// nodes on either side of the channel.
662
//
663
// NOTE: part of the V1Store interface.
664
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
665
        edge *models.ChannelEdgePolicy,
666
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
667

×
668
        var (
×
669
                isUpdate1    bool
×
670
                edgeNotFound bool
×
671
                from, to     route.Vertex
×
672
        )
×
673

×
674
        r := &batch.Request[SQLQueries]{
×
675
                Opts: batch.NewSchedulerOptions(opts...),
×
676
                Reset: func() {
×
677
                        isUpdate1 = false
×
678
                        edgeNotFound = false
×
679
                },
×
680
                Do: func(tx SQLQueries) error {
×
681
                        var err error
×
682
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
683
                                ctx, tx, edge,
×
684
                        )
×
685
                        if err != nil {
×
686
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
687
                        }
×
688

689
                        // Silence ErrEdgeNotFound so that the batch can
690
                        // succeed, but propagate the error via local state.
691
                        if errors.Is(err, ErrEdgeNotFound) {
×
692
                                edgeNotFound = true
×
693
                                return nil
×
694
                        }
×
695

696
                        return err
×
697
                },
698
                OnCommit: func(err error) error {
×
699
                        switch {
×
700
                        case err != nil:
×
701
                                return err
×
702
                        case edgeNotFound:
×
703
                                return ErrEdgeNotFound
×
704
                        default:
×
705
                                s.updateEdgeCache(edge, isUpdate1)
×
706
                                return nil
×
707
                        }
708
                },
709
        }
710

711
        err := s.chanScheduler.Execute(ctx, r)
×
712

×
713
        return from, to, err
×
714
}
715

716
// updateEdgeCache updates our reject and channel caches with the new
717
// edge policy information.
718
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
719
        isUpdate1 bool) {
×
720

×
721
        // If an entry for this channel is found in reject cache, we'll modify
×
722
        // the entry with the updated timestamp for the direction that was just
×
723
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
724
        // during the next query for this edge.
×
725
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
726
                if isUpdate1 {
×
727
                        entry.upd1Time = e.LastUpdate.Unix()
×
728
                } else {
×
729
                        entry.upd2Time = e.LastUpdate.Unix()
×
730
                }
×
731
                s.rejectCache.insert(e.ChannelID, entry)
×
732
        }
733

734
        // If an entry for this channel is found in channel cache, we'll modify
735
        // the entry with the updated policy for the direction that was just
736
        // written. If the edge doesn't exist, we'll defer loading the info and
737
        // policies and lazily read from disk during the next query.
738
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
739
                if isUpdate1 {
×
740
                        channel.Policy1 = e
×
741
                } else {
×
742
                        channel.Policy2 = e
×
743
                }
×
744
                s.chanCache.insert(e.ChannelID, channel)
×
745
        }
746
}
747

748
// ForEachSourceNodeChannel iterates through all channels of the source node,
749
// executing the passed callback on each. The call-back is provided with the
750
// channel's outpoint, whether we have a policy for the channel and the channel
751
// peer's node information.
752
//
753
// NOTE: part of the V1Store interface.
754
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
755
        cb func(chanPoint wire.OutPoint, havePolicy bool,
756
                otherNode *models.LightningNode) error, reset func()) error {
×
757

×
758
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
759
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
760
                if err != nil {
×
761
                        return fmt.Errorf("unable to fetch source node: %w",
×
762
                                err)
×
763
                }
×
764

765
                return forEachNodeChannel(
×
766
                        ctx, db, s.cfg.ChainHash, nodeID,
×
767
                        func(info *models.ChannelEdgeInfo,
×
768
                                outPolicy *models.ChannelEdgePolicy,
×
769
                                _ *models.ChannelEdgePolicy) error {
×
770

×
771
                                // Fetch the other node.
×
772
                                var (
×
773
                                        otherNodePub [33]byte
×
774
                                        node1        = info.NodeKey1Bytes
×
775
                                        node2        = info.NodeKey2Bytes
×
776
                                )
×
777
                                switch {
×
778
                                case bytes.Equal(node1[:], nodePub[:]):
×
779
                                        otherNodePub = node2
×
780
                                case bytes.Equal(node2[:], nodePub[:]):
×
781
                                        otherNodePub = node1
×
782
                                default:
×
783
                                        return fmt.Errorf("node not " +
×
784
                                                "participating in this channel")
×
785
                                }
786

787
                                _, otherNode, err := getNodeByPubKey(
×
788
                                        ctx, db, otherNodePub,
×
789
                                )
×
790
                                if err != nil {
×
791
                                        return fmt.Errorf("unable to fetch "+
×
792
                                                "other node(%x): %w",
×
793
                                                otherNodePub, err)
×
794
                                }
×
795

796
                                return cb(
×
797
                                        info.ChannelPoint, outPolicy != nil,
×
798
                                        otherNode,
×
799
                                )
×
800
                        },
801
                )
802
        }, reset)
803
}
804

805
// ForEachNode iterates through all the stored vertices/nodes in the graph,
806
// executing the passed callback with each node encountered. If the callback
807
// returns an error, then the transaction is aborted and the iteration stops
808
// early. Any operations performed on the NodeTx passed to the call-back are
809
// executed under the same read transaction and so, methods on the NodeTx object
810
// _MUST_ only be called from within the call-back.
811
//
812
// NOTE: part of the V1Store interface.
813
func (s *SQLStore) ForEachNode(ctx context.Context,
814
        cb func(tx NodeRTx) error, reset func()) error {
×
815

×
816
        var lastID int64
×
817

×
818
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
819
                nodeCB := func(dbID int64, node *models.LightningNode) error {
×
820
                        err := cb(newSQLGraphNodeTx(
×
821
                                db, s.cfg.ChainHash, dbID, node,
×
822
                        ))
×
823
                        if err != nil {
×
824
                                return fmt.Errorf("callback failed for "+
×
825
                                        "node(id=%d): %w", dbID, err)
×
826
                        }
×
827
                        lastID = dbID
×
828

×
829
                        return nil
×
830
                }
831

832
                for {
×
833
                        nodes, err := db.ListNodesPaginated(
×
834
                                ctx, sqlc.ListNodesPaginatedParams{
×
835
                                        Version: int16(ProtocolV1),
×
836
                                        ID:      lastID,
×
837
                                        Limit:   pageSize,
×
838
                                },
×
839
                        )
×
840
                        if err != nil {
×
841
                                return fmt.Errorf("unable to fetch nodes: %w",
×
842
                                        err)
×
843
                        }
×
844

845
                        if len(nodes) == 0 {
×
846
                                break
×
847
                        }
848

849
                        err = forEachNodeInBatch(
×
850
                                ctx, s.cfg.PaginationCfg, db, nodes, nodeCB,
×
851
                        )
×
852
                        if err != nil {
×
853
                                return fmt.Errorf("unable to iterate over "+
×
854
                                        "nodes: %w", err)
×
855
                        }
×
856
                }
857

858
                return nil
×
859
        }, reset)
860
}
861

862
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
863
// SQLStore and a SQL transaction.
864
type sqlGraphNodeTx struct {
865
        db    SQLQueries
866
        id    int64
867
        node  *models.LightningNode
868
        chain chainhash.Hash
869
}
870

871
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
872
// interface.
873
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
874

875
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
876
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
877

×
878
        return &sqlGraphNodeTx{
×
879
                db:    db,
×
880
                chain: chain,
×
881
                id:    id,
×
882
                node:  node,
×
883
        }
×
884
}
×
885

886
// Node returns the raw information of the node.
887
//
888
// NOTE: This is a part of the NodeRTx interface.
889
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
890
        return s.node
×
891
}
×
892

893
// ForEachChannel can be used to iterate over the node's channels under the same
894
// transaction used to fetch the node.
895
//
896
// NOTE: This is a part of the NodeRTx interface.
897
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
898
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
899

×
900
        ctx := context.TODO()
×
901

×
902
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
903
}
×
904

905
// FetchNode fetches the node with the given pub key under the same transaction
906
// used to fetch the current node. The returned node is also a NodeRTx and any
907
// operations on that NodeRTx will also be done under the same transaction.
908
//
909
// NOTE: This is a part of the NodeRTx interface.
910
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
911
        ctx := context.TODO()
×
912

×
913
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
914
        if err != nil {
×
915
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
916
                        nodePub, err)
×
917
        }
×
918

919
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
920
}
921

922
// ForEachNodeDirectedChannel iterates through all channels of a given node,
923
// executing the passed callback on the directed edge representing the channel
924
// and its incoming policy. If the callback returns an error, then the iteration
925
// is halted with the error propagated back up to the caller.
926
//
927
// Unknown policies are passed into the callback as nil values.
928
//
929
// NOTE: this is part of the graphdb.NodeTraverser interface.
930
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
931
        cb func(channel *DirectedChannel) error, reset func()) error {
×
932

×
933
        var ctx = context.TODO()
×
934

×
935
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
936
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
937
        }, reset)
×
938
}
939

940
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
941
// graph, executing the passed callback with each node encountered. If the
942
// callback returns an error, then the transaction is aborted and the iteration
943
// stops early.
944
//
945
// NOTE: This is a part of the V1Store interface.
946
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
947
        cb func(route.Vertex, *lnwire.FeatureVector) error,
948
        reset func()) error {
×
949

×
950
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
951
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
952
                        nodePub route.Vertex) error {
×
953

×
954
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
955
                        if err != nil {
×
956
                                return fmt.Errorf("unable to fetch node "+
×
957
                                        "features: %w", err)
×
958
                        }
×
959

960
                        return cb(nodePub, features)
×
961
                })
962
        }, reset)
963
        if err != nil {
×
964
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
965
        }
×
966

967
        return nil
×
968
}
969

970
// ForEachNodeChannel iterates through all channels of the given node,
971
// executing the passed callback with an edge info structure and the policies
972
// of each end of the channel. The first edge policy is the outgoing edge *to*
973
// the connecting node, while the second is the incoming edge *from* the
974
// connecting node. If the callback returns an error, then the iteration is
975
// halted with the error propagated back up to the caller.
976
//
977
// Unknown policies are passed into the callback as nil values.
978
//
979
// NOTE: part of the V1Store interface.
980
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
981
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
982
                *models.ChannelEdgePolicy) error, reset func()) error {
×
983

×
984
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
985
                dbNode, err := db.GetNodeByPubKey(
×
986
                        ctx, sqlc.GetNodeByPubKeyParams{
×
987
                                Version: int16(ProtocolV1),
×
988
                                PubKey:  nodePub[:],
×
989
                        },
×
990
                )
×
991
                if errors.Is(err, sql.ErrNoRows) {
×
992
                        return nil
×
993
                } else if err != nil {
×
994
                        return fmt.Errorf("unable to fetch node: %w", err)
×
995
                }
×
996

997
                return forEachNodeChannel(
×
998
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
999
                )
×
1000
        }, reset)
1001
}
1002

1003
// ChanUpdatesInHorizon returns all the known channel edges which have at least
1004
// one edge that has an update timestamp within the specified horizon.
1005
//
1006
// NOTE: This is part of the V1Store interface.
1007
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
1008
        endTime time.Time) ([]ChannelEdge, error) {
×
1009

×
1010
        s.cacheMu.Lock()
×
1011
        defer s.cacheMu.Unlock()
×
1012

×
1013
        var (
×
1014
                ctx = context.TODO()
×
1015
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
1016
                // an additional map to keep track of the edges already seen to
×
1017
                // prevent re-adding it.
×
1018
                edgesSeen    = make(map[uint64]struct{})
×
1019
                edgesToCache = make(map[uint64]ChannelEdge)
×
1020
                edges        []ChannelEdge
×
1021
                hits         int
×
1022
        )
×
1023
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1024
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1025
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1026
                                Version:   int16(ProtocolV1),
×
1027
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1028
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1029
                        },
×
1030
                )
×
1031
                if err != nil {
×
1032
                        return err
×
1033
                }
×
1034

1035
                for _, row := range rows {
×
1036
                        // If we've already retrieved the info and policies for
×
1037
                        // this edge, then we can skip it as we don't need to do
×
1038
                        // so again.
×
1039
                        chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid)
×
1040
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1041
                                continue
×
1042
                        }
1043

1044
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1045
                                hits++
×
1046
                                edgesSeen[chanIDInt] = struct{}{}
×
1047
                                edges = append(edges, channel)
×
1048

×
1049
                                continue
×
1050
                        }
1051

1052
                        node1, node2, err := buildNodes(
×
1053
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
1054
                        )
×
1055
                        if err != nil {
×
1056
                                return err
×
1057
                        }
×
1058

1059
                        channel, err := getAndBuildEdgeInfo(
×
1060
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
1061
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1062
                        )
×
1063
                        if err != nil {
×
1064
                                return fmt.Errorf("unable to build channel "+
×
1065
                                        "info: %w", err)
×
1066
                        }
×
1067

1068
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1069
                        if err != nil {
×
1070
                                return fmt.Errorf("unable to extract channel "+
×
1071
                                        "policies: %w", err)
×
1072
                        }
×
1073

1074
                        p1, p2, err := getAndBuildChanPolicies(
×
1075
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1076
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1077
                        )
×
1078
                        if err != nil {
×
1079
                                return fmt.Errorf("unable to build channel "+
×
1080
                                        "policies: %w", err)
×
1081
                        }
×
1082

1083
                        edgesSeen[chanIDInt] = struct{}{}
×
1084
                        chanEdge := ChannelEdge{
×
1085
                                Info:    channel,
×
1086
                                Policy1: p1,
×
1087
                                Policy2: p2,
×
1088
                                Node1:   node1,
×
1089
                                Node2:   node2,
×
1090
                        }
×
1091
                        edges = append(edges, chanEdge)
×
1092
                        edgesToCache[chanIDInt] = chanEdge
×
1093
                }
1094

1095
                return nil
×
1096
        }, func() {
×
1097
                edgesSeen = make(map[uint64]struct{})
×
1098
                edgesToCache = make(map[uint64]ChannelEdge)
×
1099
                edges = nil
×
1100
        })
×
1101
        if err != nil {
×
1102
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1103
        }
×
1104

1105
        // Insert any edges loaded from disk into the cache.
1106
        for chanid, channel := range edgesToCache {
×
1107
                s.chanCache.insert(chanid, channel)
×
1108
        }
×
1109

1110
        if len(edges) > 0 {
×
1111
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1112
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1113
        } else {
×
1114
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1115
                        "horizon (%s, %s)", startTime, endTime)
×
1116
        }
×
1117

1118
        return edges, nil
×
1119
}
1120

1121
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1122
// data to the call-back.
1123
//
1124
// NOTE: The callback contents MUST not be modified.
1125
//
1126
// NOTE: part of the V1Store interface.
1127
func (s *SQLStore) ForEachNodeCached(ctx context.Context,
1128
        cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error,
1129
        reset func()) error {
×
1130

×
1131
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1132
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1133
                        nodePub route.Vertex) error {
×
1134

×
1135
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1136
                        if err != nil {
×
1137
                                return fmt.Errorf("unable to fetch "+
×
1138
                                        "node(id=%d) features: %w", nodeID, err)
×
1139
                        }
×
1140

1141
                        toNodeCallback := func() route.Vertex {
×
1142
                                return nodePub
×
1143
                        }
×
1144

1145
                        rows, err := db.ListChannelsByNodeID(
×
1146
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1147
                                        Version: int16(ProtocolV1),
×
1148
                                        NodeID1: nodeID,
×
1149
                                },
×
1150
                        )
×
1151
                        if err != nil {
×
1152
                                return fmt.Errorf("unable to fetch channels "+
×
1153
                                        "of node(id=%d): %w", nodeID, err)
×
1154
                        }
×
1155

1156
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1157
                        for _, row := range rows {
×
1158
                                node1, node2, err := buildNodeVertices(
×
1159
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1160
                                )
×
1161
                                if err != nil {
×
1162
                                        return err
×
1163
                                }
×
1164

1165
                                e, err := getAndBuildEdgeInfo(
×
1166
                                        ctx, db, s.cfg.ChainHash,
×
1167
                                        row.GraphChannel, node1, node2,
×
1168
                                )
×
1169
                                if err != nil {
×
1170
                                        return fmt.Errorf("unable to build "+
×
1171
                                                "channel info: %w", err)
×
1172
                                }
×
1173

1174
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1175
                                        row,
×
1176
                                )
×
1177
                                if err != nil {
×
1178
                                        return fmt.Errorf("unable to "+
×
1179
                                                "extract channel "+
×
1180
                                                "policies: %w", err)
×
1181
                                }
×
1182

1183
                                p1, p2, err := getAndBuildChanPolicies(
×
1184
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1185
                                        node1, node2,
×
1186
                                )
×
1187
                                if err != nil {
×
1188
                                        return fmt.Errorf("unable to "+
×
1189
                                                "build channel policies: %w",
×
1190
                                                err)
×
1191
                                }
×
1192

1193
                                // Determine the outgoing and incoming policy
1194
                                // for this channel and node combo.
1195
                                outPolicy, inPolicy := p1, p2
×
1196
                                if p1 != nil && p1.ToNode == nodePub {
×
1197
                                        outPolicy, inPolicy = p2, p1
×
1198
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1199
                                        outPolicy, inPolicy = p2, p1
×
1200
                                }
×
1201

1202
                                var cachedInPolicy *models.CachedEdgePolicy
×
1203
                                if inPolicy != nil {
×
1204
                                        cachedInPolicy = models.NewCachedPolicy(
×
1205
                                                inPolicy,
×
1206
                                        )
×
1207
                                        cachedInPolicy.ToNodePubKey =
×
1208
                                                toNodeCallback
×
1209
                                        cachedInPolicy.ToNodeFeatures =
×
1210
                                                features
×
1211
                                }
×
1212

1213
                                var inboundFee lnwire.Fee
×
1214
                                if outPolicy != nil {
×
1215
                                        outPolicy.InboundFee.WhenSome(
×
1216
                                                func(fee lnwire.Fee) {
×
1217
                                                        inboundFee = fee
×
1218
                                                },
×
1219
                                        )
1220
                                }
1221

1222
                                directedChannel := &DirectedChannel{
×
1223
                                        ChannelID: e.ChannelID,
×
1224
                                        IsNode1: nodePub ==
×
1225
                                                e.NodeKey1Bytes,
×
1226
                                        OtherNode:    e.NodeKey2Bytes,
×
1227
                                        Capacity:     e.Capacity,
×
1228
                                        OutPolicySet: outPolicy != nil,
×
1229
                                        InPolicy:     cachedInPolicy,
×
1230
                                        InboundFee:   inboundFee,
×
1231
                                }
×
1232

×
1233
                                if nodePub == e.NodeKey2Bytes {
×
1234
                                        directedChannel.OtherNode =
×
1235
                                                e.NodeKey1Bytes
×
1236
                                }
×
1237

1238
                                channels[e.ChannelID] = directedChannel
×
1239
                        }
1240

1241
                        return cb(nodePub, channels)
×
1242
                })
1243
        }, reset)
1244
}
1245

1246
// ForEachChannelCacheable iterates through all the channel edges stored
1247
// within the graph and invokes the passed callback for each edge. The
1248
// callback takes two edges as since this is a directed graph, both the
1249
// in/out edges are visited. If the callback returns an error, then the
1250
// transaction is aborted and the iteration stops early.
1251
//
1252
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1253
// pointer for that particular channel edge routing policy will be
1254
// passed into the callback.
1255
//
1256
// NOTE: this method is like ForEachChannel but fetches only the data
1257
// required for the graph cache.
1258
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1259
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1260
        reset func()) error {
×
1261

×
1262
        ctx := context.TODO()
×
1263

×
1264
        handleChannel := func(
×
1265
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1266

×
1267
                node1, node2, err := buildNodeVertices(
×
1268
                        row.Node1Pubkey, row.Node2Pubkey,
×
1269
                )
×
1270
                if err != nil {
×
1271
                        return err
×
1272
                }
×
1273

1274
                edge := buildCacheableChannelInfo(
×
1275
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1276
                )
×
1277

×
1278
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1279
                if err != nil {
×
1280
                        return err
×
1281
                }
×
1282

1283
                var pol1, pol2 *models.CachedEdgePolicy
×
1284
                if dbPol1 != nil {
×
1285
                        policy1, err := buildChanPolicy(
×
1286
                                *dbPol1, edge.ChannelID, nil, node2,
×
1287
                        )
×
1288
                        if err != nil {
×
1289
                                return err
×
1290
                        }
×
1291

1292
                        pol1 = models.NewCachedPolicy(policy1)
×
1293
                }
1294
                if dbPol2 != nil {
×
1295
                        policy2, err := buildChanPolicy(
×
1296
                                *dbPol2, edge.ChannelID, nil, node1,
×
1297
                        )
×
1298
                        if err != nil {
×
1299
                                return err
×
1300
                        }
×
1301

1302
                        pol2 = models.NewCachedPolicy(policy2)
×
1303
                }
1304

1305
                if err := cb(edge, pol1, pol2); err != nil {
×
1306
                        return err
×
1307
                }
×
1308

1309
                return nil
×
1310
        }
1311

1312
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1313
                lastID := int64(-1)
×
1314
                for {
×
1315
                        //nolint:ll
×
1316
                        rows, err := db.ListChannelsWithPoliciesForCachePaginated(
×
1317
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1318
                                        Version: int16(ProtocolV1),
×
1319
                                        ID:      lastID,
×
1320
                                        Limit:   pageSize,
×
1321
                                },
×
1322
                        )
×
1323
                        if err != nil {
×
1324
                                return err
×
1325
                        }
×
1326

1327
                        if len(rows) == 0 {
×
1328
                                break
×
1329
                        }
1330

1331
                        for _, row := range rows {
×
1332
                                err := handleChannel(row)
×
1333
                                if err != nil {
×
1334
                                        return err
×
1335
                                }
×
1336

1337
                                lastID = row.ID
×
1338
                        }
1339
                }
1340

1341
                return nil
×
1342
        }, reset)
1343
}
1344

1345
// ForEachChannel iterates through all the channel edges stored within the
1346
// graph and invokes the passed callback for each edge. The callback takes two
1347
// edges as since this is a directed graph, both the in/out edges are visited.
1348
// If the callback returns an error, then the transaction is aborted and the
1349
// iteration stops early.
1350
//
1351
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1352
// for that particular channel edge routing policy will be passed into the
1353
// callback.
1354
//
1355
// NOTE: part of the V1Store interface.
1356
func (s *SQLStore) ForEachChannel(ctx context.Context,
1357
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1358
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1359

×
1360
        handleChannel := func(db SQLQueries, batchData *batchChannelData,
×
1361
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1362

×
1363
                node1, node2, err := buildNodeVertices(
×
1364
                        row.Node1Pubkey, row.Node2Pubkey,
×
1365
                )
×
1366
                if err != nil {
×
1367
                        return fmt.Errorf("unable to build node vertices: %w",
×
1368
                                err)
×
1369
                }
×
1370

1371
                edge, err := buildEdgeInfoWithBatchData(
×
1372
                        s.cfg.ChainHash, row.GraphChannel, node1, node2,
×
1373
                        batchData,
×
1374
                )
×
1375
                if err != nil {
×
1376
                        return fmt.Errorf("unable to build channel info: %w",
×
1377
                                err)
×
1378
                }
×
1379

1380
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1381
                if err != nil {
×
1382
                        return fmt.Errorf("unable to extract channel "+
×
1383
                                "policies: %w", err)
×
1384
                }
×
1385

1386
                p1, p2, err := buildChanPoliciesWithBatchData(
×
1387
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
1388
                )
×
1389
                if err != nil {
×
1390
                        return fmt.Errorf("unable to build channel "+
×
1391
                                "policies: %w", err)
×
1392
                }
×
1393

1394
                err = cb(edge, p1, p2)
×
1395
                if err != nil {
×
1396
                        return fmt.Errorf("callback failed for channel "+
×
1397
                                "id=%d: %w", edge.ChannelID, err)
×
1398
                }
×
1399

1400
                return nil
×
1401
        }
1402

1403
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1404
                lastID := int64(-1)
×
1405
                for {
×
1406
                        //nolint:ll
×
1407
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1408
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1409
                                        Version: int16(ProtocolV1),
×
1410
                                        ID:      lastID,
×
1411
                                        Limit:   pageSize,
×
1412
                                },
×
1413
                        )
×
1414
                        if err != nil {
×
1415
                                return err
×
1416
                        }
×
1417

1418
                        if len(rows) == 0 {
×
1419
                                break
×
1420
                        }
1421

1422
                        // Collect the channel & policy IDs that we want to
1423
                        // do a batch collection for.
1424
                        var (
×
1425
                                channelIDs = make([]int64, len(rows))
×
1426
                                policyIDs  = make([]int64, 0, len(rows)*2)
×
1427
                        )
×
1428
                        for i, row := range rows {
×
1429
                                channelIDs[i] = row.GraphChannel.ID
×
1430

×
1431
                                // Extract policy IDs from the row
×
1432
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1433
                                        row,
×
1434
                                )
×
1435
                                if err != nil {
×
1436
                                        return fmt.Errorf("unable to extract "+
×
1437
                                                "channel policies: %w", err)
×
1438
                                }
×
1439

1440
                                if dbPol1 != nil {
×
1441
                                        policyIDs = append(policyIDs, dbPol1.ID)
×
1442
                                }
×
1443

1444
                                if dbPol2 != nil {
×
1445
                                        policyIDs = append(policyIDs, dbPol2.ID)
×
1446
                                }
×
1447
                        }
1448

1449
                        batchData, err := batchLoadChannelData(
×
1450
                                ctx, s.cfg.PaginationCfg, db, channelIDs,
×
1451
                                policyIDs,
×
1452
                        )
×
1453
                        if err != nil {
×
1454
                                return fmt.Errorf("unable to batch load "+
×
1455
                                        "channel data: %w", err)
×
1456
                        }
×
1457

1458
                        for _, row := range rows {
×
1459
                                err := handleChannel(db, batchData, row)
×
1460
                                if err != nil {
×
1461
                                        return err
×
1462
                                }
×
1463

1464
                                lastID = row.GraphChannel.ID
×
1465
                        }
1466
                }
1467

1468
                return nil
×
1469
        }, reset)
1470
}
1471

1472
// FilterChannelRange returns the channel ID's of all known channels which were
1473
// mined in a block height within the passed range. The channel IDs are grouped
1474
// by their common block height. This method can be used to quickly share with a
1475
// peer the set of channels we know of within a particular range to catch them
1476
// up after a period of time offline. If withTimestamps is true then the
1477
// timestamp info of the latest received channel update messages of the channel
1478
// will be included in the response.
1479
//
1480
// NOTE: This is part of the V1Store interface.
1481
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1482
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1483

×
1484
        var (
×
1485
                ctx       = context.TODO()
×
1486
                startSCID = &lnwire.ShortChannelID{
×
1487
                        BlockHeight: startHeight,
×
1488
                }
×
1489
                endSCID = lnwire.ShortChannelID{
×
1490
                        BlockHeight: endHeight,
×
1491
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1492
                        TxPosition:  math.MaxUint16,
×
1493
                }
×
1494
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1495
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1496
        )
×
1497

×
1498
        // 1) get all channels where channelID is between start and end chan ID.
×
1499
        // 2) skip if not public (ie, no channel_proof)
×
1500
        // 3) collect that channel.
×
1501
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1502
        //    and add those timestamps to the collected channel.
×
1503
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1504
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1505
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1506
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1507
                                StartScid: chanIDStart,
×
1508
                                EndScid:   chanIDEnd,
×
1509
                        },
×
1510
                )
×
1511
                if err != nil {
×
1512
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1513
                                err)
×
1514
                }
×
1515

1516
                for _, dbChan := range dbChans {
×
1517
                        cid := lnwire.NewShortChanIDFromInt(
×
1518
                                byteOrder.Uint64(dbChan.Scid),
×
1519
                        )
×
1520
                        chanInfo := NewChannelUpdateInfo(
×
1521
                                cid, time.Time{}, time.Time{},
×
1522
                        )
×
1523

×
1524
                        if !withTimestamps {
×
1525
                                channelsPerBlock[cid.BlockHeight] = append(
×
1526
                                        channelsPerBlock[cid.BlockHeight],
×
1527
                                        chanInfo,
×
1528
                                )
×
1529

×
1530
                                continue
×
1531
                        }
1532

1533
                        //nolint:ll
1534
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1535
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1536
                                        Version:   int16(ProtocolV1),
×
1537
                                        ChannelID: dbChan.ID,
×
1538
                                        NodeID:    dbChan.NodeID1,
×
1539
                                },
×
1540
                        )
×
1541
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1542
                                return fmt.Errorf("unable to fetch node1 "+
×
1543
                                        "policy: %w", err)
×
1544
                        } else if err == nil {
×
1545
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1546
                                        node1Policy.LastUpdate.Int64, 0,
×
1547
                                )
×
1548
                        }
×
1549

1550
                        //nolint:ll
1551
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1552
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1553
                                        Version:   int16(ProtocolV1),
×
1554
                                        ChannelID: dbChan.ID,
×
1555
                                        NodeID:    dbChan.NodeID2,
×
1556
                                },
×
1557
                        )
×
1558
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1559
                                return fmt.Errorf("unable to fetch node2 "+
×
1560
                                        "policy: %w", err)
×
1561
                        } else if err == nil {
×
1562
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1563
                                        node2Policy.LastUpdate.Int64, 0,
×
1564
                                )
×
1565
                        }
×
1566

1567
                        channelsPerBlock[cid.BlockHeight] = append(
×
1568
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1569
                        )
×
1570
                }
1571

1572
                return nil
×
1573
        }, func() {
×
1574
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1575
        })
×
1576
        if err != nil {
×
1577
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1578
        }
×
1579

1580
        if len(channelsPerBlock) == 0 {
×
1581
                return nil, nil
×
1582
        }
×
1583

1584
        // Return the channel ranges in ascending block height order.
1585
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1586
        slices.Sort(blocks)
×
1587

×
1588
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1589
                return BlockChannelRange{
×
1590
                        Height:   block,
×
1591
                        Channels: channelsPerBlock[block],
×
1592
                }
×
1593
        }), nil
×
1594
}
1595

1596
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1597
// zombie. This method is used on an ad-hoc basis, when channels need to be
1598
// marked as zombies outside the normal pruning cycle.
1599
//
1600
// NOTE: part of the V1Store interface.
1601
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1602
        pubKey1, pubKey2 [33]byte) error {
×
1603

×
1604
        ctx := context.TODO()
×
1605

×
1606
        s.cacheMu.Lock()
×
1607
        defer s.cacheMu.Unlock()
×
1608

×
1609
        chanIDB := channelIDToBytes(chanID)
×
1610

×
1611
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1612
                return db.UpsertZombieChannel(
×
1613
                        ctx, sqlc.UpsertZombieChannelParams{
×
1614
                                Version:  int16(ProtocolV1),
×
1615
                                Scid:     chanIDB,
×
1616
                                NodeKey1: pubKey1[:],
×
1617
                                NodeKey2: pubKey2[:],
×
1618
                        },
×
1619
                )
×
1620
        }, sqldb.NoOpReset)
×
1621
        if err != nil {
×
1622
                return fmt.Errorf("unable to upsert zombie channel "+
×
1623
                        "(channel_id=%d): %w", chanID, err)
×
1624
        }
×
1625

1626
        s.rejectCache.remove(chanID)
×
1627
        s.chanCache.remove(chanID)
×
1628

×
1629
        return nil
×
1630
}
1631

1632
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1633
//
1634
// NOTE: part of the V1Store interface.
1635
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1636
        s.cacheMu.Lock()
×
1637
        defer s.cacheMu.Unlock()
×
1638

×
1639
        var (
×
1640
                ctx     = context.TODO()
×
1641
                chanIDB = channelIDToBytes(chanID)
×
1642
        )
×
1643

×
1644
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1645
                res, err := db.DeleteZombieChannel(
×
1646
                        ctx, sqlc.DeleteZombieChannelParams{
×
1647
                                Scid:    chanIDB,
×
1648
                                Version: int16(ProtocolV1),
×
1649
                        },
×
1650
                )
×
1651
                if err != nil {
×
1652
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1653
                                err)
×
1654
                }
×
1655

1656
                rows, err := res.RowsAffected()
×
1657
                if err != nil {
×
1658
                        return err
×
1659
                }
×
1660

1661
                if rows == 0 {
×
1662
                        return ErrZombieEdgeNotFound
×
1663
                } else if rows > 1 {
×
1664
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1665
                                "expected 1", rows)
×
1666
                }
×
1667

1668
                return nil
×
1669
        }, sqldb.NoOpReset)
1670
        if err != nil {
×
1671
                return fmt.Errorf("unable to mark edge live "+
×
1672
                        "(channel_id=%d): %w", chanID, err)
×
1673
        }
×
1674

1675
        s.rejectCache.remove(chanID)
×
1676
        s.chanCache.remove(chanID)
×
1677

×
1678
        return err
×
1679
}
1680

1681
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1682
// zombie, then the two node public keys corresponding to this edge are also
1683
// returned.
1684
//
1685
// NOTE: part of the V1Store interface.
1686
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1687
        error) {
×
1688

×
1689
        var (
×
1690
                ctx              = context.TODO()
×
1691
                isZombie         bool
×
1692
                pubKey1, pubKey2 route.Vertex
×
1693
                chanIDB          = channelIDToBytes(chanID)
×
1694
        )
×
1695

×
1696
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1697
                zombie, err := db.GetZombieChannel(
×
1698
                        ctx, sqlc.GetZombieChannelParams{
×
1699
                                Scid:    chanIDB,
×
1700
                                Version: int16(ProtocolV1),
×
1701
                        },
×
1702
                )
×
1703
                if errors.Is(err, sql.ErrNoRows) {
×
1704
                        return nil
×
1705
                }
×
1706
                if err != nil {
×
1707
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1708
                                err)
×
1709
                }
×
1710

1711
                copy(pubKey1[:], zombie.NodeKey1)
×
1712
                copy(pubKey2[:], zombie.NodeKey2)
×
1713
                isZombie = true
×
1714

×
1715
                return nil
×
1716
        }, sqldb.NoOpReset)
1717
        if err != nil {
×
1718
                return false, route.Vertex{}, route.Vertex{},
×
1719
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1720
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1721
        }
×
1722

1723
        return isZombie, pubKey1, pubKey2, nil
×
1724
}
1725

1726
// NumZombies returns the current number of zombie channels in the graph.
1727
//
1728
// NOTE: part of the V1Store interface.
1729
func (s *SQLStore) NumZombies() (uint64, error) {
×
1730
        var (
×
1731
                ctx        = context.TODO()
×
1732
                numZombies uint64
×
1733
        )
×
1734
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1735
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1736
                if err != nil {
×
1737
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1738
                                err)
×
1739
                }
×
1740

1741
                numZombies = uint64(count)
×
1742

×
1743
                return nil
×
1744
        }, sqldb.NoOpReset)
1745
        if err != nil {
×
1746
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1747
        }
×
1748

1749
        return numZombies, nil
×
1750
}
1751

1752
// DeleteChannelEdges removes edges with the given channel IDs from the
1753
// database and marks them as zombies. This ensures that we're unable to re-add
1754
// it to our database once again. If an edge does not exist within the
1755
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1756
// true, then when we mark these edges as zombies, we'll set up the keys such
1757
// that we require the node that failed to send the fresh update to be the one
1758
// that resurrects the channel from its zombie state. The markZombie bool
1759
// denotes whether to mark the channel as a zombie.
1760
//
1761
// NOTE: part of the V1Store interface.
1762
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1763
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1764

×
1765
        s.cacheMu.Lock()
×
1766
        defer s.cacheMu.Unlock()
×
1767

×
1768
        // Keep track of which channels we end up finding so that we can
×
1769
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1770
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1771
        for _, chanID := range chanIDs {
×
1772
                chanLookup[chanID] = struct{}{}
×
1773
        }
×
1774

1775
        var (
×
1776
                ctx     = context.TODO()
×
1777
                deleted []*models.ChannelEdgeInfo
×
1778
        )
×
1779
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1780
                chanIDsToDelete := make([]int64, 0, len(chanIDs))
×
1781
                chanCallBack := func(ctx context.Context,
×
1782
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1783

×
1784
                        // Deleting the entry from the map indicates that we
×
1785
                        // have found the channel.
×
1786
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1787
                        delete(chanLookup, scid)
×
1788

×
1789
                        node1, node2, err := buildNodeVertices(
×
1790
                                row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1791
                        )
×
1792
                        if err != nil {
×
1793
                                return err
×
1794
                        }
×
1795

1796
                        info, err := getAndBuildEdgeInfo(
×
1797
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
1798
                                node1, node2,
×
1799
                        )
×
1800
                        if err != nil {
×
1801
                                return err
×
1802
                        }
×
1803

1804
                        deleted = append(deleted, info)
×
1805
                        chanIDsToDelete = append(
×
1806
                                chanIDsToDelete, row.GraphChannel.ID,
×
1807
                        )
×
1808

×
1809
                        if !markZombie {
×
1810
                                return nil
×
1811
                        }
×
1812

1813
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1814
                                info.NodeKey2Bytes
×
1815
                        if strictZombiePruning {
×
1816
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1817
                                if row.Policy1LastUpdate.Valid {
×
1818
                                        e1Time := time.Unix(
×
1819
                                                row.Policy1LastUpdate.Int64, 0,
×
1820
                                        )
×
1821
                                        e1UpdateTime = &e1Time
×
1822
                                }
×
1823
                                if row.Policy2LastUpdate.Valid {
×
1824
                                        e2Time := time.Unix(
×
1825
                                                row.Policy2LastUpdate.Int64, 0,
×
1826
                                        )
×
1827
                                        e2UpdateTime = &e2Time
×
1828
                                }
×
1829

1830
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1831
                                        info, e1UpdateTime, e2UpdateTime,
×
1832
                                )
×
1833
                        }
1834

1835
                        err = db.UpsertZombieChannel(
×
1836
                                ctx, sqlc.UpsertZombieChannelParams{
×
1837
                                        Version:  int16(ProtocolV1),
×
1838
                                        Scid:     channelIDToBytes(scid),
×
1839
                                        NodeKey1: nodeKey1[:],
×
1840
                                        NodeKey2: nodeKey2[:],
×
1841
                                },
×
1842
                        )
×
1843
                        if err != nil {
×
1844
                                return fmt.Errorf("unable to mark channel as "+
×
1845
                                        "zombie: %w", err)
×
1846
                        }
×
1847

1848
                        return nil
×
1849
                }
1850

1851
                err := s.forEachChanWithPoliciesInSCIDList(
×
1852
                        ctx, db, chanCallBack, chanIDs,
×
1853
                )
×
1854
                if err != nil {
×
1855
                        return err
×
1856
                }
×
1857

1858
                if len(chanLookup) > 0 {
×
1859
                        return ErrEdgeNotFound
×
1860
                }
×
1861

1862
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1863
        }, func() {
×
1864
                deleted = nil
×
1865

×
1866
                // Re-fill the lookup map.
×
1867
                for _, chanID := range chanIDs {
×
1868
                        chanLookup[chanID] = struct{}{}
×
1869
                }
×
1870
        })
1871
        if err != nil {
×
1872
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1873
                        err)
×
1874
        }
×
1875

1876
        for _, chanID := range chanIDs {
×
1877
                s.rejectCache.remove(chanID)
×
1878
                s.chanCache.remove(chanID)
×
1879
        }
×
1880

1881
        return deleted, nil
×
1882
}
1883

1884
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1885
// channel identified by the channel ID. If the channel can't be found, then
1886
// ErrEdgeNotFound is returned. A struct which houses the general information
1887
// for the channel itself is returned as well as two structs that contain the
1888
// routing policies for the channel in either direction.
1889
//
1890
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1891
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1892
// the ChannelEdgeInfo will only include the public keys of each node.
1893
//
1894
// NOTE: part of the V1Store interface.
1895
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1896
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1897
        *models.ChannelEdgePolicy, error) {
×
1898

×
1899
        var (
×
1900
                ctx              = context.TODO()
×
1901
                edge             *models.ChannelEdgeInfo
×
1902
                policy1, policy2 *models.ChannelEdgePolicy
×
1903
                chanIDB          = channelIDToBytes(chanID)
×
1904
        )
×
1905
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1906
                row, err := db.GetChannelBySCIDWithPolicies(
×
1907
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1908
                                Scid:    chanIDB,
×
1909
                                Version: int16(ProtocolV1),
×
1910
                        },
×
1911
                )
×
1912
                if errors.Is(err, sql.ErrNoRows) {
×
1913
                        // First check if this edge is perhaps in the zombie
×
1914
                        // index.
×
1915
                        zombie, err := db.GetZombieChannel(
×
1916
                                ctx, sqlc.GetZombieChannelParams{
×
1917
                                        Scid:    chanIDB,
×
1918
                                        Version: int16(ProtocolV1),
×
1919
                                },
×
1920
                        )
×
1921
                        if errors.Is(err, sql.ErrNoRows) {
×
1922
                                return ErrEdgeNotFound
×
1923
                        } else if err != nil {
×
1924
                                return fmt.Errorf("unable to check if "+
×
1925
                                        "channel is zombie: %w", err)
×
1926
                        }
×
1927

1928
                        // At this point, we know the channel is a zombie, so
1929
                        // we'll return an error indicating this, and we will
1930
                        // populate the edge info with the public keys of each
1931
                        // party as this is the only information we have about
1932
                        // it.
1933
                        edge = &models.ChannelEdgeInfo{}
×
1934
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1935
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1936

×
1937
                        return ErrZombieEdge
×
1938
                } else if err != nil {
×
1939
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1940
                }
×
1941

1942
                node1, node2, err := buildNodeVertices(
×
1943
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1944
                )
×
1945
                if err != nil {
×
1946
                        return err
×
1947
                }
×
1948

1949
                edge, err = getAndBuildEdgeInfo(
×
1950
                        ctx, db, s.cfg.ChainHash, row.GraphChannel, node1,
×
1951
                        node2,
×
1952
                )
×
1953
                if err != nil {
×
1954
                        return fmt.Errorf("unable to build channel info: %w",
×
1955
                                err)
×
1956
                }
×
1957

1958
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1959
                if err != nil {
×
1960
                        return fmt.Errorf("unable to extract channel "+
×
1961
                                "policies: %w", err)
×
1962
                }
×
1963

1964
                policy1, policy2, err = getAndBuildChanPolicies(
×
1965
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1966
                )
×
1967
                if err != nil {
×
1968
                        return fmt.Errorf("unable to build channel "+
×
1969
                                "policies: %w", err)
×
1970
                }
×
1971

1972
                return nil
×
1973
        }, sqldb.NoOpReset)
1974
        if err != nil {
×
1975
                // If we are returning the ErrZombieEdge, then we also need to
×
1976
                // return the edge info as the method comment indicates that
×
1977
                // this will be populated when the edge is a zombie.
×
1978
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1979
                        err)
×
1980
        }
×
1981

1982
        return edge, policy1, policy2, nil
×
1983
}
1984

1985
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1986
// the channel identified by the funding outpoint. If the channel can't be
1987
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1988
// information for the channel itself is returned as well as two structs that
1989
// contain the routing policies for the channel in either direction.
1990
//
1991
// NOTE: part of the V1Store interface.
1992
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1993
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1994
        *models.ChannelEdgePolicy, error) {
×
1995

×
1996
        var (
×
1997
                ctx              = context.TODO()
×
1998
                edge             *models.ChannelEdgeInfo
×
1999
                policy1, policy2 *models.ChannelEdgePolicy
×
2000
        )
×
2001
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2002
                row, err := db.GetChannelByOutpointWithPolicies(
×
2003
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
2004
                                Outpoint: op.String(),
×
2005
                                Version:  int16(ProtocolV1),
×
2006
                        },
×
2007
                )
×
2008
                if errors.Is(err, sql.ErrNoRows) {
×
2009
                        return ErrEdgeNotFound
×
2010
                } else if err != nil {
×
2011
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2012
                }
×
2013

2014
                node1, node2, err := buildNodeVertices(
×
2015
                        row.Node1Pubkey, row.Node2Pubkey,
×
2016
                )
×
2017
                if err != nil {
×
2018
                        return err
×
2019
                }
×
2020

2021
                edge, err = getAndBuildEdgeInfo(
×
2022
                        ctx, db, s.cfg.ChainHash, row.GraphChannel, node1,
×
2023
                        node2,
×
2024
                )
×
2025
                if err != nil {
×
2026
                        return fmt.Errorf("unable to build channel info: %w",
×
2027
                                err)
×
2028
                }
×
2029

2030
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2031
                if err != nil {
×
2032
                        return fmt.Errorf("unable to extract channel "+
×
2033
                                "policies: %w", err)
×
2034
                }
×
2035

2036
                policy1, policy2, err = getAndBuildChanPolicies(
×
2037
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
2038
                )
×
2039
                if err != nil {
×
2040
                        return fmt.Errorf("unable to build channel "+
×
2041
                                "policies: %w", err)
×
2042
                }
×
2043

2044
                return nil
×
2045
        }, sqldb.NoOpReset)
2046
        if err != nil {
×
2047
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2048
                        err)
×
2049
        }
×
2050

2051
        return edge, policy1, policy2, nil
×
2052
}
2053

2054
// HasChannelEdge returns true if the database knows of a channel edge with the
2055
// passed channel ID, and false otherwise. If an edge with that ID is found
2056
// within the graph, then two time stamps representing the last time the edge
2057
// was updated for both directed edges are returned along with the boolean. If
2058
// it is not found, then the zombie index is checked and its result is returned
2059
// as the second boolean.
2060
//
2061
// NOTE: part of the V1Store interface.
2062
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
2063
        bool, error) {
×
2064

×
2065
        ctx := context.TODO()
×
2066

×
2067
        var (
×
2068
                exists          bool
×
2069
                isZombie        bool
×
2070
                node1LastUpdate time.Time
×
2071
                node2LastUpdate time.Time
×
2072
        )
×
2073

×
2074
        // We'll query the cache with the shared lock held to allow multiple
×
2075
        // readers to access values in the cache concurrently if they exist.
×
2076
        s.cacheMu.RLock()
×
2077
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2078
                s.cacheMu.RUnlock()
×
2079
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2080
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2081
                exists, isZombie = entry.flags.unpack()
×
2082

×
2083
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2084
        }
×
2085
        s.cacheMu.RUnlock()
×
2086

×
2087
        s.cacheMu.Lock()
×
2088
        defer s.cacheMu.Unlock()
×
2089

×
2090
        // The item was not found with the shared lock, so we'll acquire the
×
2091
        // exclusive lock and check the cache again in case another method added
×
2092
        // the entry to the cache while no lock was held.
×
2093
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2094
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2095
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2096
                exists, isZombie = entry.flags.unpack()
×
2097

×
2098
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2099
        }
×
2100

2101
        chanIDB := channelIDToBytes(chanID)
×
2102
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2103
                channel, err := db.GetChannelBySCID(
×
2104
                        ctx, sqlc.GetChannelBySCIDParams{
×
2105
                                Scid:    chanIDB,
×
2106
                                Version: int16(ProtocolV1),
×
2107
                        },
×
2108
                )
×
2109
                if errors.Is(err, sql.ErrNoRows) {
×
2110
                        // Check if it is a zombie channel.
×
2111
                        isZombie, err = db.IsZombieChannel(
×
2112
                                ctx, sqlc.IsZombieChannelParams{
×
2113
                                        Scid:    chanIDB,
×
2114
                                        Version: int16(ProtocolV1),
×
2115
                                },
×
2116
                        )
×
2117
                        if err != nil {
×
2118
                                return fmt.Errorf("could not check if channel "+
×
2119
                                        "is zombie: %w", err)
×
2120
                        }
×
2121

2122
                        return nil
×
2123
                } else if err != nil {
×
2124
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2125
                }
×
2126

2127
                exists = true
×
2128

×
2129
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2130
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2131
                                Version:   int16(ProtocolV1),
×
2132
                                ChannelID: channel.ID,
×
2133
                                NodeID:    channel.NodeID1,
×
2134
                        },
×
2135
                )
×
2136
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2137
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2138
                                err)
×
2139
                } else if err == nil {
×
2140
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2141
                }
×
2142

2143
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2144
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2145
                                Version:   int16(ProtocolV1),
×
2146
                                ChannelID: channel.ID,
×
2147
                                NodeID:    channel.NodeID2,
×
2148
                        },
×
2149
                )
×
2150
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2151
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2152
                                err)
×
2153
                } else if err == nil {
×
2154
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2155
                }
×
2156

2157
                return nil
×
2158
        }, sqldb.NoOpReset)
2159
        if err != nil {
×
2160
                return time.Time{}, time.Time{}, false, false,
×
2161
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2162
        }
×
2163

2164
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2165
                upd1Time: node1LastUpdate.Unix(),
×
2166
                upd2Time: node2LastUpdate.Unix(),
×
2167
                flags:    packRejectFlags(exists, isZombie),
×
2168
        })
×
2169

×
2170
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2171
}
2172

2173
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2174
// passed channel point (outpoint). If the passed channel doesn't exist within
2175
// the database, then ErrEdgeNotFound is returned.
2176
//
2177
// NOTE: part of the V1Store interface.
2178
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2179
        var (
×
2180
                ctx       = context.TODO()
×
2181
                channelID uint64
×
2182
        )
×
2183
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2184
                chanID, err := db.GetSCIDByOutpoint(
×
2185
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2186
                                Outpoint: chanPoint.String(),
×
2187
                                Version:  int16(ProtocolV1),
×
2188
                        },
×
2189
                )
×
2190
                if errors.Is(err, sql.ErrNoRows) {
×
2191
                        return ErrEdgeNotFound
×
2192
                } else if err != nil {
×
2193
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2194
                                err)
×
2195
                }
×
2196

2197
                channelID = byteOrder.Uint64(chanID)
×
2198

×
2199
                return nil
×
2200
        }, sqldb.NoOpReset)
2201
        if err != nil {
×
2202
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2203
        }
×
2204

2205
        return channelID, nil
×
2206
}
2207

2208
// IsPublicNode is a helper method that determines whether the node with the
2209
// given public key is seen as a public node in the graph from the graph's
2210
// source node's point of view.
2211
//
2212
// NOTE: part of the V1Store interface.
2213
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2214
        ctx := context.TODO()
×
2215

×
2216
        var isPublic bool
×
2217
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2218
                var err error
×
2219
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2220

×
2221
                return err
×
2222
        }, sqldb.NoOpReset)
×
2223
        if err != nil {
×
2224
                return false, fmt.Errorf("unable to check if node is "+
×
2225
                        "public: %w", err)
×
2226
        }
×
2227

2228
        return isPublic, nil
×
2229
}
2230

2231
// FetchChanInfos returns the set of channel edges that correspond to the passed
2232
// channel ID's. If an edge is the query is unknown to the database, it will
2233
// skipped and the result will contain only those edges that exist at the time
2234
// of the query. This can be used to respond to peer queries that are seeking to
2235
// fill in gaps in their view of the channel graph.
2236
//
2237
// NOTE: part of the V1Store interface.
2238
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2239
        var (
×
2240
                ctx   = context.TODO()
×
2241
                edges = make(map[uint64]ChannelEdge)
×
2242
        )
×
2243
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2244
                chanCallBack := func(ctx context.Context,
×
2245
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2246

×
2247
                        node1, node2, err := buildNodes(
×
2248
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
2249
                        )
×
2250
                        if err != nil {
×
2251
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2252
                                        err)
×
2253
                        }
×
2254

2255
                        edge, err := getAndBuildEdgeInfo(
×
2256
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
2257
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2258
                        )
×
2259
                        if err != nil {
×
2260
                                return fmt.Errorf("unable to build "+
×
2261
                                        "channel info: %w", err)
×
2262
                        }
×
2263

2264
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2265
                        if err != nil {
×
2266
                                return fmt.Errorf("unable to extract channel "+
×
2267
                                        "policies: %w", err)
×
2268
                        }
×
2269

2270
                        p1, p2, err := getAndBuildChanPolicies(
×
2271
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2272
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2273
                        )
×
2274
                        if err != nil {
×
2275
                                return fmt.Errorf("unable to build channel "+
×
2276
                                        "policies: %w", err)
×
2277
                        }
×
2278

2279
                        edges[edge.ChannelID] = ChannelEdge{
×
2280
                                Info:    edge,
×
2281
                                Policy1: p1,
×
2282
                                Policy2: p2,
×
2283
                                Node1:   node1,
×
2284
                                Node2:   node2,
×
2285
                        }
×
2286

×
2287
                        return nil
×
2288
                }
2289

2290
                return s.forEachChanWithPoliciesInSCIDList(
×
2291
                        ctx, db, chanCallBack, chanIDs,
×
2292
                )
×
2293
        }, func() {
×
2294
                clear(edges)
×
2295
        })
×
2296
        if err != nil {
×
2297
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2298
        }
×
2299

2300
        res := make([]ChannelEdge, 0, len(edges))
×
2301
        for _, chanID := range chanIDs {
×
2302
                edge, ok := edges[chanID]
×
2303
                if !ok {
×
2304
                        continue
×
2305
                }
2306

2307
                res = append(res, edge)
×
2308
        }
2309

2310
        return res, nil
×
2311
}
2312

2313
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2314
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2315
// channels in a paginated manner.
2316
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2317
        db SQLQueries, cb func(ctx context.Context,
2318
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2319
        chanIDs []uint64) error {
×
2320

×
2321
        queryWrapper := func(ctx context.Context,
×
2322
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2323
                error) {
×
2324

×
2325
                return db.GetChannelsBySCIDWithPolicies(
×
2326
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2327
                                Version: int16(ProtocolV1),
×
2328
                                Scids:   scids,
×
2329
                        },
×
2330
                )
×
2331
        }
×
2332

2333
        return sqldb.ExecutePagedQuery(
×
2334
                ctx, s.cfg.PaginationCfg, chanIDs, channelIDToBytes,
×
2335
                queryWrapper, cb,
×
2336
        )
×
2337
}
2338

2339
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2340
// ID's that we don't know and are not known zombies of the passed set. In other
2341
// words, we perform a set difference of our set of chan ID's and the ones
2342
// passed in. This method can be used by callers to determine the set of
2343
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2344
// known zombies is also returned.
2345
//
2346
// NOTE: part of the V1Store interface.
2347
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2348
        []ChannelUpdateInfo, error) {
×
2349

×
2350
        var (
×
2351
                ctx          = context.TODO()
×
2352
                newChanIDs   []uint64
×
2353
                knownZombies []ChannelUpdateInfo
×
2354
                infoLookup   = make(
×
2355
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2356
                )
×
2357
        )
×
2358

×
2359
        // We first build a lookup map of the channel ID's to the
×
2360
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2361
        // already know about.
×
2362
        for _, chanInfo := range chansInfo {
×
2363
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2364
        }
×
2365

2366
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2367
                // The call-back function deletes known channels from
×
2368
                // infoLookup, so that we can later check which channels are
×
2369
                // zombies by only looking at the remaining channels in the set.
×
2370
                cb := func(ctx context.Context,
×
2371
                        channel sqlc.GraphChannel) error {
×
2372

×
2373
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2374

×
2375
                        return nil
×
2376
                }
×
2377

2378
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2379
                if err != nil {
×
2380
                        return fmt.Errorf("unable to iterate through "+
×
2381
                                "channels: %w", err)
×
2382
                }
×
2383

2384
                // We want to ensure that we deal with the channels in the
2385
                // same order that they were passed in, so we iterate over the
2386
                // original chansInfo slice and then check if that channel is
2387
                // still in the infoLookup map.
2388
                for _, chanInfo := range chansInfo {
×
2389
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2390
                        if _, ok := infoLookup[channelID]; !ok {
×
2391
                                continue
×
2392
                        }
2393

2394
                        isZombie, err := db.IsZombieChannel(
×
2395
                                ctx, sqlc.IsZombieChannelParams{
×
2396
                                        Scid:    channelIDToBytes(channelID),
×
2397
                                        Version: int16(ProtocolV1),
×
2398
                                },
×
2399
                        )
×
2400
                        if err != nil {
×
2401
                                return fmt.Errorf("unable to fetch zombie "+
×
2402
                                        "channel: %w", err)
×
2403
                        }
×
2404

2405
                        if isZombie {
×
2406
                                knownZombies = append(knownZombies, chanInfo)
×
2407

×
2408
                                continue
×
2409
                        }
2410

2411
                        newChanIDs = append(newChanIDs, channelID)
×
2412
                }
2413

2414
                return nil
×
2415
        }, func() {
×
2416
                newChanIDs = nil
×
2417
                knownZombies = nil
×
2418
                // Rebuild the infoLookup map in case of a rollback.
×
2419
                for _, chanInfo := range chansInfo {
×
2420
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2421
                        infoLookup[scid] = chanInfo
×
2422
                }
×
2423
        })
2424
        if err != nil {
×
2425
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2426
        }
×
2427

2428
        return newChanIDs, knownZombies, nil
×
2429
}
2430

2431
// forEachChanInSCIDList is a helper method that executes a paged query
2432
// against the database to fetch all channels that match the passed
2433
// ChannelUpdateInfo slice. The callback function is called for each channel
2434
// that is found.
2435
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2436
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2437
        chansInfo []ChannelUpdateInfo) error {
×
2438

×
2439
        queryWrapper := func(ctx context.Context,
×
2440
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2441

×
2442
                return db.GetChannelsBySCIDs(
×
2443
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2444
                                Version: int16(ProtocolV1),
×
2445
                                Scids:   scids,
×
2446
                        },
×
2447
                )
×
2448
        }
×
2449

2450
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2451
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2452

×
2453
                return channelIDToBytes(channelID)
×
2454
        }
×
2455

2456
        return sqldb.ExecutePagedQuery(
×
2457
                ctx, s.cfg.PaginationCfg, chansInfo, chanIDConverter,
×
2458
                queryWrapper, cb,
×
2459
        )
×
2460
}
2461

2462
// PruneGraphNodes is a garbage collection method which attempts to prune out
2463
// any nodes from the channel graph that are currently unconnected. This ensure
2464
// that we only maintain a graph of reachable nodes. In the event that a pruned
2465
// node gains more channels, it will be re-added back to the graph.
2466
//
2467
// NOTE: this prunes nodes across protocol versions. It will never prune the
2468
// source nodes.
2469
//
2470
// NOTE: part of the V1Store interface.
2471
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2472
        var ctx = context.TODO()
×
2473

×
2474
        var prunedNodes []route.Vertex
×
2475
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2476
                var err error
×
2477
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2478

×
2479
                return err
×
2480
        }, func() {
×
2481
                prunedNodes = nil
×
2482
        })
×
2483
        if err != nil {
×
2484
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2485
        }
×
2486

2487
        return prunedNodes, nil
×
2488
}
2489

2490
// PruneGraph prunes newly closed channels from the channel graph in response
2491
// to a new block being solved on the network. Any transactions which spend the
2492
// funding output of any known channels within he graph will be deleted.
2493
// Additionally, the "prune tip", or the last block which has been used to
2494
// prune the graph is stored so callers can ensure the graph is fully in sync
2495
// with the current UTXO state. A slice of channels that have been closed by
2496
// the target block along with any pruned nodes are returned if the function
2497
// succeeds without error.
2498
//
2499
// NOTE: part of the V1Store interface.
2500
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2501
        blockHash *chainhash.Hash, blockHeight uint32) (
2502
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2503

×
2504
        ctx := context.TODO()
×
2505

×
2506
        s.cacheMu.Lock()
×
2507
        defer s.cacheMu.Unlock()
×
2508

×
2509
        var (
×
2510
                closedChans []*models.ChannelEdgeInfo
×
2511
                prunedNodes []route.Vertex
×
2512
        )
×
2513
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2514
                var chansToDelete []int64
×
2515

×
2516
                // Define the callback function for processing each channel.
×
2517
                channelCallback := func(ctx context.Context,
×
2518
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2519

×
2520
                        node1, node2, err := buildNodeVertices(
×
2521
                                row.Node1Pubkey, row.Node2Pubkey,
×
2522
                        )
×
2523
                        if err != nil {
×
2524
                                return err
×
2525
                        }
×
2526

2527
                        info, err := getAndBuildEdgeInfo(
×
2528
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
2529
                                node1, node2,
×
2530
                        )
×
2531
                        if err != nil {
×
2532
                                return err
×
2533
                        }
×
2534

2535
                        closedChans = append(closedChans, info)
×
2536
                        chansToDelete = append(
×
2537
                                chansToDelete, row.GraphChannel.ID,
×
2538
                        )
×
2539

×
2540
                        return nil
×
2541
                }
2542

2543
                err := s.forEachChanInOutpoints(
×
2544
                        ctx, db, spentOutputs, channelCallback,
×
2545
                )
×
2546
                if err != nil {
×
2547
                        return fmt.Errorf("unable to fetch channels by "+
×
2548
                                "outpoints: %w", err)
×
2549
                }
×
2550

2551
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2552
                if err != nil {
×
2553
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2554
                }
×
2555

2556
                err = db.UpsertPruneLogEntry(
×
2557
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2558
                                BlockHash:   blockHash[:],
×
2559
                                BlockHeight: int64(blockHeight),
×
2560
                        },
×
2561
                )
×
2562
                if err != nil {
×
2563
                        return fmt.Errorf("unable to insert prune log "+
×
2564
                                "entry: %w", err)
×
2565
                }
×
2566

2567
                // Now that we've pruned some channels, we'll also prune any
2568
                // nodes that no longer have any channels.
2569
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2570
                if err != nil {
×
2571
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2572
                                err)
×
2573
                }
×
2574

2575
                return nil
×
2576
        }, func() {
×
2577
                prunedNodes = nil
×
2578
                closedChans = nil
×
2579
        })
×
2580
        if err != nil {
×
2581
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2582
        }
×
2583

2584
        for _, channel := range closedChans {
×
2585
                s.rejectCache.remove(channel.ChannelID)
×
2586
                s.chanCache.remove(channel.ChannelID)
×
2587
        }
×
2588

2589
        return closedChans, prunedNodes, nil
×
2590
}
2591

2592
// forEachChanInOutpoints is a helper function that executes a paginated
2593
// query to fetch channels by their outpoints and applies the given call-back
2594
// to each.
2595
//
2596
// NOTE: this fetches channels for all protocol versions.
2597
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2598
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2599
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2600

×
2601
        // Create a wrapper that uses the transaction's db instance to execute
×
2602
        // the query.
×
2603
        queryWrapper := func(ctx context.Context,
×
2604
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2605
                error) {
×
2606

×
2607
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2608
        }
×
2609

2610
        // Define the conversion function from Outpoint to string.
2611
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2612
                return outpoint.String()
×
2613
        }
×
2614

2615
        return sqldb.ExecutePagedQuery(
×
2616
                ctx, s.cfg.PaginationCfg, outpoints, outpointToString,
×
2617
                queryWrapper, cb,
×
2618
        )
×
2619
}
2620

2621
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2622
        dbIDs []int64) error {
×
2623

×
2624
        // Create a wrapper that uses the transaction's db instance to execute
×
2625
        // the query.
×
2626
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2627
                return nil, db.DeleteChannels(ctx, ids)
×
2628
        }
×
2629

2630
        idConverter := func(id int64) int64 {
×
2631
                return id
×
2632
        }
×
2633

2634
        return sqldb.ExecutePagedQuery(
×
2635
                ctx, s.cfg.PaginationCfg, dbIDs, idConverter,
×
2636
                queryWrapper, func(ctx context.Context, _ any) error {
×
2637
                        return nil
×
2638
                },
×
2639
        )
2640
}
2641

2642
// ChannelView returns the verifiable edge information for each active channel
2643
// within the known channel graph. The set of UTXOs (along with their scripts)
2644
// returned are the ones that need to be watched on chain to detect channel
2645
// closes on the resident blockchain.
2646
//
2647
// NOTE: part of the V1Store interface.
2648
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2649
        var (
×
2650
                ctx        = context.TODO()
×
2651
                edgePoints []EdgePoint
×
2652
        )
×
2653

×
2654
        handleChannel := func(db SQLQueries,
×
2655
                channel sqlc.ListChannelsPaginatedRow) error {
×
2656

×
2657
                pkScript, err := genMultiSigP2WSH(
×
2658
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
2659
                )
×
2660
                if err != nil {
×
2661
                        return err
×
2662
                }
×
2663

2664
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2665
                if err != nil {
×
2666
                        return err
×
2667
                }
×
2668

2669
                edgePoints = append(edgePoints, EdgePoint{
×
2670
                        FundingPkScript: pkScript,
×
2671
                        OutPoint:        *op,
×
2672
                })
×
2673

×
2674
                return nil
×
2675
        }
2676

2677
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2678
                lastID := int64(-1)
×
2679
                for {
×
2680
                        rows, err := db.ListChannelsPaginated(
×
2681
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2682
                                        Version: int16(ProtocolV1),
×
2683
                                        ID:      lastID,
×
2684
                                        Limit:   pageSize,
×
2685
                                },
×
2686
                        )
×
2687
                        if err != nil {
×
2688
                                return err
×
2689
                        }
×
2690

2691
                        if len(rows) == 0 {
×
2692
                                break
×
2693
                        }
2694

2695
                        for _, row := range rows {
×
2696
                                err := handleChannel(db, row)
×
2697
                                if err != nil {
×
2698
                                        return err
×
2699
                                }
×
2700

2701
                                lastID = row.ID
×
2702
                        }
2703
                }
2704

2705
                return nil
×
2706
        }, func() {
×
2707
                edgePoints = nil
×
2708
        })
×
2709
        if err != nil {
×
2710
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2711
        }
×
2712

2713
        return edgePoints, nil
×
2714
}
2715

2716
// PruneTip returns the block height and hash of the latest block that has been
2717
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2718
// to tell if the graph is currently in sync with the current best known UTXO
2719
// state.
2720
//
2721
// NOTE: part of the V1Store interface.
2722
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2723
        var (
×
2724
                ctx       = context.TODO()
×
2725
                tipHash   chainhash.Hash
×
2726
                tipHeight uint32
×
2727
        )
×
2728
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2729
                pruneTip, err := db.GetPruneTip(ctx)
×
2730
                if errors.Is(err, sql.ErrNoRows) {
×
2731
                        return ErrGraphNeverPruned
×
2732
                } else if err != nil {
×
2733
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2734
                }
×
2735

2736
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2737
                tipHeight = uint32(pruneTip.BlockHeight)
×
2738

×
2739
                return nil
×
2740
        }, sqldb.NoOpReset)
2741
        if err != nil {
×
2742
                return nil, 0, err
×
2743
        }
×
2744

2745
        return &tipHash, tipHeight, nil
×
2746
}
2747

2748
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2749
//
2750
// NOTE: this prunes nodes across protocol versions. It will never prune the
2751
// source nodes.
2752
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2753
        db SQLQueries) ([]route.Vertex, error) {
×
2754

×
2755
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2756
        if err != nil {
×
2757
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2758
                        "nodes: %w", err)
×
2759
        }
×
2760

2761
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2762
        for i, nodeKey := range nodeKeys {
×
2763
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2764
                if err != nil {
×
2765
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2766
                                "from bytes: %w", err)
×
2767
                }
×
2768

2769
                prunedNodes[i] = pub
×
2770
        }
2771

2772
        return prunedNodes, nil
×
2773
}
2774

2775
// DisconnectBlockAtHeight is used to indicate that the block specified
2776
// by the passed height has been disconnected from the main chain. This
2777
// will "rewind" the graph back to the height below, deleting channels
2778
// that are no longer confirmed from the graph. The prune log will be
2779
// set to the last prune height valid for the remaining chain.
2780
// Channels that were removed from the graph resulting from the
2781
// disconnected block are returned.
2782
//
2783
// NOTE: part of the V1Store interface.
2784
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2785
        []*models.ChannelEdgeInfo, error) {
×
2786

×
2787
        ctx := context.TODO()
×
2788

×
2789
        var (
×
2790
                // Every channel having a ShortChannelID starting at 'height'
×
2791
                // will no longer be confirmed.
×
2792
                startShortChanID = lnwire.ShortChannelID{
×
2793
                        BlockHeight: height,
×
2794
                }
×
2795

×
2796
                // Delete everything after this height from the db up until the
×
2797
                // SCID alias range.
×
2798
                endShortChanID = aliasmgr.StartingAlias
×
2799

×
2800
                removedChans []*models.ChannelEdgeInfo
×
2801

×
2802
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2803
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2804
        )
×
2805

×
2806
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2807
                rows, err := db.GetChannelsBySCIDRange(
×
2808
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2809
                                StartScid: chanIDStart,
×
2810
                                EndScid:   chanIDEnd,
×
2811
                        },
×
2812
                )
×
2813
                if err != nil {
×
2814
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2815
                }
×
2816

2817
                chanIDsToDelete := make([]int64, len(rows))
×
2818
                for i, row := range rows {
×
2819
                        node1, node2, err := buildNodeVertices(
×
2820
                                row.Node1PubKey, row.Node2PubKey,
×
2821
                        )
×
2822
                        if err != nil {
×
2823
                                return err
×
2824
                        }
×
2825

2826
                        channel, err := getAndBuildEdgeInfo(
×
2827
                                ctx, db, s.cfg.ChainHash, row.GraphChannel,
×
2828
                                node1, node2,
×
2829
                        )
×
2830
                        if err != nil {
×
2831
                                return err
×
2832
                        }
×
2833

2834
                        chanIDsToDelete[i] = row.GraphChannel.ID
×
2835
                        removedChans = append(removedChans, channel)
×
2836
                }
2837

2838
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2839
                if err != nil {
×
2840
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2841
                }
×
2842

2843
                return db.DeletePruneLogEntriesInRange(
×
2844
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2845
                                StartHeight: int64(height),
×
2846
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2847
                        },
×
2848
                )
×
2849
        }, func() {
×
2850
                removedChans = nil
×
2851
        })
×
2852
        if err != nil {
×
2853
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2854
                        "height: %w", err)
×
2855
        }
×
2856

2857
        for _, channel := range removedChans {
×
2858
                s.rejectCache.remove(channel.ChannelID)
×
2859
                s.chanCache.remove(channel.ChannelID)
×
2860
        }
×
2861

2862
        return removedChans, nil
×
2863
}
2864

2865
// AddEdgeProof sets the proof of an existing edge in the graph database.
2866
//
2867
// NOTE: part of the V1Store interface.
2868
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2869
        proof *models.ChannelAuthProof) error {
×
2870

×
2871
        var (
×
2872
                ctx       = context.TODO()
×
2873
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2874
        )
×
2875

×
2876
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2877
                res, err := db.AddV1ChannelProof(
×
2878
                        ctx, sqlc.AddV1ChannelProofParams{
×
2879
                                Scid:              scidBytes,
×
2880
                                Node1Signature:    proof.NodeSig1Bytes,
×
2881
                                Node2Signature:    proof.NodeSig2Bytes,
×
2882
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2883
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2884
                        },
×
2885
                )
×
2886
                if err != nil {
×
2887
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2888
                }
×
2889

2890
                n, err := res.RowsAffected()
×
2891
                if err != nil {
×
2892
                        return err
×
2893
                }
×
2894

2895
                if n == 0 {
×
2896
                        return fmt.Errorf("no rows affected when adding edge "+
×
2897
                                "proof for SCID %v", scid)
×
2898
                } else if n > 1 {
×
2899
                        return fmt.Errorf("multiple rows affected when adding "+
×
2900
                                "edge proof for SCID %v: %d rows affected",
×
2901
                                scid, n)
×
2902
                }
×
2903

2904
                return nil
×
2905
        }, sqldb.NoOpReset)
2906
        if err != nil {
×
2907
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2908
        }
×
2909

2910
        return nil
×
2911
}
2912

2913
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2914
// that we can ignore channel announcements that we know to be closed without
2915
// having to validate them and fetch a block.
2916
//
2917
// NOTE: part of the V1Store interface.
2918
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2919
        var (
×
2920
                ctx     = context.TODO()
×
2921
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2922
        )
×
2923

×
2924
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2925
                return db.InsertClosedChannel(ctx, chanIDB)
×
2926
        }, sqldb.NoOpReset)
×
2927
}
2928

2929
// IsClosedScid checks whether a channel identified by the passed in scid is
2930
// closed. This helps avoid having to perform expensive validation checks.
2931
//
2932
// NOTE: part of the V1Store interface.
2933
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2934
        var (
×
2935
                ctx      = context.TODO()
×
2936
                isClosed bool
×
2937
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2938
        )
×
2939
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2940
                var err error
×
2941
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2942
                if err != nil {
×
2943
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2944
                                err)
×
2945
                }
×
2946

2947
                return nil
×
2948
        }, sqldb.NoOpReset)
2949
        if err != nil {
×
2950
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2951
                        err)
×
2952
        }
×
2953

2954
        return isClosed, nil
×
2955
}
2956

2957
// GraphSession will provide the call-back with access to a NodeTraverser
2958
// instance which can be used to perform queries against the channel graph.
2959
//
2960
// NOTE: part of the V1Store interface.
2961
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
2962
        reset func()) error {
×
2963

×
2964
        var ctx = context.TODO()
×
2965

×
2966
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2967
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2968
        }, reset)
×
2969
}
2970

2971
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2972
// read only transaction for a consistent view of the graph.
2973
type sqlNodeTraverser struct {
2974
        db    SQLQueries
2975
        chain chainhash.Hash
2976
}
2977

2978
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2979
// NodeTraverser interface.
2980
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2981

2982
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2983
func newSQLNodeTraverser(db SQLQueries,
2984
        chain chainhash.Hash) *sqlNodeTraverser {
×
2985

×
2986
        return &sqlNodeTraverser{
×
2987
                db:    db,
×
2988
                chain: chain,
×
2989
        }
×
2990
}
×
2991

2992
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2993
// node.
2994
//
2995
// NOTE: Part of the NodeTraverser interface.
2996
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2997
        cb func(channel *DirectedChannel) error, _ func()) error {
×
2998

×
2999
        ctx := context.TODO()
×
3000

×
3001
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
3002
}
×
3003

3004
// FetchNodeFeatures returns the features of the given node. If the node is
3005
// unknown, assume no additional features are supported.
3006
//
3007
// NOTE: Part of the NodeTraverser interface.
3008
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
3009
        *lnwire.FeatureVector, error) {
×
3010

×
3011
        ctx := context.TODO()
×
3012

×
3013
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
3014
}
×
3015

3016
// forEachNodeDirectedChannel iterates through all channels of a given
3017
// node, executing the passed callback on the directed edge representing the
3018
// channel and its incoming policy. If the node is not found, no error is
3019
// returned.
3020
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
3021
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
3022

×
3023
        toNodeCallback := func() route.Vertex {
×
3024
                return nodePub
×
3025
        }
×
3026

3027
        dbID, err := db.GetNodeIDByPubKey(
×
3028
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
3029
                        Version: int16(ProtocolV1),
×
3030
                        PubKey:  nodePub[:],
×
3031
                },
×
3032
        )
×
3033
        if errors.Is(err, sql.ErrNoRows) {
×
3034
                return nil
×
3035
        } else if err != nil {
×
3036
                return fmt.Errorf("unable to fetch node: %w", err)
×
3037
        }
×
3038

3039
        rows, err := db.ListChannelsByNodeID(
×
3040
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3041
                        Version: int16(ProtocolV1),
×
3042
                        NodeID1: dbID,
×
3043
                },
×
3044
        )
×
3045
        if err != nil {
×
3046
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3047
        }
×
3048

3049
        // Exit early if there are no channels for this node so we don't
3050
        // do the unnecessary feature fetching.
3051
        if len(rows) == 0 {
×
3052
                return nil
×
3053
        }
×
3054

3055
        features, err := getNodeFeatures(ctx, db, dbID)
×
3056
        if err != nil {
×
3057
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3058
        }
×
3059

3060
        for _, row := range rows {
×
3061
                node1, node2, err := buildNodeVertices(
×
3062
                        row.Node1Pubkey, row.Node2Pubkey,
×
3063
                )
×
3064
                if err != nil {
×
3065
                        return fmt.Errorf("unable to build node vertices: %w",
×
3066
                                err)
×
3067
                }
×
3068

3069
                edge := buildCacheableChannelInfo(
×
3070
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
3071
                        node1, node2,
×
3072
                )
×
3073

×
3074
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3075
                if err != nil {
×
3076
                        return err
×
3077
                }
×
3078

3079
                var p1, p2 *models.CachedEdgePolicy
×
3080
                if dbPol1 != nil {
×
3081
                        policy1, err := buildChanPolicy(
×
3082
                                *dbPol1, edge.ChannelID, nil, node2,
×
3083
                        )
×
3084
                        if err != nil {
×
3085
                                return err
×
3086
                        }
×
3087

3088
                        p1 = models.NewCachedPolicy(policy1)
×
3089
                }
3090
                if dbPol2 != nil {
×
3091
                        policy2, err := buildChanPolicy(
×
3092
                                *dbPol2, edge.ChannelID, nil, node1,
×
3093
                        )
×
3094
                        if err != nil {
×
3095
                                return err
×
3096
                        }
×
3097

3098
                        p2 = models.NewCachedPolicy(policy2)
×
3099
                }
3100

3101
                // Determine the outgoing and incoming policy for this
3102
                // channel and node combo.
3103
                outPolicy, inPolicy := p1, p2
×
3104
                if p1 != nil && node2 == nodePub {
×
3105
                        outPolicy, inPolicy = p2, p1
×
3106
                } else if p2 != nil && node1 != nodePub {
×
3107
                        outPolicy, inPolicy = p2, p1
×
3108
                }
×
3109

3110
                var cachedInPolicy *models.CachedEdgePolicy
×
3111
                if inPolicy != nil {
×
3112
                        cachedInPolicy = inPolicy
×
3113
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3114
                        cachedInPolicy.ToNodeFeatures = features
×
3115
                }
×
3116

3117
                directedChannel := &DirectedChannel{
×
3118
                        ChannelID:    edge.ChannelID,
×
3119
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3120
                        OtherNode:    edge.NodeKey2Bytes,
×
3121
                        Capacity:     edge.Capacity,
×
3122
                        OutPolicySet: outPolicy != nil,
×
3123
                        InPolicy:     cachedInPolicy,
×
3124
                }
×
3125
                if outPolicy != nil {
×
3126
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3127
                                directedChannel.InboundFee = fee
×
3128
                        })
×
3129
                }
3130

3131
                if nodePub == edge.NodeKey2Bytes {
×
3132
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3133
                }
×
3134

3135
                if err := cb(directedChannel); err != nil {
×
3136
                        return err
×
3137
                }
×
3138
        }
3139

3140
        return nil
×
3141
}
3142

3143
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3144
// and executes the provided callback for each node.
3145
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
3146
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
3147

×
3148
        lastID := int64(-1)
×
3149

×
3150
        for {
×
3151
                nodes, err := db.ListNodeIDsAndPubKeys(
×
3152
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3153
                                Version: int16(ProtocolV1),
×
3154
                                ID:      lastID,
×
3155
                                Limit:   pageSize,
×
3156
                        },
×
3157
                )
×
3158
                if err != nil {
×
3159
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
3160
                }
×
3161

3162
                if len(nodes) == 0 {
×
3163
                        break
×
3164
                }
3165

3166
                for _, node := range nodes {
×
3167
                        var pub route.Vertex
×
3168
                        copy(pub[:], node.PubKey)
×
3169

×
3170
                        if err := cb(node.ID, pub); err != nil {
×
3171
                                return fmt.Errorf("forEachNodeCacheable "+
×
3172
                                        "callback failed for node(id=%d): %w",
×
3173
                                        node.ID, err)
×
3174
                        }
×
3175

3176
                        lastID = node.ID
×
3177
                }
3178
        }
3179

3180
        return nil
×
3181
}
3182

3183
// forEachNodeChannel iterates through all channels of a node, executing
3184
// the passed callback on each. The call-back is provided with the channel's
3185
// edge information, the outgoing policy and the incoming policy for the
3186
// channel and node combo.
3187
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3188
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
3189
                *models.ChannelEdgePolicy,
3190
                *models.ChannelEdgePolicy) error) error {
×
3191

×
3192
        // Get all the V1 channels for this node.Add commentMore actions
×
3193
        rows, err := db.ListChannelsByNodeID(
×
3194
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3195
                        Version: int16(ProtocolV1),
×
3196
                        NodeID1: id,
×
3197
                },
×
3198
        )
×
3199
        if err != nil {
×
3200
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3201
        }
×
3202

3203
        // Call the call-back for each channel and its known policies.
3204
        for _, row := range rows {
×
3205
                node1, node2, err := buildNodeVertices(
×
3206
                        row.Node1Pubkey, row.Node2Pubkey,
×
3207
                )
×
3208
                if err != nil {
×
3209
                        return fmt.Errorf("unable to build node vertices: %w",
×
3210
                                err)
×
3211
                }
×
3212

3213
                edge, err := getAndBuildEdgeInfo(
×
3214
                        ctx, db, chain, row.GraphChannel, node1, node2,
×
3215
                )
×
3216
                if err != nil {
×
3217
                        return fmt.Errorf("unable to build channel info: %w",
×
3218
                                err)
×
3219
                }
×
3220

3221
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3222
                if err != nil {
×
3223
                        return fmt.Errorf("unable to extract channel "+
×
3224
                                "policies: %w", err)
×
3225
                }
×
3226

3227
                p1, p2, err := getAndBuildChanPolicies(
×
3228
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3229
                )
×
3230
                if err != nil {
×
3231
                        return fmt.Errorf("unable to build channel "+
×
3232
                                "policies: %w", err)
×
3233
                }
×
3234

3235
                // Determine the outgoing and incoming policy for this
3236
                // channel and node combo.
3237
                p1ToNode := row.GraphChannel.NodeID2
×
3238
                p2ToNode := row.GraphChannel.NodeID1
×
3239
                outPolicy, inPolicy := p1, p2
×
3240
                if (p1 != nil && p1ToNode == id) ||
×
3241
                        (p2 != nil && p2ToNode != id) {
×
3242

×
3243
                        outPolicy, inPolicy = p2, p1
×
3244
                }
×
3245

3246
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3247
                        return err
×
3248
                }
×
3249
        }
3250

3251
        return nil
×
3252
}
3253

3254
// updateChanEdgePolicy upserts the channel policy info we have stored for
3255
// a channel we already know of.
3256
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3257
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3258
        error) {
×
3259

×
3260
        var (
×
3261
                node1Pub, node2Pub route.Vertex
×
3262
                isNode1            bool
×
3263
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3264
        )
×
3265

×
3266
        // Check that this edge policy refers to a channel that we already
×
3267
        // know of. We do this explicitly so that we can return the appropriate
×
3268
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3269
        // abort the transaction which would abort the entire batch.
×
3270
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3271
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3272
                        Scid:    chanIDB,
×
3273
                        Version: int16(ProtocolV1),
×
3274
                },
×
3275
        )
×
3276
        if errors.Is(err, sql.ErrNoRows) {
×
3277
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3278
        } else if err != nil {
×
3279
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3280
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3281
        }
×
3282

3283
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3284
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3285

×
3286
        // Figure out which node this edge is from.
×
3287
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3288
        nodeID := dbChan.NodeID1
×
3289
        if !isNode1 {
×
3290
                nodeID = dbChan.NodeID2
×
3291
        }
×
3292

3293
        var (
×
3294
                inboundBase sql.NullInt64
×
3295
                inboundRate sql.NullInt64
×
3296
        )
×
3297
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3298
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3299
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3300
        })
×
3301

3302
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3303
                Version:     int16(ProtocolV1),
×
3304
                ChannelID:   dbChan.ID,
×
3305
                NodeID:      nodeID,
×
3306
                Timelock:    int32(edge.TimeLockDelta),
×
3307
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3308
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3309
                MinHtlcMsat: int64(edge.MinHTLC),
×
3310
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3311
                Disabled: sql.NullBool{
×
3312
                        Valid: true,
×
3313
                        Bool:  edge.IsDisabled(),
×
3314
                },
×
3315
                MaxHtlcMsat: sql.NullInt64{
×
3316
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3317
                        Int64: int64(edge.MaxHTLC),
×
3318
                },
×
3319
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3320
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3321
                InboundBaseFeeMsat:      inboundBase,
×
3322
                InboundFeeRateMilliMsat: inboundRate,
×
3323
                Signature:               edge.SigBytes,
×
3324
        })
×
3325
        if err != nil {
×
3326
                return node1Pub, node2Pub, isNode1,
×
3327
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3328
        }
×
3329

3330
        // Convert the flat extra opaque data into a map of TLV types to
3331
        // values.
3332
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3333
        if err != nil {
×
3334
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3335
                        "marshal extra opaque data: %w", err)
×
3336
        }
×
3337

3338
        // Update the channel policy's extra signed fields.
3339
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3340
        if err != nil {
×
3341
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3342
                        "policy extra TLVs: %w", err)
×
3343
        }
×
3344

3345
        return node1Pub, node2Pub, isNode1, nil
×
3346
}
3347

3348
// getNodeByPubKey attempts to look up a target node by its public key.
3349
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3350
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3351

×
3352
        dbNode, err := db.GetNodeByPubKey(
×
3353
                ctx, sqlc.GetNodeByPubKeyParams{
×
3354
                        Version: int16(ProtocolV1),
×
3355
                        PubKey:  pubKey[:],
×
3356
                },
×
3357
        )
×
3358
        if errors.Is(err, sql.ErrNoRows) {
×
3359
                return 0, nil, ErrGraphNodeNotFound
×
3360
        } else if err != nil {
×
3361
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3362
        }
×
3363

3364
        node, err := buildNode(ctx, db, &dbNode)
×
3365
        if err != nil {
×
3366
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3367
        }
×
3368

3369
        return dbNode.ID, node, nil
×
3370
}
3371

3372
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3373
// provided parameters.
3374
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3375
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3376

×
3377
        return &models.CachedEdgeInfo{
×
3378
                ChannelID:     byteOrder.Uint64(scid),
×
3379
                NodeKey1Bytes: node1Pub,
×
3380
                NodeKey2Bytes: node2Pub,
×
3381
                Capacity:      btcutil.Amount(capacity),
×
3382
        }
×
3383
}
×
3384

3385
// buildNode constructs a LightningNode instance from the given database node
3386
// record. The node's features, addresses and extra signed fields are also
3387
// fetched from the database and set on the node.
3388
func buildNode(ctx context.Context, db SQLQueries,
3389
        dbNode *sqlc.GraphNode) (*models.LightningNode, error) {
×
3390

×
3391
        // NOTE: buildNode is only used to load the data for a single node, and
×
3392
        // so no paged queries will be performed. This means that it's ok to
×
3393
        // used pass in default config values here.
×
3394
        cfg := sqldb.DefaultPagedQueryConfig()
×
3395

×
3396
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3397
        if err != nil {
×
3398
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3399
                        err)
×
3400
        }
×
3401

3402
        return buildNodeWithBatchData(dbNode, data)
×
3403
}
3404

3405
// buildNodeWithBatchData builds a models.LightningNode instance
3406
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3407
// features/addresses/extra fields, then the corresponding fields are expected
3408
// to be present in the batchNodeData.
3409
func buildNodeWithBatchData(dbNode *sqlc.GraphNode,
3410
        batchData *batchNodeData) (*models.LightningNode, error) {
×
3411

×
3412
        if dbNode.Version != int16(ProtocolV1) {
×
3413
                return nil, fmt.Errorf("unsupported node version: %d",
×
3414
                        dbNode.Version)
×
3415
        }
×
3416

3417
        var pub [33]byte
×
3418
        copy(pub[:], dbNode.PubKey)
×
3419

×
3420
        node := &models.LightningNode{
×
3421
                PubKeyBytes: pub,
×
3422
                Features:    lnwire.EmptyFeatureVector(),
×
3423
                LastUpdate:  time.Unix(0, 0),
×
3424
        }
×
3425

×
3426
        if len(dbNode.Signature) == 0 {
×
3427
                return node, nil
×
3428
        }
×
3429

3430
        node.HaveNodeAnnouncement = true
×
3431
        node.AuthSigBytes = dbNode.Signature
×
3432
        node.Alias = dbNode.Alias.String
×
3433
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3434

×
3435
        var err error
×
3436
        if dbNode.Color.Valid {
×
3437
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3438
                if err != nil {
×
3439
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3440
                                err)
×
3441
                }
×
3442
        }
3443

3444
        // Use preloaded features.
3445
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3446
                fv := lnwire.EmptyFeatureVector()
×
3447
                for _, bit := range features {
×
3448
                        fv.Set(lnwire.FeatureBit(bit))
×
3449
                }
×
3450
                node.Features = fv
×
3451
        }
3452

3453
        // Use preloaded addresses.
3454
        addresses, exists := batchData.addresses[dbNode.ID]
×
3455
        if exists && len(addresses) > 0 {
×
3456
                node.Addresses, err = buildNodeAddresses(addresses)
×
3457
                if err != nil {
×
3458
                        return nil, fmt.Errorf("unable to build addresses "+
×
3459
                                "for node(%d): %w", dbNode.ID, err)
×
3460
                }
×
3461
        }
3462

3463
        // Use preloaded extra fields.
3464
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3465
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3466
                if err != nil {
×
3467
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3468
                                "signed fields: %w", err)
×
3469
                }
×
3470
                if len(recs) != 0 {
×
3471
                        node.ExtraOpaqueData = recs
×
3472
                }
×
3473
        }
3474

3475
        return node, nil
×
3476
}
3477

3478
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3479
// with the preloaded data, and executes the provided callback for each node.
3480
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.PagedQueryConfig,
3481
        db SQLQueries, nodes []sqlc.GraphNode,
3482
        cb func(dbID int64, node *models.LightningNode) error) error {
×
3483

×
3484
        // Extract node IDs for batch loading.
×
3485
        nodeIDs := make([]int64, len(nodes))
×
3486
        for i, node := range nodes {
×
3487
                nodeIDs[i] = node.ID
×
3488
        }
×
3489

3490
        // Batch load all related data for this page.
3491
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3492
        if err != nil {
×
3493
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3494
        }
×
3495

3496
        for _, dbNode := range nodes {
×
3497
                node, err := buildNodeWithBatchData(&dbNode, batchData)
×
3498
                if err != nil {
×
3499
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3500
                                dbNode.ID, err)
×
3501
                }
×
3502

3503
                if err := cb(dbNode.ID, node); err != nil {
×
3504
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3505
                                dbNode.ID, err)
×
3506
                }
×
3507
        }
3508

3509
        return nil
×
3510
}
3511

3512
// getNodeFeatures fetches the feature bits and constructs the feature vector
3513
// for a node with the given DB ID.
3514
func getNodeFeatures(ctx context.Context, db SQLQueries,
3515
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3516

×
3517
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3518
        if err != nil {
×
3519
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3520
                        nodeID, err)
×
3521
        }
×
3522

3523
        features := lnwire.EmptyFeatureVector()
×
3524
        for _, feature := range rows {
×
3525
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3526
        }
×
3527

3528
        return features, nil
×
3529
}
3530

3531
// upsertNode upserts the node record into the database. If the node already
3532
// exists, then the node's information is updated. If the node doesn't exist,
3533
// then a new node is created. The node's features, addresses and extra TLV
3534
// types are also updated. The node's DB ID is returned.
3535
func upsertNode(ctx context.Context, db SQLQueries,
3536
        node *models.LightningNode) (int64, error) {
×
3537

×
3538
        params := sqlc.UpsertNodeParams{
×
3539
                Version: int16(ProtocolV1),
×
3540
                PubKey:  node.PubKeyBytes[:],
×
3541
        }
×
3542

×
3543
        if node.HaveNodeAnnouncement {
×
3544
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3545
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3546
                params.Alias = sqldb.SQLStr(node.Alias)
×
3547
                params.Signature = node.AuthSigBytes
×
3548
        }
×
3549

3550
        nodeID, err := db.UpsertNode(ctx, params)
×
3551
        if err != nil {
×
3552
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3553
                        err)
×
3554
        }
×
3555

3556
        // We can exit here if we don't have the announcement yet.
3557
        if !node.HaveNodeAnnouncement {
×
3558
                return nodeID, nil
×
3559
        }
×
3560

3561
        // Update the node's features.
3562
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3563
        if err != nil {
×
3564
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3565
        }
×
3566

3567
        // Update the node's addresses.
3568
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3569
        if err != nil {
×
3570
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3571
        }
×
3572

3573
        // Convert the flat extra opaque data into a map of TLV types to
3574
        // values.
3575
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3576
        if err != nil {
×
3577
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3578
                        err)
×
3579
        }
×
3580

3581
        // Update the node's extra signed fields.
3582
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3583
        if err != nil {
×
3584
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3585
        }
×
3586

3587
        return nodeID, nil
×
3588
}
3589

3590
// upsertNodeFeatures updates the node's features node_features table. This
3591
// includes deleting any feature bits no longer present and inserting any new
3592
// feature bits. If the feature bit does not yet exist in the features table,
3593
// then an entry is created in that table first.
3594
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3595
        features *lnwire.FeatureVector) error {
×
3596

×
3597
        // Get any existing features for the node.
×
3598
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3599
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3600
                return err
×
3601
        }
×
3602

3603
        // Copy the nodes latest set of feature bits.
3604
        newFeatures := make(map[int32]struct{})
×
3605
        if features != nil {
×
3606
                for feature := range features.Features() {
×
3607
                        newFeatures[int32(feature)] = struct{}{}
×
3608
                }
×
3609
        }
3610

3611
        // For any current feature that already exists in the DB, remove it from
3612
        // the in-memory map. For any existing feature that does not exist in
3613
        // the in-memory map, delete it from the database.
3614
        for _, feature := range existingFeatures {
×
3615
                // The feature is still present, so there are no updates to be
×
3616
                // made.
×
3617
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3618
                        delete(newFeatures, feature.FeatureBit)
×
3619
                        continue
×
3620
                }
3621

3622
                // The feature is no longer present, so we remove it from the
3623
                // database.
3624
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3625
                        NodeID:     nodeID,
×
3626
                        FeatureBit: feature.FeatureBit,
×
3627
                })
×
3628
                if err != nil {
×
3629
                        return fmt.Errorf("unable to delete node(%d) "+
×
3630
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3631
                                err)
×
3632
                }
×
3633
        }
3634

3635
        // Any remaining entries in newFeatures are new features that need to be
3636
        // added to the database for the first time.
3637
        for feature := range newFeatures {
×
3638
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3639
                        NodeID:     nodeID,
×
3640
                        FeatureBit: feature,
×
3641
                })
×
3642
                if err != nil {
×
3643
                        return fmt.Errorf("unable to insert node(%d) "+
×
3644
                                "feature(%v): %w", nodeID, feature, err)
×
3645
                }
×
3646
        }
3647

3648
        return nil
×
3649
}
3650

3651
// fetchNodeFeatures fetches the features for a node with the given public key.
3652
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3653
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3654

×
3655
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3656
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3657
                        PubKey:  nodePub[:],
×
3658
                        Version: int16(ProtocolV1),
×
3659
                },
×
3660
        )
×
3661
        if err != nil {
×
3662
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3663
                        nodePub, err)
×
3664
        }
×
3665

3666
        features := lnwire.EmptyFeatureVector()
×
3667
        for _, bit := range rows {
×
3668
                features.Set(lnwire.FeatureBit(bit))
×
3669
        }
×
3670

3671
        return features, nil
×
3672
}
3673

3674
// dbAddressType is an enum type that represents the different address types
3675
// that we store in the node_addresses table. The address type determines how
3676
// the address is to be serialised/deserialize.
3677
type dbAddressType uint8
3678

3679
const (
3680
        addressTypeIPv4   dbAddressType = 1
3681
        addressTypeIPv6   dbAddressType = 2
3682
        addressTypeTorV2  dbAddressType = 3
3683
        addressTypeTorV3  dbAddressType = 4
3684
        addressTypeDNS    dbAddressType = 5
3685
        addressTypeOpaque dbAddressType = math.MaxInt8
3686
)
3687

3688
// upsertNodeAddresses updates the node's addresses in the database. This
3689
// includes deleting any existing addresses and inserting the new set of
3690
// addresses. The deletion is necessary since the ordering of the addresses may
3691
// change, and we need to ensure that the database reflects the latest set of
3692
// addresses so that at the time of reconstructing the node announcement, the
3693
// order is preserved and the signature over the message remains valid.
3694
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3695
        addresses []net.Addr) error {
×
3696

×
3697
        // Delete any existing addresses for the node. This is required since
×
3698
        // even if the new set of addresses is the same, the ordering may have
×
3699
        // changed for a given address type.
×
3700
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3701
        if err != nil {
×
3702
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3703
                        nodeID, err)
×
3704
        }
×
3705

3706
        // Copy the nodes latest set of addresses.
3707
        newAddresses := map[dbAddressType][]string{
×
3708
                addressTypeIPv4:   {},
×
3709
                addressTypeIPv6:   {},
×
3710
                addressTypeTorV2:  {},
×
3711
                addressTypeTorV3:  {},
×
NEW
3712
                addressTypeDNS:    {},
×
3713
                addressTypeOpaque: {},
×
3714
        }
×
3715
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3716
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3717
        }
×
3718

3719
        for _, address := range addresses {
×
3720
                switch addr := address.(type) {
×
3721
                case *net.TCPAddr:
×
3722
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3723
                                addAddr(addressTypeIPv4, addr)
×
3724
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3725
                                addAddr(addressTypeIPv6, addr)
×
3726
                        } else {
×
3727
                                return fmt.Errorf("unhandled IP address: %v",
×
3728
                                        addr)
×
3729
                        }
×
3730

3731
                case *tor.OnionAddr:
×
3732
                        switch len(addr.OnionService) {
×
3733
                        case tor.V2Len:
×
3734
                                addAddr(addressTypeTorV2, addr)
×
3735
                        case tor.V3Len:
×
3736
                                addAddr(addressTypeTorV3, addr)
×
3737
                        default:
×
3738
                                return fmt.Errorf("invalid length for a tor " +
×
3739
                                        "address")
×
3740
                        }
3741

NEW
3742
                case *lnwire.DNSAddr:
×
NEW
3743
                        // Validate it is a valid DNS address.
×
NEW
3744
                        if err := addr.Validate(); err != nil {
×
NEW
3745
                                return err
×
NEW
3746
                        }
×
NEW
3747
                        addAddr(addressTypeDNS, addr)
×
3748

3749
                case *lnwire.OpaqueAddrs:
×
3750
                        addAddr(addressTypeOpaque, addr)
×
3751

3752
                default:
×
3753
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3754
                }
3755
        }
3756

3757
        // Any remaining entries in newAddresses are new addresses that need to
3758
        // be added to the database for the first time.
3759
        for addrType, addrList := range newAddresses {
×
3760
                for position, addr := range addrList {
×
3761
                        err := db.InsertNodeAddress(
×
3762
                                ctx, sqlc.InsertNodeAddressParams{
×
3763
                                        NodeID:   nodeID,
×
3764
                                        Type:     int16(addrType),
×
3765
                                        Address:  addr,
×
3766
                                        Position: int32(position),
×
3767
                                },
×
3768
                        )
×
3769
                        if err != nil {
×
3770
                                return fmt.Errorf("unable to insert "+
×
3771
                                        "node(%d) address(%v): %w", nodeID,
×
3772
                                        addr, err)
×
3773
                        }
×
3774
                }
3775
        }
3776

3777
        return nil
×
3778
}
3779

3780
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3781
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
3782
        error) {
×
3783

×
3784
        // GetNodeAddresses ensures that the addresses for a given type are
×
3785
        // returned in the same order as they were inserted.
×
3786
        rows, err := db.GetNodeAddresses(ctx, id)
×
3787
        if err != nil {
×
3788
                return nil, err
×
3789
        }
×
3790

3791
        addresses := make([]net.Addr, 0, len(rows))
×
3792
        for _, row := range rows {
×
3793
                address := row.Address
×
3794

×
3795
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
3796
                if err != nil {
×
3797
                        return nil, fmt.Errorf("unable to parse address "+
×
3798
                                "for node(%d): %v: %w", id, address, err)
×
3799
                }
×
3800

3801
                addresses = append(addresses, addr)
×
3802
        }
3803

3804
        // If we have no addresses, then we'll return nil instead of an
3805
        // empty slice.
3806
        if len(addresses) == 0 {
×
3807
                addresses = nil
×
3808
        }
×
3809

3810
        return addresses, nil
×
3811
}
3812

3813
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3814
// database. This includes updating any existing types, inserting any new types,
3815
// and deleting any types that are no longer present.
3816
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3817
        nodeID int64, extraFields map[uint64][]byte) error {
×
3818

×
3819
        // Get any existing extra signed fields for the node.
×
3820
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3821
        if err != nil {
×
3822
                return err
×
3823
        }
×
3824

3825
        // Make a lookup map of the existing field types so that we can use it
3826
        // to keep track of any fields we should delete.
3827
        m := make(map[uint64]bool)
×
3828
        for _, field := range existingFields {
×
3829
                m[uint64(field.Type)] = true
×
3830
        }
×
3831

3832
        // For all the new fields, we'll upsert them and remove them from the
3833
        // map of existing fields.
3834
        for tlvType, value := range extraFields {
×
3835
                err = db.UpsertNodeExtraType(
×
3836
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3837
                                NodeID: nodeID,
×
3838
                                Type:   int64(tlvType),
×
3839
                                Value:  value,
×
3840
                        },
×
3841
                )
×
3842
                if err != nil {
×
3843
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3844
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3845
                }
×
3846

3847
                // Remove the field from the map of existing fields if it was
3848
                // present.
3849
                delete(m, tlvType)
×
3850
        }
3851

3852
        // For all the fields that are left in the map of existing fields, we'll
3853
        // delete them as they are no longer present in the new set of fields.
3854
        for tlvType := range m {
×
3855
                err = db.DeleteExtraNodeType(
×
3856
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3857
                                NodeID: nodeID,
×
3858
                                Type:   int64(tlvType),
×
3859
                        },
×
3860
                )
×
3861
                if err != nil {
×
3862
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3863
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3864
                }
×
3865
        }
3866

3867
        return nil
×
3868
}
3869

3870
// srcNodeInfo holds the information about the source node of the graph.
3871
type srcNodeInfo struct {
3872
        // id is the DB level ID of the source node entry in the "nodes" table.
3873
        id int64
3874

3875
        // pub is the public key of the source node.
3876
        pub route.Vertex
3877
}
3878

3879
// sourceNode returns the DB node ID and pub key of the source node for the
3880
// specified protocol version.
3881
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3882
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3883

×
3884
        s.srcNodeMu.Lock()
×
3885
        defer s.srcNodeMu.Unlock()
×
3886

×
3887
        // If we already have the source node ID and pub key cached, then
×
3888
        // return them.
×
3889
        if info, ok := s.srcNodes[version]; ok {
×
3890
                return info.id, info.pub, nil
×
3891
        }
×
3892

3893
        var pubKey route.Vertex
×
3894

×
3895
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3896
        if err != nil {
×
3897
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3898
                        err)
×
3899
        }
×
3900

3901
        if len(nodes) == 0 {
×
3902
                return 0, pubKey, ErrSourceNodeNotSet
×
3903
        } else if len(nodes) > 1 {
×
3904
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3905
                        "protocol %s found", version)
×
3906
        }
×
3907

3908
        copy(pubKey[:], nodes[0].PubKey)
×
3909

×
3910
        s.srcNodes[version] = &srcNodeInfo{
×
3911
                id:  nodes[0].NodeID,
×
3912
                pub: pubKey,
×
3913
        }
×
3914

×
3915
        return nodes[0].NodeID, pubKey, nil
×
3916
}
3917

3918
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3919
// This then produces a map from TLV type to value. If the input is not a
3920
// valid TLV stream, then an error is returned.
3921
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3922
        r := bytes.NewReader(data)
×
3923

×
3924
        tlvStream, err := tlv.NewStream()
×
3925
        if err != nil {
×
3926
                return nil, err
×
3927
        }
×
3928

3929
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3930
        // pass it into the P2P decoding variant.
3931
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3932
        if err != nil {
×
3933
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3934
        }
×
3935
        if len(parsedTypes) == 0 {
×
3936
                return nil, nil
×
3937
        }
×
3938

3939
        records := make(map[uint64][]byte)
×
3940
        for k, v := range parsedTypes {
×
3941
                records[uint64(k)] = v
×
3942
        }
×
3943

3944
        return records, nil
×
3945
}
3946

3947
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3948
// channel.
3949
type dbChanInfo struct {
3950
        channelID int64
3951
        node1ID   int64
3952
        node2ID   int64
3953
}
3954

3955
// insertChannel inserts a new channel record into the database.
3956
func insertChannel(ctx context.Context, db SQLQueries,
3957
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3958

×
3959
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3960

×
3961
        // Make sure that the channel doesn't already exist. We do this
×
3962
        // explicitly instead of relying on catching a unique constraint error
×
3963
        // because relying on SQL to throw that error would abort the entire
×
3964
        // batch of transactions.
×
3965
        _, err := db.GetChannelBySCID(
×
3966
                ctx, sqlc.GetChannelBySCIDParams{
×
3967
                        Scid:    chanIDB,
×
3968
                        Version: int16(ProtocolV1),
×
3969
                },
×
3970
        )
×
3971
        if err == nil {
×
3972
                return nil, ErrEdgeAlreadyExist
×
3973
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3974
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3975
        }
×
3976

3977
        // Make sure that at least a "shell" entry for each node is present in
3978
        // the nodes table.
3979
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3980
        if err != nil {
×
3981
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3982
        }
×
3983

3984
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3985
        if err != nil {
×
3986
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3987
        }
×
3988

3989
        var capacity sql.NullInt64
×
3990
        if edge.Capacity != 0 {
×
3991
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3992
        }
×
3993

3994
        createParams := sqlc.CreateChannelParams{
×
3995
                Version:     int16(ProtocolV1),
×
3996
                Scid:        chanIDB,
×
3997
                NodeID1:     node1DBID,
×
3998
                NodeID2:     node2DBID,
×
3999
                Outpoint:    edge.ChannelPoint.String(),
×
4000
                Capacity:    capacity,
×
4001
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
4002
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
4003
        }
×
4004

×
4005
        if edge.AuthProof != nil {
×
4006
                proof := edge.AuthProof
×
4007

×
4008
                createParams.Node1Signature = proof.NodeSig1Bytes
×
4009
                createParams.Node2Signature = proof.NodeSig2Bytes
×
4010
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
4011
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
4012
        }
×
4013

4014
        // Insert the new channel record.
4015
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
4016
        if err != nil {
×
4017
                return nil, err
×
4018
        }
×
4019

4020
        // Insert any channel features.
4021
        for feature := range edge.Features.Features() {
×
4022
                err = db.InsertChannelFeature(
×
4023
                        ctx, sqlc.InsertChannelFeatureParams{
×
4024
                                ChannelID:  dbChanID,
×
4025
                                FeatureBit: int32(feature),
×
4026
                        },
×
4027
                )
×
4028
                if err != nil {
×
4029
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
4030
                                "feature(%v): %w", dbChanID, feature, err)
×
4031
                }
×
4032
        }
4033

4034
        // Finally, insert any extra TLV fields in the channel announcement.
4035
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
4036
        if err != nil {
×
4037
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
4038
                        "data: %w", err)
×
4039
        }
×
4040

4041
        for tlvType, value := range extra {
×
4042
                err := db.CreateChannelExtraType(
×
4043
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
4044
                                ChannelID: dbChanID,
×
4045
                                Type:      int64(tlvType),
×
4046
                                Value:     value,
×
4047
                        },
×
4048
                )
×
4049
                if err != nil {
×
4050
                        return nil, fmt.Errorf("unable to upsert "+
×
4051
                                "channel(%d) extra signed field(%v): %w",
×
4052
                                edge.ChannelID, tlvType, err)
×
4053
                }
×
4054
        }
4055

4056
        return &dbChanInfo{
×
4057
                channelID: dbChanID,
×
4058
                node1ID:   node1DBID,
×
4059
                node2ID:   node2DBID,
×
4060
        }, nil
×
4061
}
4062

4063
// maybeCreateShellNode checks if a shell node entry exists for the
4064
// given public key. If it does not exist, then a new shell node entry is
4065
// created. The ID of the node is returned. A shell node only has a protocol
4066
// version and public key persisted.
4067
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
4068
        pubKey route.Vertex) (int64, error) {
×
4069

×
4070
        dbNode, err := db.GetNodeByPubKey(
×
4071
                ctx, sqlc.GetNodeByPubKeyParams{
×
4072
                        PubKey:  pubKey[:],
×
4073
                        Version: int16(ProtocolV1),
×
4074
                },
×
4075
        )
×
4076
        // The node exists. Return the ID.
×
4077
        if err == nil {
×
4078
                return dbNode.ID, nil
×
4079
        } else if !errors.Is(err, sql.ErrNoRows) {
×
4080
                return 0, err
×
4081
        }
×
4082

4083
        // Otherwise, the node does not exist, so we create a shell entry for
4084
        // it.
4085
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
4086
                Version: int16(ProtocolV1),
×
4087
                PubKey:  pubKey[:],
×
4088
        })
×
4089
        if err != nil {
×
4090
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4091
        }
×
4092

4093
        return id, nil
×
4094
}
4095

4096
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4097
// the database. This includes deleting any existing types and then inserting
4098
// the new types.
4099
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4100
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4101

×
4102
        // Delete all existing extra signed fields for the channel policy.
×
4103
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4104
        if err != nil {
×
4105
                return fmt.Errorf("unable to delete "+
×
4106
                        "existing policy extra signed fields for policy %d: %w",
×
4107
                        chanPolicyID, err)
×
4108
        }
×
4109

4110
        // Insert all new extra signed fields for the channel policy.
4111
        for tlvType, value := range extraFields {
×
4112
                err = db.InsertChanPolicyExtraType(
×
4113
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
4114
                                ChannelPolicyID: chanPolicyID,
×
4115
                                Type:            int64(tlvType),
×
4116
                                Value:           value,
×
4117
                        },
×
4118
                )
×
4119
                if err != nil {
×
4120
                        return fmt.Errorf("unable to insert "+
×
4121
                                "channel_policy(%d) extra signed field(%v): %w",
×
4122
                                chanPolicyID, tlvType, err)
×
4123
                }
×
4124
        }
4125

4126
        return nil
×
4127
}
4128

4129
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4130
// provided dbChanRow and also fetches any other required information
4131
// to construct the edge info.
4132
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
4133
        chain chainhash.Hash, dbChan sqlc.GraphChannel, node1,
4134
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4135

×
4136
        // NOTE: getAndBuildEdgeInfo is only used to load the data for a single
×
4137
        // edge, and so no paged queries will be performed. This means that
×
4138
        // it's ok to used pass in default config values here.
×
4139
        cfg := sqldb.DefaultPagedQueryConfig()
×
4140

×
4141
        data, err := batchLoadChannelData(ctx, cfg, db, []int64{dbChan.ID}, nil)
×
4142
        if err != nil {
×
4143
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
4144
                        err)
×
4145
        }
×
4146

4147
        return buildEdgeInfoWithBatchData(chain, dbChan, node1, node2, data)
×
4148
}
4149

4150
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
4151
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
4152
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
4153
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
4154

×
4155
        if dbChan.Version != int16(ProtocolV1) {
×
4156
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4157
                        dbChan.Version)
×
4158
        }
×
4159

4160
        // Use pre-loaded features and extras types.
4161
        fv := lnwire.EmptyFeatureVector()
×
4162
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
4163
                for _, bit := range features {
×
4164
                        fv.Set(lnwire.FeatureBit(bit))
×
4165
                }
×
4166
        }
4167

4168
        var extras map[uint64][]byte
×
4169
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
4170
        if exists {
×
4171
                extras = channelExtras
×
4172
        } else {
×
4173
                extras = make(map[uint64][]byte)
×
4174
        }
×
4175

4176
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4177
        if err != nil {
×
4178
                return nil, err
×
4179
        }
×
4180

4181
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4182
        if err != nil {
×
4183
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4184
                        "fields: %w", err)
×
4185
        }
×
4186
        if recs == nil {
×
4187
                recs = make([]byte, 0)
×
4188
        }
×
4189

4190
        var btcKey1, btcKey2 route.Vertex
×
4191
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4192
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4193

×
4194
        channel := &models.ChannelEdgeInfo{
×
4195
                ChainHash:        chain,
×
4196
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4197
                NodeKey1Bytes:    node1,
×
4198
                NodeKey2Bytes:    node2,
×
4199
                BitcoinKey1Bytes: btcKey1,
×
4200
                BitcoinKey2Bytes: btcKey2,
×
4201
                ChannelPoint:     *op,
×
4202
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4203
                Features:         fv,
×
4204
                ExtraOpaqueData:  recs,
×
4205
        }
×
4206

×
4207
        // We always set all the signatures at the same time, so we can
×
4208
        // safely check if one signature is present to determine if we have the
×
4209
        // rest of the signatures for the auth proof.
×
4210
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4211
                channel.AuthProof = &models.ChannelAuthProof{
×
4212
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4213
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4214
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4215
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4216
                }
×
4217
        }
×
4218

4219
        return channel, nil
×
4220
}
4221

4222
// buildNodeVertices is a helper that converts raw node public keys
4223
// into route.Vertex instances.
4224
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4225
        route.Vertex, error) {
×
4226

×
4227
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4228
        if err != nil {
×
4229
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4230
                        "create vertex from node1 pubkey: %w", err)
×
4231
        }
×
4232

4233
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4234
        if err != nil {
×
4235
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4236
                        "create vertex from node2 pubkey: %w", err)
×
4237
        }
×
4238

4239
        return node1Vertex, node2Vertex, nil
×
4240
}
4241

4242
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4243
// retrieves all the extra info required to build the complete
4244
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4245
// the provided sqlc.GraphChannelPolicy records are nil.
4246
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4247
        dbPol1, dbPol2 *sqlc.GraphChannelPolicy, channelID uint64, node1,
4248
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4249
        *models.ChannelEdgePolicy, error) {
×
4250

×
4251
        if dbPol1 == nil && dbPol2 == nil {
×
4252
                return nil, nil, nil
×
4253
        }
×
4254

4255
        var policyIDs = make([]int64, 0, 2)
×
4256
        if dbPol1 != nil {
×
4257
                policyIDs = append(policyIDs, dbPol1.ID)
×
4258
        }
×
4259
        if dbPol2 != nil {
×
4260
                policyIDs = append(policyIDs, dbPol2.ID)
×
4261
        }
×
4262

4263
        // NOTE: getAndBuildChanPolicies is only used to load the data for
4264
        // a maximum of two policies, and so no paged queries will be
4265
        // performed (unless the page size is one). So it's ok to use
4266
        // the default config values here.
4267
        cfg := sqldb.DefaultPagedQueryConfig()
×
4268

×
4269
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4270
        if err != nil {
×
4271
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4272
                        "data: %w", err)
×
4273
        }
×
4274

4275
        pol1, err := buildChanPolicyWithBatchData(
×
4276
                dbPol1, channelID, node2, batchData,
×
4277
        )
×
4278
        if err != nil {
×
4279
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4280
        }
×
4281

4282
        pol2, err := buildChanPolicyWithBatchData(
×
4283
                dbPol2, channelID, node1, batchData,
×
4284
        )
×
4285
        if err != nil {
×
4286
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4287
        }
×
4288

4289
        return pol1, pol2, nil
×
4290
}
4291

4292
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4293
// provided sqlc.GraphChannelPolicy and other required information.
4294
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4295
        extras map[uint64][]byte,
4296
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4297

×
4298
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4299
        if err != nil {
×
4300
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4301
                        "fields: %w", err)
×
4302
        }
×
4303

4304
        var inboundFee fn.Option[lnwire.Fee]
×
4305
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4306
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4307

×
4308
                inboundFee = fn.Some(lnwire.Fee{
×
4309
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4310
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4311
                })
×
4312
        }
×
4313

4314
        return &models.ChannelEdgePolicy{
×
4315
                SigBytes:  dbPolicy.Signature,
×
4316
                ChannelID: channelID,
×
4317
                LastUpdate: time.Unix(
×
4318
                        dbPolicy.LastUpdate.Int64, 0,
×
4319
                ),
×
4320
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4321
                        dbPolicy.MessageFlags,
×
4322
                ),
×
4323
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4324
                        dbPolicy.ChannelFlags,
×
4325
                ),
×
4326
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4327
                MinHTLC: lnwire.MilliSatoshi(
×
4328
                        dbPolicy.MinHtlcMsat,
×
4329
                ),
×
4330
                MaxHTLC: lnwire.MilliSatoshi(
×
4331
                        dbPolicy.MaxHtlcMsat.Int64,
×
4332
                ),
×
4333
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4334
                        dbPolicy.BaseFeeMsat,
×
4335
                ),
×
4336
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4337
                ToNode:                    toNode,
×
4338
                InboundFee:                inboundFee,
×
4339
                ExtraOpaqueData:           recs,
×
4340
        }, nil
×
4341
}
4342

4343
// buildNodes builds the models.LightningNode instances for the
4344
// given row which is expected to be a sqlc type that contains node information.
4345
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4346
        dbNode2 sqlc.GraphNode) (*models.LightningNode, *models.LightningNode,
4347
        error) {
×
4348

×
4349
        node1, err := buildNode(ctx, db, &dbNode1)
×
4350
        if err != nil {
×
4351
                return nil, nil, err
×
4352
        }
×
4353

4354
        node2, err := buildNode(ctx, db, &dbNode2)
×
4355
        if err != nil {
×
4356
                return nil, nil, err
×
4357
        }
×
4358

4359
        return node1, node2, nil
×
4360
}
4361

4362
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4363
// row which is expected to be a sqlc type that contains channel policy
4364
// information. It returns two policies, which may be nil if the policy
4365
// information is not present in the row.
4366
//
4367
//nolint:ll,dupl,funlen
4368
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4369
        *sqlc.GraphChannelPolicy, error) {
×
4370

×
4371
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4372
        switch r := row.(type) {
×
4373
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4374
                if r.Policy1Timelock.Valid {
×
4375
                        policy1 = &sqlc.GraphChannelPolicy{
×
4376
                                Timelock:                r.Policy1Timelock.Int32,
×
4377
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4378
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4379
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4380
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4381
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4382
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4383
                                Disabled:                r.Policy1Disabled,
×
4384
                                MessageFlags:            r.Policy1MessageFlags,
×
4385
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4386
                        }
×
4387
                }
×
4388
                if r.Policy2Timelock.Valid {
×
4389
                        policy2 = &sqlc.GraphChannelPolicy{
×
4390
                                Timelock:                r.Policy2Timelock.Int32,
×
4391
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4392
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4393
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4394
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4395
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4396
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4397
                                Disabled:                r.Policy2Disabled,
×
4398
                                MessageFlags:            r.Policy2MessageFlags,
×
4399
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4400
                        }
×
4401
                }
×
4402

4403
                return policy1, policy2, nil
×
4404

4405
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4406
                if r.Policy1ID.Valid {
×
4407
                        policy1 = &sqlc.GraphChannelPolicy{
×
4408
                                ID:                      r.Policy1ID.Int64,
×
4409
                                Version:                 r.Policy1Version.Int16,
×
4410
                                ChannelID:               r.GraphChannel.ID,
×
4411
                                NodeID:                  r.Policy1NodeID.Int64,
×
4412
                                Timelock:                r.Policy1Timelock.Int32,
×
4413
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4414
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4415
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4416
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4417
                                LastUpdate:              r.Policy1LastUpdate,
×
4418
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4419
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4420
                                Disabled:                r.Policy1Disabled,
×
4421
                                MessageFlags:            r.Policy1MessageFlags,
×
4422
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4423
                                Signature:               r.Policy1Signature,
×
4424
                        }
×
4425
                }
×
4426
                if r.Policy2ID.Valid {
×
4427
                        policy2 = &sqlc.GraphChannelPolicy{
×
4428
                                ID:                      r.Policy2ID.Int64,
×
4429
                                Version:                 r.Policy2Version.Int16,
×
4430
                                ChannelID:               r.GraphChannel.ID,
×
4431
                                NodeID:                  r.Policy2NodeID.Int64,
×
4432
                                Timelock:                r.Policy2Timelock.Int32,
×
4433
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4434
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4435
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4436
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4437
                                LastUpdate:              r.Policy2LastUpdate,
×
4438
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4439
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4440
                                Disabled:                r.Policy2Disabled,
×
4441
                                MessageFlags:            r.Policy2MessageFlags,
×
4442
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4443
                                Signature:               r.Policy2Signature,
×
4444
                        }
×
4445
                }
×
4446

4447
                return policy1, policy2, nil
×
4448

4449
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4450
                if r.Policy1ID.Valid {
×
4451
                        policy1 = &sqlc.GraphChannelPolicy{
×
4452
                                ID:                      r.Policy1ID.Int64,
×
4453
                                Version:                 r.Policy1Version.Int16,
×
4454
                                ChannelID:               r.GraphChannel.ID,
×
4455
                                NodeID:                  r.Policy1NodeID.Int64,
×
4456
                                Timelock:                r.Policy1Timelock.Int32,
×
4457
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4458
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4459
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4460
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4461
                                LastUpdate:              r.Policy1LastUpdate,
×
4462
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4463
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4464
                                Disabled:                r.Policy1Disabled,
×
4465
                                MessageFlags:            r.Policy1MessageFlags,
×
4466
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4467
                                Signature:               r.Policy1Signature,
×
4468
                        }
×
4469
                }
×
4470
                if r.Policy2ID.Valid {
×
4471
                        policy2 = &sqlc.GraphChannelPolicy{
×
4472
                                ID:                      r.Policy2ID.Int64,
×
4473
                                Version:                 r.Policy2Version.Int16,
×
4474
                                ChannelID:               r.GraphChannel.ID,
×
4475
                                NodeID:                  r.Policy2NodeID.Int64,
×
4476
                                Timelock:                r.Policy2Timelock.Int32,
×
4477
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4478
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4479
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4480
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4481
                                LastUpdate:              r.Policy2LastUpdate,
×
4482
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4483
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4484
                                Disabled:                r.Policy2Disabled,
×
4485
                                MessageFlags:            r.Policy2MessageFlags,
×
4486
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4487
                                Signature:               r.Policy2Signature,
×
4488
                        }
×
4489
                }
×
4490

4491
                return policy1, policy2, nil
×
4492

4493
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4494
                if r.Policy1ID.Valid {
×
4495
                        policy1 = &sqlc.GraphChannelPolicy{
×
4496
                                ID:                      r.Policy1ID.Int64,
×
4497
                                Version:                 r.Policy1Version.Int16,
×
4498
                                ChannelID:               r.GraphChannel.ID,
×
4499
                                NodeID:                  r.Policy1NodeID.Int64,
×
4500
                                Timelock:                r.Policy1Timelock.Int32,
×
4501
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4502
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4503
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4504
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4505
                                LastUpdate:              r.Policy1LastUpdate,
×
4506
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4507
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4508
                                Disabled:                r.Policy1Disabled,
×
4509
                                MessageFlags:            r.Policy1MessageFlags,
×
4510
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4511
                                Signature:               r.Policy1Signature,
×
4512
                        }
×
4513
                }
×
4514
                if r.Policy2ID.Valid {
×
4515
                        policy2 = &sqlc.GraphChannelPolicy{
×
4516
                                ID:                      r.Policy2ID.Int64,
×
4517
                                Version:                 r.Policy2Version.Int16,
×
4518
                                ChannelID:               r.GraphChannel.ID,
×
4519
                                NodeID:                  r.Policy2NodeID.Int64,
×
4520
                                Timelock:                r.Policy2Timelock.Int32,
×
4521
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4522
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4523
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4524
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4525
                                LastUpdate:              r.Policy2LastUpdate,
×
4526
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4527
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4528
                                Disabled:                r.Policy2Disabled,
×
4529
                                MessageFlags:            r.Policy2MessageFlags,
×
4530
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4531
                                Signature:               r.Policy2Signature,
×
4532
                        }
×
4533
                }
×
4534

4535
                return policy1, policy2, nil
×
4536

4537
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4538
                if r.Policy1ID.Valid {
×
4539
                        policy1 = &sqlc.GraphChannelPolicy{
×
4540
                                ID:                      r.Policy1ID.Int64,
×
4541
                                Version:                 r.Policy1Version.Int16,
×
4542
                                ChannelID:               r.GraphChannel.ID,
×
4543
                                NodeID:                  r.Policy1NodeID.Int64,
×
4544
                                Timelock:                r.Policy1Timelock.Int32,
×
4545
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4546
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4547
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4548
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4549
                                LastUpdate:              r.Policy1LastUpdate,
×
4550
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4551
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4552
                                Disabled:                r.Policy1Disabled,
×
4553
                                MessageFlags:            r.Policy1MessageFlags,
×
4554
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4555
                                Signature:               r.Policy1Signature,
×
4556
                        }
×
4557
                }
×
4558
                if r.Policy2ID.Valid {
×
4559
                        policy2 = &sqlc.GraphChannelPolicy{
×
4560
                                ID:                      r.Policy2ID.Int64,
×
4561
                                Version:                 r.Policy2Version.Int16,
×
4562
                                ChannelID:               r.GraphChannel.ID,
×
4563
                                NodeID:                  r.Policy2NodeID.Int64,
×
4564
                                Timelock:                r.Policy2Timelock.Int32,
×
4565
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4566
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4567
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4568
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4569
                                LastUpdate:              r.Policy2LastUpdate,
×
4570
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4571
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4572
                                Disabled:                r.Policy2Disabled,
×
4573
                                MessageFlags:            r.Policy2MessageFlags,
×
4574
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4575
                                Signature:               r.Policy2Signature,
×
4576
                        }
×
4577
                }
×
4578

4579
                return policy1, policy2, nil
×
4580

4581
        case sqlc.ListChannelsByNodeIDRow:
×
4582
                if r.Policy1ID.Valid {
×
4583
                        policy1 = &sqlc.GraphChannelPolicy{
×
4584
                                ID:                      r.Policy1ID.Int64,
×
4585
                                Version:                 r.Policy1Version.Int16,
×
4586
                                ChannelID:               r.GraphChannel.ID,
×
4587
                                NodeID:                  r.Policy1NodeID.Int64,
×
4588
                                Timelock:                r.Policy1Timelock.Int32,
×
4589
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4590
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4591
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4592
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4593
                                LastUpdate:              r.Policy1LastUpdate,
×
4594
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4595
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4596
                                Disabled:                r.Policy1Disabled,
×
4597
                                MessageFlags:            r.Policy1MessageFlags,
×
4598
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4599
                                Signature:               r.Policy1Signature,
×
4600
                        }
×
4601
                }
×
4602
                if r.Policy2ID.Valid {
×
4603
                        policy2 = &sqlc.GraphChannelPolicy{
×
4604
                                ID:                      r.Policy2ID.Int64,
×
4605
                                Version:                 r.Policy2Version.Int16,
×
4606
                                ChannelID:               r.GraphChannel.ID,
×
4607
                                NodeID:                  r.Policy2NodeID.Int64,
×
4608
                                Timelock:                r.Policy2Timelock.Int32,
×
4609
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4610
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4611
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4612
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4613
                                LastUpdate:              r.Policy2LastUpdate,
×
4614
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4615
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4616
                                Disabled:                r.Policy2Disabled,
×
4617
                                MessageFlags:            r.Policy2MessageFlags,
×
4618
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4619
                                Signature:               r.Policy2Signature,
×
4620
                        }
×
4621
                }
×
4622

4623
                return policy1, policy2, nil
×
4624

4625
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4626
                if r.Policy1ID.Valid {
×
4627
                        policy1 = &sqlc.GraphChannelPolicy{
×
4628
                                ID:                      r.Policy1ID.Int64,
×
4629
                                Version:                 r.Policy1Version.Int16,
×
4630
                                ChannelID:               r.GraphChannel.ID,
×
4631
                                NodeID:                  r.Policy1NodeID.Int64,
×
4632
                                Timelock:                r.Policy1Timelock.Int32,
×
4633
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4634
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4635
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4636
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4637
                                LastUpdate:              r.Policy1LastUpdate,
×
4638
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4639
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4640
                                Disabled:                r.Policy1Disabled,
×
4641
                                MessageFlags:            r.Policy1MessageFlags,
×
4642
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4643
                                Signature:               r.Policy1Signature,
×
4644
                        }
×
4645
                }
×
4646
                if r.Policy2ID.Valid {
×
4647
                        policy2 = &sqlc.GraphChannelPolicy{
×
4648
                                ID:                      r.Policy2ID.Int64,
×
4649
                                Version:                 r.Policy2Version.Int16,
×
4650
                                ChannelID:               r.GraphChannel.ID,
×
4651
                                NodeID:                  r.Policy2NodeID.Int64,
×
4652
                                Timelock:                r.Policy2Timelock.Int32,
×
4653
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4654
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4655
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4656
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4657
                                LastUpdate:              r.Policy2LastUpdate,
×
4658
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4659
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4660
                                Disabled:                r.Policy2Disabled,
×
4661
                                MessageFlags:            r.Policy2MessageFlags,
×
4662
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4663
                                Signature:               r.Policy2Signature,
×
4664
                        }
×
4665
                }
×
4666

4667
                return policy1, policy2, nil
×
4668
        default:
×
4669
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4670
                        "extractChannelPolicies: %T", r)
×
4671
        }
4672
}
4673

4674
// channelIDToBytes converts a channel ID (SCID) to a byte array
4675
// representation.
4676
func channelIDToBytes(channelID uint64) []byte {
×
4677
        var chanIDB [8]byte
×
4678
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4679

×
4680
        return chanIDB[:]
×
4681
}
×
4682

4683
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4684
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4685
        if len(addresses) == 0 {
×
4686
                return nil, nil
×
4687
        }
×
4688

4689
        result := make([]net.Addr, 0, len(addresses))
×
4690
        for _, addr := range addresses {
×
4691
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
4692
                if err != nil {
×
4693
                        return nil, fmt.Errorf("unable to parse address %s "+
×
4694
                                "of type %d: %w", addr.address, addr.addrType,
×
4695
                                err)
×
4696
                }
×
4697
                if netAddr != nil {
×
4698
                        result = append(result, netAddr)
×
4699
                }
×
4700
        }
4701

4702
        // If we have no valid addresses, return nil instead of empty slice.
4703
        if len(result) == 0 {
×
4704
                return nil, nil
×
4705
        }
×
4706

4707
        return result, nil
×
4708
}
4709

4710
// parseAddress parses the given address string based on the address type
4711
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4712
// and opaque addresses.
4713
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4714
        switch addrType {
×
4715
        case addressTypeIPv4:
×
4716
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4717
                if err != nil {
×
4718
                        return nil, err
×
4719
                }
×
4720

4721
                tcp.IP = tcp.IP.To4()
×
4722

×
4723
                return tcp, nil
×
4724

4725
        case addressTypeIPv6:
×
4726
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
4727
                if err != nil {
×
4728
                        return nil, err
×
4729
                }
×
4730

4731
                return tcp, nil
×
4732

4733
        case addressTypeTorV3, addressTypeTorV2:
×
4734
                service, portStr, err := net.SplitHostPort(address)
×
4735
                if err != nil {
×
4736
                        return nil, fmt.Errorf("unable to split tor "+
×
4737
                                "address: %v", address)
×
4738
                }
×
4739

4740
                port, err := strconv.Atoi(portStr)
×
4741
                if err != nil {
×
4742
                        return nil, err
×
4743
                }
×
4744

4745
                return &tor.OnionAddr{
×
4746
                        OnionService: service,
×
4747
                        Port:         port,
×
4748
                }, nil
×
4749

NEW
4750
        case addressTypeDNS:
×
NEW
4751
                host, portStr, err := net.SplitHostPort(address)
×
NEW
4752
                if err != nil {
×
NEW
4753
                        return nil, fmt.Errorf("unable to "+
×
NEW
4754
                                "split tor dns address: %v",
×
NEW
4755
                                address)
×
NEW
4756
                }
×
4757

NEW
4758
                port, err := strconv.Atoi(portStr)
×
NEW
4759
                if err != nil {
×
NEW
4760
                        return nil, err
×
NEW
4761
                }
×
4762

NEW
4763
                return &lnwire.DNSAddr{
×
NEW
4764
                        Hostname: host,
×
NEW
4765
                        Port:     uint16(port),
×
NEW
4766
                }, nil
×
4767

4768
        case addressTypeOpaque:
×
4769
                opaque, err := hex.DecodeString(address)
×
4770
                if err != nil {
×
4771
                        return nil, fmt.Errorf("unable to decode opaque "+
×
4772
                                "address: %v", address)
×
4773
                }
×
4774

4775
                return &lnwire.OpaqueAddrs{
×
4776
                        Payload: opaque,
×
4777
                }, nil
×
4778

4779
        default:
×
4780
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
4781
        }
4782
}
4783

4784
// batchNodeData holds all the related data for a batch of nodes.
4785
type batchNodeData struct {
4786
        // features is a map from a DB node ID to the feature bits for that
4787
        // node.
4788
        features map[int64][]int
4789

4790
        // addresses is a map from a DB node ID to the node's addresses.
4791
        addresses map[int64][]nodeAddress
4792

4793
        // extraFields is a map from a DB node ID to the extra signed fields
4794
        // for that node.
4795
        extraFields map[int64]map[uint64][]byte
4796
}
4797

4798
// nodeAddress holds the address type, position and address string for a
4799
// node. This is used to batch the fetching of node addresses.
4800
type nodeAddress struct {
4801
        addrType dbAddressType
4802
        position int32
4803
        address  string
4804
}
4805

4806
// batchLoadNodeData loads all related data for a batch of node IDs using the
4807
// provided SQLQueries interface. It returns a batchNodeData instance containing
4808
// the node features, addresses and extra signed fields.
4809
func batchLoadNodeData(ctx context.Context, cfg *sqldb.PagedQueryConfig,
4810
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
4811

×
4812
        // Batch load the node features.
×
4813
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
4814
        if err != nil {
×
4815
                return nil, fmt.Errorf("unable to batch load node "+
×
4816
                        "features: %w", err)
×
4817
        }
×
4818

4819
        // Batch load the node addresses.
4820
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
4821
        if err != nil {
×
4822
                return nil, fmt.Errorf("unable to batch load node "+
×
4823
                        "addresses: %w", err)
×
4824
        }
×
4825

4826
        // Batch load the node extra signed fields.
4827
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
4828
        if err != nil {
×
4829
                return nil, fmt.Errorf("unable to batch load node extra "+
×
4830
                        "signed fields: %w", err)
×
4831
        }
×
4832

4833
        return &batchNodeData{
×
4834
                features:    features,
×
4835
                addresses:   addrs,
×
4836
                extraFields: extraTypes,
×
4837
        }, nil
×
4838
}
4839

4840
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
4841
// using ExecutePagedQuery wrapper around the GetNodeFeaturesBatch query.
4842
func batchLoadNodeFeaturesHelper(ctx context.Context,
4843
        cfg *sqldb.PagedQueryConfig, db SQLQueries,
4844
        nodeIDs []int64) (map[int64][]int, error) {
×
4845

×
4846
        features := make(map[int64][]int)
×
4847

×
4848
        return features, sqldb.ExecutePagedQuery(
×
4849
                ctx, cfg, nodeIDs,
×
4850
                func(id int64) int64 {
×
4851
                        return id
×
4852
                },
×
4853
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
4854
                        error) {
×
4855

×
4856
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
4857
                },
×
4858
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
4859
                        features[feature.NodeID] = append(
×
4860
                                features[feature.NodeID],
×
4861
                                int(feature.FeatureBit),
×
4862
                        )
×
4863

×
4864
                        return nil
×
4865
                },
×
4866
        )
4867
}
4868

4869
// batchLoadNodeAddressesHelper loads node addresses using ExecutePagedQuery
4870
// wrapper around the GetNodeAddressesBatch query. It returns a map from
4871
// node ID to a slice of nodeAddress structs.
4872
func batchLoadNodeAddressesHelper(ctx context.Context,
4873
        cfg *sqldb.PagedQueryConfig, db SQLQueries,
4874
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
4875

×
4876
        addrs := make(map[int64][]nodeAddress)
×
4877

×
4878
        return addrs, sqldb.ExecutePagedQuery(
×
4879
                ctx, cfg, nodeIDs,
×
4880
                func(id int64) int64 {
×
4881
                        return id
×
4882
                },
×
4883
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
4884
                        error) {
×
4885

×
4886
                        return db.GetNodeAddressesBatch(ctx, ids)
×
4887
                },
×
4888
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
4889
                        addrs[addr.NodeID] = append(
×
4890
                                addrs[addr.NodeID], nodeAddress{
×
4891
                                        addrType: dbAddressType(addr.Type),
×
4892
                                        position: addr.Position,
×
4893
                                        address:  addr.Address,
×
4894
                                },
×
4895
                        )
×
4896

×
4897
                        return nil
×
4898
                },
×
4899
        )
4900
}
4901

4902
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
4903
// node IDs using ExecutePagedQuery wrapper around the GetNodeExtraTypesBatch
4904
// query.
4905
func batchLoadNodeExtraTypesHelper(ctx context.Context,
4906
        cfg *sqldb.PagedQueryConfig, db SQLQueries,
4907
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
4908

×
4909
        extraFields := make(map[int64]map[uint64][]byte)
×
4910

×
4911
        callback := func(ctx context.Context,
×
4912
                field sqlc.GraphNodeExtraType) error {
×
4913

×
4914
                if extraFields[field.NodeID] == nil {
×
4915
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
4916
                }
×
4917
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
4918

×
4919
                return nil
×
4920
        }
4921

4922
        return extraFields, sqldb.ExecutePagedQuery(
×
4923
                ctx, cfg, nodeIDs,
×
4924
                func(id int64) int64 {
×
4925
                        return id
×
4926
                },
×
4927
                func(ctx context.Context, ids []int64) (
4928
                        []sqlc.GraphNodeExtraType, error) {
×
4929

×
4930
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
4931
                },
×
4932
                callback,
4933
        )
4934
}
4935

4936
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
4937
// from the provided sqlc.GraphChannelPolicy records and the
4938
// provided batchChannelData.
4939
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4940
        channelID uint64, node1, node2 route.Vertex,
4941
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
4942
        *models.ChannelEdgePolicy, error) {
×
4943

×
4944
        pol1, err := buildChanPolicyWithBatchData(
×
4945
                dbPol1, channelID, node2, batchData,
×
4946
        )
×
4947
        if err != nil {
×
4948
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4949
        }
×
4950

4951
        pol2, err := buildChanPolicyWithBatchData(
×
4952
                dbPol2, channelID, node1, batchData,
×
4953
        )
×
4954
        if err != nil {
×
4955
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4956
        }
×
4957

4958
        return pol1, pol2, nil
×
4959
}
4960

4961
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
4962
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
4963
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
4964
        channelID uint64, toNode route.Vertex,
4965
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
4966

×
4967
        if dbPol == nil {
×
4968
                return nil, nil
×
4969
        }
×
4970

4971
        var dbPol1Extras map[uint64][]byte
×
4972
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
4973
                dbPol1Extras = extras
×
4974
        } else {
×
4975
                dbPol1Extras = make(map[uint64][]byte)
×
4976
        }
×
4977

4978
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
4979
}
4980

4981
// batchChannelData holds all the related data for a batch of channels.
4982
type batchChannelData struct {
4983
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
4984
        chanfeatures map[int64][]int
4985

4986
        // chanExtras is a map from DB channel ID to a map of TLV type to
4987
        // extra signed field bytes.
4988
        chanExtraTypes map[int64]map[uint64][]byte
4989

4990
        // policyExtras is a map from DB channel policy ID to a map of TLV type
4991
        // to extra signed field bytes.
4992
        policyExtras map[int64]map[uint64][]byte
4993
}
4994

4995
// batchLoadChannelData loads all related data for batches of channels and
4996
// policies.
4997
func batchLoadChannelData(ctx context.Context, cfg *sqldb.PagedQueryConfig,
4998
        db SQLQueries, channelIDs []int64,
4999
        policyIDs []int64) (*batchChannelData, error) {
×
5000

×
5001
        batchData := &batchChannelData{
×
5002
                chanfeatures:   make(map[int64][]int),
×
5003
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
5004
                policyExtras:   make(map[int64]map[uint64][]byte),
×
5005
        }
×
5006

×
5007
        // Batch load channel features and extras
×
5008
        var err error
×
5009
        if len(channelIDs) > 0 {
×
5010
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
5011
                        ctx, cfg, db, channelIDs,
×
5012
                )
×
5013
                if err != nil {
×
5014
                        return nil, fmt.Errorf("unable to batch load "+
×
5015
                                "channel features: %w", err)
×
5016
                }
×
5017

5018
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
5019
                        ctx, cfg, db, channelIDs,
×
5020
                )
×
5021
                if err != nil {
×
5022
                        return nil, fmt.Errorf("unable to batch load "+
×
5023
                                "channel extras: %w", err)
×
5024
                }
×
5025
        }
5026

5027
        if len(policyIDs) > 0 {
×
5028
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
5029
                        ctx, cfg, db, policyIDs,
×
5030
                )
×
5031
                if err != nil {
×
5032
                        return nil, fmt.Errorf("unable to batch load "+
×
5033
                                "policy extras: %w", err)
×
5034
                }
×
5035
                batchData.policyExtras = policyExtras
×
5036
        }
5037

5038
        return batchData, nil
×
5039
}
5040

5041
// batchLoadChannelFeaturesHelper loads channel features for a batch of
5042
// channel IDs using ExecutePagedQuery wrapper around the
5043
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
5044
// slice of feature bits.
5045
func batchLoadChannelFeaturesHelper(ctx context.Context,
5046
        cfg *sqldb.PagedQueryConfig, db SQLQueries,
5047
        channelIDs []int64) (map[int64][]int, error) {
×
5048

×
5049
        features := make(map[int64][]int)
×
5050

×
5051
        return features, sqldb.ExecutePagedQuery(
×
5052
                ctx, cfg, channelIDs,
×
5053
                func(id int64) int64 {
×
5054
                        return id
×
5055
                },
×
5056
                func(ctx context.Context,
5057
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
5058

×
5059
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
5060
                },
×
5061
                func(ctx context.Context,
5062
                        feature sqlc.GraphChannelFeature) error {
×
5063

×
5064
                        features[feature.ChannelID] = append(
×
5065
                                features[feature.ChannelID],
×
5066
                                int(feature.FeatureBit),
×
5067
                        )
×
5068

×
5069
                        return nil
×
5070
                },
×
5071
        )
5072
}
5073

5074
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
5075
// channel IDs using ExecutePagedQuery wrapper around the GetChannelExtrasBatch
5076
// query. It returns a map from DB channel ID to a map of TLV type to extra
5077
// signed field bytes.
5078
func batchLoadChannelExtrasHelper(ctx context.Context,
5079
        cfg *sqldb.PagedQueryConfig, db SQLQueries,
5080
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5081

×
5082
        extras := make(map[int64]map[uint64][]byte)
×
5083

×
5084
        cb := func(ctx context.Context,
×
5085
                extra sqlc.GraphChannelExtraType) error {
×
5086

×
5087
                if extras[extra.ChannelID] == nil {
×
5088
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5089
                }
×
5090
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
5091

×
5092
                return nil
×
5093
        }
5094

5095
        return extras, sqldb.ExecutePagedQuery(
×
5096
                ctx, cfg, channelIDs,
×
5097
                func(id int64) int64 {
×
5098
                        return id
×
5099
                },
×
5100
                func(ctx context.Context,
5101
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5102

×
5103
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5104
                }, cb,
×
5105
        )
5106
}
5107

5108
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5109
// batch of policy IDs using ExecutePagedQuery wrapper around the
5110
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5111
// a map of TLV type to extra signed field bytes.
5112
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5113
        cfg *sqldb.PagedQueryConfig, db SQLQueries,
5114
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5115

×
5116
        extras := make(map[int64]map[uint64][]byte)
×
5117

×
5118
        return extras, sqldb.ExecutePagedQuery(
×
5119
                ctx, cfg, policyIDs,
×
5120
                func(id int64) int64 {
×
5121
                        return id
×
5122
                },
×
5123
                func(ctx context.Context, ids []int64) (
5124
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5125

×
5126
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5127
                },
×
5128
                func(ctx context.Context,
5129
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5130

×
5131
                        if extras[row.PolicyID] == nil {
×
5132
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5133
                        }
×
5134
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5135

×
5136
                        return nil
×
5137
                },
5138
        )
5139
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc