• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16291326623

15 Jul 2025 10:54AM UTC coverage: 67.338% (-0.01%) from 67.349%
16291326623

push

github

web-flow
Merge pull request #10068 from ellemouton/graphResetForCallbacks

multi: let all V1Store `ForEach*` methods take a `reset` call-back

133 of 184 new or added lines in 18 files covered. (72.28%)

102 existing lines in 21 files now uncovered.

135417 of 201100 relevant lines covered (67.34%)

21783.4 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
72
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
73
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
74

75
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
96
        GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error)
97
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
98
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
99
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
100
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
101
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
102
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
103
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
104
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
105
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
106
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
107
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
108
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
109
        DeleteChannel(ctx context.Context, id int64) error
110

111
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
112
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
113

114
        /*
115
                Channel Policy table queries.
116
        */
117
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
118
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.ChannelPolicy, error)
119
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
120

121
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
122
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
123
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
124

125
        /*
126
                Zombie index queries.
127
        */
128
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
129
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error)
130
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
131
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
132
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
133

134
        /*
135
                Prune log table queries.
136
        */
137
        GetPruneTip(ctx context.Context) (sqlc.PruneLog, error)
138
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
139
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
140
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
141

142
        /*
143
                Closed SCID table queries.
144
        */
145
        InsertClosedChannel(ctx context.Context, scid []byte) error
146
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
147
}
148

149
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
150
// database operations.
151
type BatchedSQLQueries interface {
152
        SQLQueries
153
        sqldb.BatchedTx[SQLQueries]
154
}
155

156
// SQLStore is an implementation of the V1Store interface that uses a SQL
157
// database as the backend.
158
type SQLStore struct {
159
        cfg *SQLStoreConfig
160
        db  BatchedSQLQueries
161

162
        // cacheMu guards all caches (rejectCache and chanCache). If
163
        // this mutex will be acquired at the same time as the DB mutex then
164
        // the cacheMu MUST be acquired first to prevent deadlock.
165
        cacheMu     sync.RWMutex
166
        rejectCache *rejectCache
167
        chanCache   *channelCache
168

169
        chanScheduler batch.Scheduler[SQLQueries]
170
        nodeScheduler batch.Scheduler[SQLQueries]
171

172
        srcNodes  map[ProtocolVersion]*srcNodeInfo
173
        srcNodeMu sync.Mutex
174
}
175

176
// A compile-time assertion to ensure that SQLStore implements the V1Store
177
// interface.
178
var _ V1Store = (*SQLStore)(nil)
179

180
// SQLStoreConfig holds the configuration for the SQLStore.
181
type SQLStoreConfig struct {
182
        // ChainHash is the genesis hash for the chain that all the gossip
183
        // messages in this store are aimed at.
184
        ChainHash chainhash.Hash
185
}
186

187
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
188
// storage backend.
189
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
190
        options ...StoreOptionModifier) (*SQLStore, error) {
×
191

×
192
        opts := DefaultOptions()
×
193
        for _, o := range options {
×
194
                o(opts)
×
195
        }
×
196

197
        if opts.NoMigration {
×
198
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
199
                        "supported for SQL stores")
×
200
        }
×
201

202
        s := &SQLStore{
×
203
                cfg:         cfg,
×
204
                db:          db,
×
205
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
206
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
207
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
208
        }
×
209

×
210
        s.chanScheduler = batch.NewTimeScheduler(
×
211
                db, &s.cacheMu, opts.BatchCommitInterval,
×
212
        )
×
213
        s.nodeScheduler = batch.NewTimeScheduler(
×
214
                db, nil, opts.BatchCommitInterval,
×
215
        )
×
216

×
217
        return s, nil
×
218
}
219

220
// AddLightningNode adds a vertex/node to the graph database. If the node is not
221
// in the database from before, this will add a new, unconnected one to the
222
// graph. If it is present from before, this will update that node's
223
// information.
224
//
225
// NOTE: part of the V1Store interface.
226
func (s *SQLStore) AddLightningNode(ctx context.Context,
227
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
228

×
229
        r := &batch.Request[SQLQueries]{
×
230
                Opts: batch.NewSchedulerOptions(opts...),
×
231
                Do: func(queries SQLQueries) error {
×
232
                        _, err := upsertNode(ctx, queries, node)
×
233
                        return err
×
234
                },
×
235
        }
236

237
        return s.nodeScheduler.Execute(ctx, r)
×
238
}
239

240
// FetchLightningNode attempts to look up a target node by its identity public
241
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
242
// returned.
243
//
244
// NOTE: part of the V1Store interface.
245
func (s *SQLStore) FetchLightningNode(ctx context.Context,
246
        pubKey route.Vertex) (*models.LightningNode, error) {
×
247

×
248
        var node *models.LightningNode
×
249
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
250
                var err error
×
251
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
252

×
253
                return err
×
254
        }, sqldb.NoOpReset)
×
255
        if err != nil {
×
256
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
257
        }
×
258

259
        return node, nil
×
260
}
261

262
// HasLightningNode determines if the graph has a vertex identified by the
263
// target node identity public key. If the node exists in the database, a
264
// timestamp of when the data for the node was lasted updated is returned along
265
// with a true boolean. Otherwise, an empty time.Time is returned with a false
266
// boolean.
267
//
268
// NOTE: part of the V1Store interface.
269
func (s *SQLStore) HasLightningNode(ctx context.Context,
270
        pubKey [33]byte) (time.Time, bool, error) {
×
271

×
272
        var (
×
273
                exists     bool
×
274
                lastUpdate time.Time
×
275
        )
×
276
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
277
                dbNode, err := db.GetNodeByPubKey(
×
278
                        ctx, sqlc.GetNodeByPubKeyParams{
×
279
                                Version: int16(ProtocolV1),
×
280
                                PubKey:  pubKey[:],
×
281
                        },
×
282
                )
×
283
                if errors.Is(err, sql.ErrNoRows) {
×
284
                        return nil
×
285
                } else if err != nil {
×
286
                        return fmt.Errorf("unable to fetch node: %w", err)
×
287
                }
×
288

289
                exists = true
×
290

×
291
                if dbNode.LastUpdate.Valid {
×
292
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
293
                }
×
294

295
                return nil
×
296
        }, sqldb.NoOpReset)
297
        if err != nil {
×
298
                return time.Time{}, false,
×
299
                        fmt.Errorf("unable to fetch node: %w", err)
×
300
        }
×
301

302
        return lastUpdate, exists, nil
×
303
}
304

305
// AddrsForNode returns all known addresses for the target node public key
306
// that the graph DB is aware of. The returned boolean indicates if the
307
// given node is unknown to the graph DB or not.
308
//
309
// NOTE: part of the V1Store interface.
310
func (s *SQLStore) AddrsForNode(ctx context.Context,
311
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
312

×
313
        var (
×
314
                addresses []net.Addr
×
315
                known     bool
×
316
        )
×
317
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
318
                var err error
×
319
                known, addresses, err = getNodeAddresses(
×
320
                        ctx, db, nodePub.SerializeCompressed(),
×
321
                )
×
322
                if err != nil {
×
323
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
324
                                err)
×
325
                }
×
326

327
                return nil
×
328
        }, sqldb.NoOpReset)
329
        if err != nil {
×
330
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
331
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
332
        }
×
333

334
        return known, addresses, nil
×
335
}
336

337
// DeleteLightningNode starts a new database transaction to remove a vertex/node
338
// from the database according to the node's public key.
339
//
340
// NOTE: part of the V1Store interface.
341
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
342
        pubKey route.Vertex) error {
×
343

×
344
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
345
                res, err := db.DeleteNodeByPubKey(
×
346
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
347
                                Version: int16(ProtocolV1),
×
348
                                PubKey:  pubKey[:],
×
349
                        },
×
350
                )
×
351
                if err != nil {
×
352
                        return err
×
353
                }
×
354

355
                rows, err := res.RowsAffected()
×
356
                if err != nil {
×
357
                        return err
×
358
                }
×
359

360
                if rows == 0 {
×
361
                        return ErrGraphNodeNotFound
×
362
                } else if rows > 1 {
×
363
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
364
                }
×
365

366
                return err
×
367
        }, sqldb.NoOpReset)
368
        if err != nil {
×
369
                return fmt.Errorf("unable to delete node: %w", err)
×
370
        }
×
371

372
        return nil
×
373
}
374

375
// FetchNodeFeatures returns the features of the given node. If no features are
376
// known for the node, an empty feature vector is returned.
377
//
378
// NOTE: this is part of the graphdb.NodeTraverser interface.
379
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
380
        *lnwire.FeatureVector, error) {
×
381

×
382
        ctx := context.TODO()
×
383

×
384
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
385
}
×
386

387
// DisabledChannelIDs returns the channel ids of disabled channels.
388
// A channel is disabled when two of the associated ChanelEdgePolicies
389
// have their disabled bit on.
390
//
391
// NOTE: part of the V1Store interface.
392
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
393
        var (
×
394
                ctx     = context.TODO()
×
395
                chanIDs []uint64
×
396
        )
×
397
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
398
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
399
                if err != nil {
×
400
                        return fmt.Errorf("unable to fetch disabled "+
×
401
                                "channels: %w", err)
×
402
                }
×
403

404
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
405

×
406
                return nil
×
407
        }, sqldb.NoOpReset)
408
        if err != nil {
×
409
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
410
                        err)
×
411
        }
×
412

413
        return chanIDs, nil
×
414
}
415

416
// LookupAlias attempts to return the alias as advertised by the target node.
417
//
418
// NOTE: part of the V1Store interface.
419
func (s *SQLStore) LookupAlias(ctx context.Context,
420
        pub *btcec.PublicKey) (string, error) {
×
421

×
422
        var alias string
×
423
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
424
                dbNode, err := db.GetNodeByPubKey(
×
425
                        ctx, sqlc.GetNodeByPubKeyParams{
×
426
                                Version: int16(ProtocolV1),
×
427
                                PubKey:  pub.SerializeCompressed(),
×
428
                        },
×
429
                )
×
430
                if errors.Is(err, sql.ErrNoRows) {
×
431
                        return ErrNodeAliasNotFound
×
432
                } else if err != nil {
×
433
                        return fmt.Errorf("unable to fetch node: %w", err)
×
434
                }
×
435

436
                if !dbNode.Alias.Valid {
×
437
                        return ErrNodeAliasNotFound
×
438
                }
×
439

440
                alias = dbNode.Alias.String
×
441

×
442
                return nil
×
443
        }, sqldb.NoOpReset)
444
        if err != nil {
×
445
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
446
        }
×
447

448
        return alias, nil
×
449
}
450

451
// SourceNode returns the source node of the graph. The source node is treated
452
// as the center node within a star-graph. This method may be used to kick off
453
// a path finding algorithm in order to explore the reachability of another
454
// node based off the source node.
455
//
456
// NOTE: part of the V1Store interface.
457
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
458
        error) {
×
459

×
460
        var node *models.LightningNode
×
461
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
462
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
463
                if err != nil {
×
464
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
465
                                err)
×
466
                }
×
467

468
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
469

×
470
                return err
×
471
        }, sqldb.NoOpReset)
472
        if err != nil {
×
473
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
474
        }
×
475

476
        return node, nil
×
477
}
478

479
// SetSourceNode sets the source node within the graph database. The source
480
// node is to be used as the center of a star-graph within path finding
481
// algorithms.
482
//
483
// NOTE: part of the V1Store interface.
484
func (s *SQLStore) SetSourceNode(ctx context.Context,
485
        node *models.LightningNode) error {
×
486

×
487
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
488
                id, err := upsertNode(ctx, db, node)
×
489
                if err != nil {
×
490
                        return fmt.Errorf("unable to upsert source node: %w",
×
491
                                err)
×
492
                }
×
493

494
                // Make sure that if a source node for this version is already
495
                // set, then the ID is the same as the one we are about to set.
496
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
497
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
498
                        return fmt.Errorf("unable to fetch source node: %w",
×
499
                                err)
×
500
                } else if err == nil {
×
501
                        if dbSourceNodeID != id {
×
502
                                return fmt.Errorf("v1 source node already "+
×
503
                                        "set to a different node: %d vs %d",
×
504
                                        dbSourceNodeID, id)
×
505
                        }
×
506

507
                        return nil
×
508
                }
509

510
                return db.AddSourceNode(ctx, id)
×
511
        }, sqldb.NoOpReset)
512
}
513

514
// NodeUpdatesInHorizon returns all the known lightning node which have an
515
// update timestamp within the passed range. This method can be used by two
516
// nodes to quickly determine if they have the same set of up to date node
517
// announcements.
518
//
519
// NOTE: This is part of the V1Store interface.
520
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
521
        endTime time.Time) ([]models.LightningNode, error) {
×
522

×
523
        ctx := context.TODO()
×
524

×
525
        var nodes []models.LightningNode
×
526
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
527
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
528
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
529
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
530
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
531
                        },
×
532
                )
×
533
                if err != nil {
×
534
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
535
                }
×
536

537
                for _, dbNode := range dbNodes {
×
538
                        node, err := buildNode(ctx, db, &dbNode)
×
539
                        if err != nil {
×
540
                                return fmt.Errorf("unable to build node: %w",
×
541
                                        err)
×
542
                        }
×
543

544
                        nodes = append(nodes, *node)
×
545
                }
546

547
                return nil
×
548
        }, sqldb.NoOpReset)
549
        if err != nil {
×
550
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
551
        }
×
552

553
        return nodes, nil
×
554
}
555

556
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
557
// undirected edge from the two target nodes are created. The information stored
558
// denotes the static attributes of the channel, such as the channelID, the keys
559
// involved in creation of the channel, and the set of features that the channel
560
// supports. The chanPoint and chanID are used to uniquely identify the edge
561
// globally within the database.
562
//
563
// NOTE: part of the V1Store interface.
564
func (s *SQLStore) AddChannelEdge(ctx context.Context,
565
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
566

×
567
        var alreadyExists bool
×
568
        r := &batch.Request[SQLQueries]{
×
569
                Opts: batch.NewSchedulerOptions(opts...),
×
570
                Reset: func() {
×
571
                        alreadyExists = false
×
572
                },
×
573
                Do: func(tx SQLQueries) error {
×
574
                        _, err := insertChannel(ctx, tx, edge)
×
575

×
576
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
577
                        // succeed, but propagate the error via local state.
×
578
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
579
                                alreadyExists = true
×
580
                                return nil
×
581
                        }
×
582

583
                        return err
×
584
                },
585
                OnCommit: func(err error) error {
×
586
                        switch {
×
587
                        case err != nil:
×
588
                                return err
×
589
                        case alreadyExists:
×
590
                                return ErrEdgeAlreadyExist
×
591
                        default:
×
592
                                s.rejectCache.remove(edge.ChannelID)
×
593
                                s.chanCache.remove(edge.ChannelID)
×
594
                                return nil
×
595
                        }
596
                },
597
        }
598

599
        return s.chanScheduler.Execute(ctx, r)
×
600
}
601

602
// HighestChanID returns the "highest" known channel ID in the channel graph.
603
// This represents the "newest" channel from the PoV of the chain. This method
604
// can be used by peers to quickly determine if their graphs are in sync.
605
//
606
// NOTE: This is part of the V1Store interface.
607
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
608
        var highestChanID uint64
×
609
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
610
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
611
                if errors.Is(err, sql.ErrNoRows) {
×
612
                        return nil
×
613
                } else if err != nil {
×
614
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
615
                                err)
×
616
                }
×
617

618
                highestChanID = byteOrder.Uint64(chanID)
×
619

×
620
                return nil
×
621
        }, sqldb.NoOpReset)
622
        if err != nil {
×
623
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
624
        }
×
625

626
        return highestChanID, nil
×
627
}
628

629
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
630
// within the database for the referenced channel. The `flags` attribute within
631
// the ChannelEdgePolicy determines which of the directed edges are being
632
// updated. If the flag is 1, then the first node's information is being
633
// updated, otherwise it's the second node's information. The node ordering is
634
// determined by the lexicographical ordering of the identity public keys of the
635
// nodes on either side of the channel.
636
//
637
// NOTE: part of the V1Store interface.
638
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
639
        edge *models.ChannelEdgePolicy,
640
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
641

×
642
        var (
×
643
                isUpdate1    bool
×
644
                edgeNotFound bool
×
645
                from, to     route.Vertex
×
646
        )
×
647

×
648
        r := &batch.Request[SQLQueries]{
×
649
                Opts: batch.NewSchedulerOptions(opts...),
×
650
                Reset: func() {
×
651
                        isUpdate1 = false
×
652
                        edgeNotFound = false
×
653
                },
×
654
                Do: func(tx SQLQueries) error {
×
655
                        var err error
×
656
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
657
                                ctx, tx, edge,
×
658
                        )
×
659
                        if err != nil {
×
660
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
661
                        }
×
662

663
                        // Silence ErrEdgeNotFound so that the batch can
664
                        // succeed, but propagate the error via local state.
665
                        if errors.Is(err, ErrEdgeNotFound) {
×
666
                                edgeNotFound = true
×
667
                                return nil
×
668
                        }
×
669

670
                        return err
×
671
                },
672
                OnCommit: func(err error) error {
×
673
                        switch {
×
674
                        case err != nil:
×
675
                                return err
×
676
                        case edgeNotFound:
×
677
                                return ErrEdgeNotFound
×
678
                        default:
×
679
                                s.updateEdgeCache(edge, isUpdate1)
×
680
                                return nil
×
681
                        }
682
                },
683
        }
684

685
        err := s.chanScheduler.Execute(ctx, r)
×
686

×
687
        return from, to, err
×
688
}
689

690
// updateEdgeCache updates our reject and channel caches with the new
691
// edge policy information.
692
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
693
        isUpdate1 bool) {
×
694

×
695
        // If an entry for this channel is found in reject cache, we'll modify
×
696
        // the entry with the updated timestamp for the direction that was just
×
697
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
698
        // during the next query for this edge.
×
699
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
700
                if isUpdate1 {
×
701
                        entry.upd1Time = e.LastUpdate.Unix()
×
702
                } else {
×
703
                        entry.upd2Time = e.LastUpdate.Unix()
×
704
                }
×
705
                s.rejectCache.insert(e.ChannelID, entry)
×
706
        }
707

708
        // If an entry for this channel is found in channel cache, we'll modify
709
        // the entry with the updated policy for the direction that was just
710
        // written. If the edge doesn't exist, we'll defer loading the info and
711
        // policies and lazily read from disk during the next query.
712
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
713
                if isUpdate1 {
×
714
                        channel.Policy1 = e
×
715
                } else {
×
716
                        channel.Policy2 = e
×
717
                }
×
718
                s.chanCache.insert(e.ChannelID, channel)
×
719
        }
720
}
721

722
// ForEachSourceNodeChannel iterates through all channels of the source node,
723
// executing the passed callback on each. The call-back is provided with the
724
// channel's outpoint, whether we have a policy for the channel and the channel
725
// peer's node information.
726
//
727
// NOTE: part of the V1Store interface.
728
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
729
        cb func(chanPoint wire.OutPoint, havePolicy bool,
NEW
730
                otherNode *models.LightningNode) error, reset func()) error {
×
731

×
732
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
733
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
734
                if err != nil {
×
735
                        return fmt.Errorf("unable to fetch source node: %w",
×
736
                                err)
×
737
                }
×
738

739
                return forEachNodeChannel(
×
740
                        ctx, db, s.cfg.ChainHash, nodeID,
×
741
                        func(info *models.ChannelEdgeInfo,
×
742
                                outPolicy *models.ChannelEdgePolicy,
×
743
                                _ *models.ChannelEdgePolicy) error {
×
744

×
745
                                // Fetch the other node.
×
746
                                var (
×
747
                                        otherNodePub [33]byte
×
748
                                        node1        = info.NodeKey1Bytes
×
749
                                        node2        = info.NodeKey2Bytes
×
750
                                )
×
751
                                switch {
×
752
                                case bytes.Equal(node1[:], nodePub[:]):
×
753
                                        otherNodePub = node2
×
754
                                case bytes.Equal(node2[:], nodePub[:]):
×
755
                                        otherNodePub = node1
×
756
                                default:
×
757
                                        return fmt.Errorf("node not " +
×
758
                                                "participating in this channel")
×
759
                                }
760

761
                                _, otherNode, err := getNodeByPubKey(
×
762
                                        ctx, db, otherNodePub,
×
763
                                )
×
764
                                if err != nil {
×
765
                                        return fmt.Errorf("unable to fetch "+
×
766
                                                "other node(%x): %w",
×
767
                                                otherNodePub, err)
×
768
                                }
×
769

770
                                return cb(
×
771
                                        info.ChannelPoint, outPolicy != nil,
×
772
                                        otherNode,
×
773
                                )
×
774
                        },
775
                )
776
        }, reset)
777
}
778

779
// ForEachNode iterates through all the stored vertices/nodes in the graph,
780
// executing the passed callback with each node encountered. If the callback
781
// returns an error, then the transaction is aborted and the iteration stops
782
// early. Any operations performed on the NodeTx passed to the call-back are
783
// executed under the same read transaction and so, methods on the NodeTx object
784
// _MUST_ only be called from within the call-back.
785
//
786
// NOTE: part of the V1Store interface.
787
func (s *SQLStore) ForEachNode(ctx context.Context,
NEW
788
        cb func(tx NodeRTx) error, reset func()) error {
×
789

×
790
        var lastID int64 = 0
×
791
        handleNode := func(db SQLQueries, dbNode sqlc.Node) error {
×
792
                node, err := buildNode(ctx, db, &dbNode)
×
793
                if err != nil {
×
794
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
795
                                dbNode.ID, err)
×
796
                }
×
797

798
                err = cb(
×
799
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
800
                )
×
801
                if err != nil {
×
802
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
803
                                dbNode.ID, err)
×
804
                }
×
805

806
                return nil
×
807
        }
808

809
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
810
                for {
×
811
                        nodes, err := db.ListNodesPaginated(
×
812
                                ctx, sqlc.ListNodesPaginatedParams{
×
813
                                        Version: int16(ProtocolV1),
×
814
                                        ID:      lastID,
×
815
                                        Limit:   pageSize,
×
816
                                },
×
817
                        )
×
818
                        if err != nil {
×
819
                                return fmt.Errorf("unable to fetch nodes: %w",
×
820
                                        err)
×
821
                        }
×
822

823
                        if len(nodes) == 0 {
×
824
                                break
×
825
                        }
826

827
                        for _, dbNode := range nodes {
×
828
                                err = handleNode(db, dbNode)
×
829
                                if err != nil {
×
830
                                        return err
×
831
                                }
×
832

833
                                lastID = dbNode.ID
×
834
                        }
835
                }
836

837
                return nil
×
838
        }, reset)
839
}
840

841
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
842
// SQLStore and a SQL transaction.
843
type sqlGraphNodeTx struct {
844
        db    SQLQueries
845
        id    int64
846
        node  *models.LightningNode
847
        chain chainhash.Hash
848
}
849

850
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
851
// interface.
852
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
853

854
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
855
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
856

×
857
        return &sqlGraphNodeTx{
×
858
                db:    db,
×
859
                chain: chain,
×
860
                id:    id,
×
861
                node:  node,
×
862
        }
×
863
}
×
864

865
// Node returns the raw information of the node.
866
//
867
// NOTE: This is a part of the NodeRTx interface.
868
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
869
        return s.node
×
870
}
×
871

872
// ForEachChannel can be used to iterate over the node's channels under the same
873
// transaction used to fetch the node.
874
//
875
// NOTE: This is a part of the NodeRTx interface.
876
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
877
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
878

×
879
        ctx := context.TODO()
×
880

×
881
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
882
}
×
883

884
// FetchNode fetches the node with the given pub key under the same transaction
885
// used to fetch the current node. The returned node is also a NodeRTx and any
886
// operations on that NodeRTx will also be done under the same transaction.
887
//
888
// NOTE: This is a part of the NodeRTx interface.
889
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
890
        ctx := context.TODO()
×
891

×
892
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
893
        if err != nil {
×
894
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
895
                        nodePub, err)
×
896
        }
×
897

898
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
899
}
900

901
// ForEachNodeDirectedChannel iterates through all channels of a given node,
902
// executing the passed callback on the directed edge representing the channel
903
// and its incoming policy. If the callback returns an error, then the iteration
904
// is halted with the error propagated back up to the caller.
905
//
906
// Unknown policies are passed into the callback as nil values.
907
//
908
// NOTE: this is part of the graphdb.NodeTraverser interface.
909
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
NEW
910
        cb func(channel *DirectedChannel) error, reset func()) error {
×
911

×
912
        var ctx = context.TODO()
×
913

×
914
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
915
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
NEW
916
        }, reset)
×
917
}
918

919
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
920
// graph, executing the passed callback with each node encountered. If the
921
// callback returns an error, then the transaction is aborted and the iteration
922
// stops early.
923
//
924
// NOTE: This is a part of the V1Store interface.
925
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
926
        cb func(route.Vertex, *lnwire.FeatureVector) error,
NEW
927
        reset func()) error {
×
928

×
929
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
930
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
931
                        nodePub route.Vertex) error {
×
932

×
933
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
934
                        if err != nil {
×
935
                                return fmt.Errorf("unable to fetch node "+
×
936
                                        "features: %w", err)
×
937
                        }
×
938

939
                        return cb(nodePub, features)
×
940
                })
941
        }, reset)
942
        if err != nil {
×
943
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
944
        }
×
945

946
        return nil
×
947
}
948

949
// ForEachNodeChannel iterates through all channels of the given node,
950
// executing the passed callback with an edge info structure and the policies
951
// of each end of the channel. The first edge policy is the outgoing edge *to*
952
// the connecting node, while the second is the incoming edge *from* the
953
// connecting node. If the callback returns an error, then the iteration is
954
// halted with the error propagated back up to the caller.
955
//
956
// Unknown policies are passed into the callback as nil values.
957
//
958
// NOTE: part of the V1Store interface.
959
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
960
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
NEW
961
                *models.ChannelEdgePolicy) error, reset func()) error {
×
962

×
963
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
964
                dbNode, err := db.GetNodeByPubKey(
×
965
                        ctx, sqlc.GetNodeByPubKeyParams{
×
966
                                Version: int16(ProtocolV1),
×
967
                                PubKey:  nodePub[:],
×
968
                        },
×
969
                )
×
970
                if errors.Is(err, sql.ErrNoRows) {
×
971
                        return nil
×
972
                } else if err != nil {
×
973
                        return fmt.Errorf("unable to fetch node: %w", err)
×
974
                }
×
975

976
                return forEachNodeChannel(
×
977
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
978
                )
×
979
        }, reset)
980
}
981

982
// ChanUpdatesInHorizon returns all the known channel edges which have at least
983
// one edge that has an update timestamp within the specified horizon.
984
//
985
// NOTE: This is part of the V1Store interface.
986
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
987
        endTime time.Time) ([]ChannelEdge, error) {
×
988

×
989
        s.cacheMu.Lock()
×
990
        defer s.cacheMu.Unlock()
×
991

×
992
        var (
×
993
                ctx = context.TODO()
×
994
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
995
                // an additional map to keep track of the edges already seen to
×
996
                // prevent re-adding it.
×
997
                edgesSeen    = make(map[uint64]struct{})
×
998
                edgesToCache = make(map[uint64]ChannelEdge)
×
999
                edges        []ChannelEdge
×
1000
                hits         int
×
1001
        )
×
1002
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1003
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1004
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1005
                                Version:   int16(ProtocolV1),
×
1006
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1007
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1008
                        },
×
1009
                )
×
1010
                if err != nil {
×
1011
                        return err
×
1012
                }
×
1013

1014
                for _, row := range rows {
×
1015
                        // If we've already retrieved the info and policies for
×
1016
                        // this edge, then we can skip it as we don't need to do
×
1017
                        // so again.
×
1018
                        chanIDInt := byteOrder.Uint64(row.Channel.Scid)
×
1019
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1020
                                continue
×
1021
                        }
1022

1023
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1024
                                hits++
×
1025
                                edgesSeen[chanIDInt] = struct{}{}
×
1026
                                edges = append(edges, channel)
×
1027

×
1028
                                continue
×
1029
                        }
1030

1031
                        node1, node2, err := buildNodes(
×
1032
                                ctx, db, row.Node, row.Node_2,
×
1033
                        )
×
1034
                        if err != nil {
×
1035
                                return err
×
1036
                        }
×
1037

1038
                        channel, err := getAndBuildEdgeInfo(
×
1039
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1040
                                row.Channel, node1.PubKeyBytes,
×
1041
                                node2.PubKeyBytes,
×
1042
                        )
×
1043
                        if err != nil {
×
1044
                                return fmt.Errorf("unable to build channel "+
×
1045
                                        "info: %w", err)
×
1046
                        }
×
1047

1048
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1049
                        if err != nil {
×
1050
                                return fmt.Errorf("unable to extract channel "+
×
1051
                                        "policies: %w", err)
×
1052
                        }
×
1053

1054
                        p1, p2, err := getAndBuildChanPolicies(
×
1055
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1056
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1057
                        )
×
1058
                        if err != nil {
×
1059
                                return fmt.Errorf("unable to build channel "+
×
1060
                                        "policies: %w", err)
×
1061
                        }
×
1062

1063
                        edgesSeen[chanIDInt] = struct{}{}
×
1064
                        chanEdge := ChannelEdge{
×
1065
                                Info:    channel,
×
1066
                                Policy1: p1,
×
1067
                                Policy2: p2,
×
1068
                                Node1:   node1,
×
1069
                                Node2:   node2,
×
1070
                        }
×
1071
                        edges = append(edges, chanEdge)
×
1072
                        edgesToCache[chanIDInt] = chanEdge
×
1073
                }
1074

1075
                return nil
×
1076
        }, func() {
×
1077
                edgesSeen = make(map[uint64]struct{})
×
1078
                edgesToCache = make(map[uint64]ChannelEdge)
×
1079
                edges = nil
×
1080
        })
×
1081
        if err != nil {
×
1082
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1083
        }
×
1084

1085
        // Insert any edges loaded from disk into the cache.
1086
        for chanid, channel := range edgesToCache {
×
1087
                s.chanCache.insert(chanid, channel)
×
1088
        }
×
1089

1090
        if len(edges) > 0 {
×
1091
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1092
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1093
        } else {
×
1094
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1095
                        "horizon (%s, %s)", startTime, endTime)
×
1096
        }
×
1097

1098
        return edges, nil
×
1099
}
1100

1101
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1102
// data to the call-back.
1103
//
1104
// NOTE: The callback contents MUST not be modified.
1105
//
1106
// NOTE: part of the V1Store interface.
1107
func (s *SQLStore) ForEachNodeCached(ctx context.Context,
1108
        cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error,
NEW
1109
        reset func()) error {
×
1110

×
1111
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1112
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1113
                        nodePub route.Vertex) error {
×
1114

×
1115
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1116
                        if err != nil {
×
1117
                                return fmt.Errorf("unable to fetch "+
×
1118
                                        "node(id=%d) features: %w", nodeID, err)
×
1119
                        }
×
1120

1121
                        toNodeCallback := func() route.Vertex {
×
1122
                                return nodePub
×
1123
                        }
×
1124

1125
                        rows, err := db.ListChannelsByNodeID(
×
1126
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1127
                                        Version: int16(ProtocolV1),
×
1128
                                        NodeID1: nodeID,
×
1129
                                },
×
1130
                        )
×
1131
                        if err != nil {
×
1132
                                return fmt.Errorf("unable to fetch channels "+
×
1133
                                        "of node(id=%d): %w", nodeID, err)
×
1134
                        }
×
1135

1136
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1137
                        for _, row := range rows {
×
1138
                                node1, node2, err := buildNodeVertices(
×
1139
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1140
                                )
×
1141
                                if err != nil {
×
1142
                                        return err
×
1143
                                }
×
1144

1145
                                e, err := getAndBuildEdgeInfo(
×
1146
                                        ctx, db, s.cfg.ChainHash,
×
1147
                                        row.Channel.ID, row.Channel, node1,
×
1148
                                        node2,
×
1149
                                )
×
1150
                                if err != nil {
×
1151
                                        return fmt.Errorf("unable to build "+
×
1152
                                                "channel info: %w", err)
×
1153
                                }
×
1154

1155
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1156
                                        row,
×
1157
                                )
×
1158
                                if err != nil {
×
1159
                                        return fmt.Errorf("unable to "+
×
1160
                                                "extract channel "+
×
1161
                                                "policies: %w", err)
×
1162
                                }
×
1163

1164
                                p1, p2, err := getAndBuildChanPolicies(
×
1165
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1166
                                        node1, node2,
×
1167
                                )
×
1168
                                if err != nil {
×
1169
                                        return fmt.Errorf("unable to "+
×
1170
                                                "build channel policies: %w",
×
1171
                                                err)
×
1172
                                }
×
1173

1174
                                // Determine the outgoing and incoming policy
1175
                                // for this channel and node combo.
1176
                                outPolicy, inPolicy := p1, p2
×
1177
                                if p1 != nil && p1.ToNode == nodePub {
×
1178
                                        outPolicy, inPolicy = p2, p1
×
1179
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1180
                                        outPolicy, inPolicy = p2, p1
×
1181
                                }
×
1182

1183
                                var cachedInPolicy *models.CachedEdgePolicy
×
1184
                                if inPolicy != nil {
×
1185
                                        cachedInPolicy = models.NewCachedPolicy(
×
1186
                                                p2,
×
1187
                                        )
×
1188
                                        cachedInPolicy.ToNodePubKey =
×
1189
                                                toNodeCallback
×
1190
                                        cachedInPolicy.ToNodeFeatures =
×
1191
                                                features
×
1192
                                }
×
1193

1194
                                var inboundFee lnwire.Fee
×
1195
                                outPolicy.InboundFee.WhenSome(
×
1196
                                        func(fee lnwire.Fee) {
×
1197
                                                inboundFee = fee
×
1198
                                        },
×
1199
                                )
1200

1201
                                directedChannel := &DirectedChannel{
×
1202
                                        ChannelID: e.ChannelID,
×
1203
                                        IsNode1: nodePub ==
×
1204
                                                e.NodeKey1Bytes,
×
1205
                                        OtherNode:    e.NodeKey2Bytes,
×
1206
                                        Capacity:     e.Capacity,
×
1207
                                        OutPolicySet: p1 != nil,
×
1208
                                        InPolicy:     cachedInPolicy,
×
1209
                                        InboundFee:   inboundFee,
×
1210
                                }
×
1211

×
1212
                                if nodePub == e.NodeKey2Bytes {
×
1213
                                        directedChannel.OtherNode =
×
1214
                                                e.NodeKey1Bytes
×
1215
                                }
×
1216

1217
                                channels[e.ChannelID] = directedChannel
×
1218
                        }
1219

1220
                        return cb(nodePub, channels)
×
1221
                })
1222
        }, reset)
1223
}
1224

1225
// ForEachChannelCacheable iterates through all the channel edges stored
1226
// within the graph and invokes the passed callback for each edge. The
1227
// callback takes two edges as since this is a directed graph, both the
1228
// in/out edges are visited. If the callback returns an error, then the
1229
// transaction is aborted and the iteration stops early.
1230
//
1231
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1232
// pointer for that particular channel edge routing policy will be
1233
// passed into the callback.
1234
//
1235
// NOTE: this method is like ForEachChannel but fetches only the data
1236
// required for the graph cache.
1237
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1238
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
NEW
1239
        reset func()) error {
×
1240

×
1241
        ctx := context.TODO()
×
1242

×
1243
        handleChannel := func(db SQLQueries,
×
1244
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1245

×
1246
                node1, node2, err := buildNodeVertices(
×
1247
                        row.Node1Pubkey, row.Node2Pubkey,
×
1248
                )
×
1249
                if err != nil {
×
1250
                        return err
×
1251
                }
×
1252

1253
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
1254

×
1255
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1256
                if err != nil {
×
1257
                        return err
×
1258
                }
×
1259

1260
                var pol1, pol2 *models.CachedEdgePolicy
×
1261
                if dbPol1 != nil {
×
1262
                        policy1, err := buildChanPolicy(
×
1263
                                *dbPol1, edge.ChannelID, nil, node2,
×
1264
                        )
×
1265
                        if err != nil {
×
1266
                                return err
×
1267
                        }
×
1268

1269
                        pol1 = models.NewCachedPolicy(policy1)
×
1270
                }
1271
                if dbPol2 != nil {
×
1272
                        policy2, err := buildChanPolicy(
×
1273
                                *dbPol2, edge.ChannelID, nil, node1,
×
1274
                        )
×
1275
                        if err != nil {
×
1276
                                return err
×
1277
                        }
×
1278

1279
                        pol2 = models.NewCachedPolicy(policy2)
×
1280
                }
1281

1282
                if err := cb(edge, pol1, pol2); err != nil {
×
1283
                        return err
×
1284
                }
×
1285

1286
                return nil
×
1287
        }
1288

1289
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1290
                lastID := int64(-1)
×
1291
                for {
×
1292
                        //nolint:ll
×
1293
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1294
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1295
                                        Version: int16(ProtocolV1),
×
1296
                                        ID:      lastID,
×
1297
                                        Limit:   pageSize,
×
1298
                                },
×
1299
                        )
×
1300
                        if err != nil {
×
1301
                                return err
×
1302
                        }
×
1303

1304
                        if len(rows) == 0 {
×
1305
                                break
×
1306
                        }
1307

1308
                        for _, row := range rows {
×
1309
                                err := handleChannel(db, row)
×
1310
                                if err != nil {
×
1311
                                        return err
×
1312
                                }
×
1313

1314
                                lastID = row.Channel.ID
×
1315
                        }
1316
                }
1317

1318
                return nil
×
1319
        }, reset)
1320
}
1321

1322
// ForEachChannel iterates through all the channel edges stored within the
1323
// graph and invokes the passed callback for each edge. The callback takes two
1324
// edges as since this is a directed graph, both the in/out edges are visited.
1325
// If the callback returns an error, then the transaction is aborted and the
1326
// iteration stops early.
1327
//
1328
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1329
// for that particular channel edge routing policy will be passed into the
1330
// callback.
1331
//
1332
// NOTE: part of the V1Store interface.
1333
func (s *SQLStore) ForEachChannel(ctx context.Context,
1334
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
NEW
1335
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1336

×
1337
        handleChannel := func(db SQLQueries,
×
1338
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1339

×
1340
                node1, node2, err := buildNodeVertices(
×
1341
                        row.Node1Pubkey, row.Node2Pubkey,
×
1342
                )
×
1343
                if err != nil {
×
1344
                        return fmt.Errorf("unable to build node vertices: %w",
×
1345
                                err)
×
1346
                }
×
1347

1348
                edge, err := getAndBuildEdgeInfo(
×
1349
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1350
                        node1, node2,
×
1351
                )
×
1352
                if err != nil {
×
1353
                        return fmt.Errorf("unable to build channel info: %w",
×
1354
                                err)
×
1355
                }
×
1356

1357
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1358
                if err != nil {
×
1359
                        return fmt.Errorf("unable to extract channel "+
×
1360
                                "policies: %w", err)
×
1361
                }
×
1362

1363
                p1, p2, err := getAndBuildChanPolicies(
×
1364
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1365
                )
×
1366
                if err != nil {
×
1367
                        return fmt.Errorf("unable to build channel "+
×
1368
                                "policies: %w", err)
×
1369
                }
×
1370

1371
                err = cb(edge, p1, p2)
×
1372
                if err != nil {
×
1373
                        return fmt.Errorf("callback failed for channel "+
×
1374
                                "id=%d: %w", edge.ChannelID, err)
×
1375
                }
×
1376

1377
                return nil
×
1378
        }
1379

1380
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1381
                lastID := int64(-1)
×
1382
                for {
×
1383
                        //nolint:ll
×
1384
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1385
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1386
                                        Version: int16(ProtocolV1),
×
1387
                                        ID:      lastID,
×
1388
                                        Limit:   pageSize,
×
1389
                                },
×
1390
                        )
×
1391
                        if err != nil {
×
1392
                                return err
×
1393
                        }
×
1394

1395
                        if len(rows) == 0 {
×
1396
                                break
×
1397
                        }
1398

1399
                        for _, row := range rows {
×
1400
                                err := handleChannel(db, row)
×
1401
                                if err != nil {
×
1402
                                        return err
×
1403
                                }
×
1404

1405
                                lastID = row.Channel.ID
×
1406
                        }
1407
                }
1408

1409
                return nil
×
1410
        }, reset)
1411
}
1412

1413
// FilterChannelRange returns the channel ID's of all known channels which were
1414
// mined in a block height within the passed range. The channel IDs are grouped
1415
// by their common block height. This method can be used to quickly share with a
1416
// peer the set of channels we know of within a particular range to catch them
1417
// up after a period of time offline. If withTimestamps is true then the
1418
// timestamp info of the latest received channel update messages of the channel
1419
// will be included in the response.
1420
//
1421
// NOTE: This is part of the V1Store interface.
1422
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1423
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1424

×
1425
        var (
×
1426
                ctx       = context.TODO()
×
1427
                startSCID = &lnwire.ShortChannelID{
×
1428
                        BlockHeight: startHeight,
×
1429
                }
×
1430
                endSCID = lnwire.ShortChannelID{
×
1431
                        BlockHeight: endHeight,
×
1432
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1433
                        TxPosition:  math.MaxUint16,
×
1434
                }
×
1435
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1436
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1437
        )
×
1438

×
1439
        // 1) get all channels where channelID is between start and end chan ID.
×
1440
        // 2) skip if not public (ie, no channel_proof)
×
1441
        // 3) collect that channel.
×
1442
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1443
        //    and add those timestamps to the collected channel.
×
1444
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1445
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1446
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1447
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1448
                                StartScid: chanIDStart,
×
1449
                                EndScid:   chanIDEnd,
×
1450
                        },
×
1451
                )
×
1452
                if err != nil {
×
1453
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1454
                                err)
×
1455
                }
×
1456

1457
                for _, dbChan := range dbChans {
×
1458
                        cid := lnwire.NewShortChanIDFromInt(
×
1459
                                byteOrder.Uint64(dbChan.Scid),
×
1460
                        )
×
1461
                        chanInfo := NewChannelUpdateInfo(
×
1462
                                cid, time.Time{}, time.Time{},
×
1463
                        )
×
1464

×
1465
                        if !withTimestamps {
×
1466
                                channelsPerBlock[cid.BlockHeight] = append(
×
1467
                                        channelsPerBlock[cid.BlockHeight],
×
1468
                                        chanInfo,
×
1469
                                )
×
1470

×
1471
                                continue
×
1472
                        }
1473

1474
                        //nolint:ll
1475
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1476
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1477
                                        Version:   int16(ProtocolV1),
×
1478
                                        ChannelID: dbChan.ID,
×
1479
                                        NodeID:    dbChan.NodeID1,
×
1480
                                },
×
1481
                        )
×
1482
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1483
                                return fmt.Errorf("unable to fetch node1 "+
×
1484
                                        "policy: %w", err)
×
1485
                        } else if err == nil {
×
1486
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1487
                                        node1Policy.LastUpdate.Int64, 0,
×
1488
                                )
×
1489
                        }
×
1490

1491
                        //nolint:ll
1492
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1493
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1494
                                        Version:   int16(ProtocolV1),
×
1495
                                        ChannelID: dbChan.ID,
×
1496
                                        NodeID:    dbChan.NodeID2,
×
1497
                                },
×
1498
                        )
×
1499
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1500
                                return fmt.Errorf("unable to fetch node2 "+
×
1501
                                        "policy: %w", err)
×
1502
                        } else if err == nil {
×
1503
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1504
                                        node2Policy.LastUpdate.Int64, 0,
×
1505
                                )
×
1506
                        }
×
1507

1508
                        channelsPerBlock[cid.BlockHeight] = append(
×
1509
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1510
                        )
×
1511
                }
1512

1513
                return nil
×
1514
        }, func() {
×
1515
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1516
        })
×
1517
        if err != nil {
×
1518
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1519
        }
×
1520

1521
        if len(channelsPerBlock) == 0 {
×
1522
                return nil, nil
×
1523
        }
×
1524

1525
        // Return the channel ranges in ascending block height order.
1526
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1527
        slices.Sort(blocks)
×
1528

×
1529
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1530
                return BlockChannelRange{
×
1531
                        Height:   block,
×
1532
                        Channels: channelsPerBlock[block],
×
1533
                }
×
1534
        }), nil
×
1535
}
1536

1537
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1538
// zombie. This method is used on an ad-hoc basis, when channels need to be
1539
// marked as zombies outside the normal pruning cycle.
1540
//
1541
// NOTE: part of the V1Store interface.
1542
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1543
        pubKey1, pubKey2 [33]byte) error {
×
1544

×
1545
        ctx := context.TODO()
×
1546

×
1547
        s.cacheMu.Lock()
×
1548
        defer s.cacheMu.Unlock()
×
1549

×
1550
        chanIDB := channelIDToBytes(chanID)
×
1551

×
1552
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1553
                return db.UpsertZombieChannel(
×
1554
                        ctx, sqlc.UpsertZombieChannelParams{
×
1555
                                Version:  int16(ProtocolV1),
×
1556
                                Scid:     chanIDB,
×
1557
                                NodeKey1: pubKey1[:],
×
1558
                                NodeKey2: pubKey2[:],
×
1559
                        },
×
1560
                )
×
1561
        }, sqldb.NoOpReset)
×
1562
        if err != nil {
×
1563
                return fmt.Errorf("unable to upsert zombie channel "+
×
1564
                        "(channel_id=%d): %w", chanID, err)
×
1565
        }
×
1566

1567
        s.rejectCache.remove(chanID)
×
1568
        s.chanCache.remove(chanID)
×
1569

×
1570
        return nil
×
1571
}
1572

1573
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1574
//
1575
// NOTE: part of the V1Store interface.
1576
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1577
        s.cacheMu.Lock()
×
1578
        defer s.cacheMu.Unlock()
×
1579

×
1580
        var (
×
1581
                ctx     = context.TODO()
×
1582
                chanIDB = channelIDToBytes(chanID)
×
1583
        )
×
1584

×
1585
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1586
                res, err := db.DeleteZombieChannel(
×
1587
                        ctx, sqlc.DeleteZombieChannelParams{
×
1588
                                Scid:    chanIDB,
×
1589
                                Version: int16(ProtocolV1),
×
1590
                        },
×
1591
                )
×
1592
                if err != nil {
×
1593
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1594
                                err)
×
1595
                }
×
1596

1597
                rows, err := res.RowsAffected()
×
1598
                if err != nil {
×
1599
                        return err
×
1600
                }
×
1601

1602
                if rows == 0 {
×
1603
                        return ErrZombieEdgeNotFound
×
1604
                } else if rows > 1 {
×
1605
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1606
                                "expected 1", rows)
×
1607
                }
×
1608

1609
                return nil
×
1610
        }, sqldb.NoOpReset)
1611
        if err != nil {
×
1612
                return fmt.Errorf("unable to mark edge live "+
×
1613
                        "(channel_id=%d): %w", chanID, err)
×
1614
        }
×
1615

1616
        s.rejectCache.remove(chanID)
×
1617
        s.chanCache.remove(chanID)
×
1618

×
1619
        return err
×
1620
}
1621

1622
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1623
// zombie, then the two node public keys corresponding to this edge are also
1624
// returned.
1625
//
1626
// NOTE: part of the V1Store interface.
1627
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1628
        error) {
×
1629

×
1630
        var (
×
1631
                ctx              = context.TODO()
×
1632
                isZombie         bool
×
1633
                pubKey1, pubKey2 route.Vertex
×
1634
                chanIDB          = channelIDToBytes(chanID)
×
1635
        )
×
1636

×
1637
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1638
                zombie, err := db.GetZombieChannel(
×
1639
                        ctx, sqlc.GetZombieChannelParams{
×
1640
                                Scid:    chanIDB,
×
1641
                                Version: int16(ProtocolV1),
×
1642
                        },
×
1643
                )
×
1644
                if errors.Is(err, sql.ErrNoRows) {
×
1645
                        return nil
×
1646
                }
×
1647
                if err != nil {
×
1648
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1649
                                err)
×
1650
                }
×
1651

1652
                copy(pubKey1[:], zombie.NodeKey1)
×
1653
                copy(pubKey2[:], zombie.NodeKey2)
×
1654
                isZombie = true
×
1655

×
1656
                return nil
×
1657
        }, sqldb.NoOpReset)
1658
        if err != nil {
×
1659
                return false, route.Vertex{}, route.Vertex{},
×
1660
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1661
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1662
        }
×
1663

1664
        return isZombie, pubKey1, pubKey2, nil
×
1665
}
1666

1667
// NumZombies returns the current number of zombie channels in the graph.
1668
//
1669
// NOTE: part of the V1Store interface.
1670
func (s *SQLStore) NumZombies() (uint64, error) {
×
1671
        var (
×
1672
                ctx        = context.TODO()
×
1673
                numZombies uint64
×
1674
        )
×
1675
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1676
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1677
                if err != nil {
×
1678
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1679
                                err)
×
1680
                }
×
1681

1682
                numZombies = uint64(count)
×
1683

×
1684
                return nil
×
1685
        }, sqldb.NoOpReset)
1686
        if err != nil {
×
1687
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1688
        }
×
1689

1690
        return numZombies, nil
×
1691
}
1692

1693
// DeleteChannelEdges removes edges with the given channel IDs from the
1694
// database and marks them as zombies. This ensures that we're unable to re-add
1695
// it to our database once again. If an edge does not exist within the
1696
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1697
// true, then when we mark these edges as zombies, we'll set up the keys such
1698
// that we require the node that failed to send the fresh update to be the one
1699
// that resurrects the channel from its zombie state. The markZombie bool
1700
// denotes whether to mark the channel as a zombie.
1701
//
1702
// NOTE: part of the V1Store interface.
1703
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1704
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1705

×
1706
        s.cacheMu.Lock()
×
1707
        defer s.cacheMu.Unlock()
×
1708

×
1709
        var (
×
1710
                ctx     = context.TODO()
×
1711
                deleted []*models.ChannelEdgeInfo
×
1712
        )
×
1713
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1714
                for _, chanID := range chanIDs {
×
1715
                        chanIDB := channelIDToBytes(chanID)
×
1716

×
1717
                        row, err := db.GetChannelBySCIDWithPolicies(
×
1718
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1719
                                        Scid:    chanIDB,
×
1720
                                        Version: int16(ProtocolV1),
×
1721
                                },
×
1722
                        )
×
1723
                        if errors.Is(err, sql.ErrNoRows) {
×
1724
                                return ErrEdgeNotFound
×
1725
                        } else if err != nil {
×
1726
                                return fmt.Errorf("unable to fetch channel: %w",
×
1727
                                        err)
×
1728
                        }
×
1729

1730
                        node1, node2, err := buildNodeVertices(
×
1731
                                row.Node.PubKey, row.Node_2.PubKey,
×
1732
                        )
×
1733
                        if err != nil {
×
1734
                                return err
×
1735
                        }
×
1736

1737
                        info, err := getAndBuildEdgeInfo(
×
1738
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1739
                                row.Channel, node1, node2,
×
1740
                        )
×
1741
                        if err != nil {
×
1742
                                return err
×
1743
                        }
×
1744

1745
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
1746
                        if err != nil {
×
1747
                                return fmt.Errorf("unable to delete "+
×
1748
                                        "channel: %w", err)
×
1749
                        }
×
1750

1751
                        deleted = append(deleted, info)
×
1752

×
1753
                        if !markZombie {
×
1754
                                continue
×
1755
                        }
1756

1757
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1758
                                info.NodeKey2Bytes
×
1759
                        if strictZombiePruning {
×
1760
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1761
                                if row.Policy1LastUpdate.Valid {
×
1762
                                        e1Time := time.Unix(
×
1763
                                                row.Policy1LastUpdate.Int64, 0,
×
1764
                                        )
×
1765
                                        e1UpdateTime = &e1Time
×
1766
                                }
×
1767
                                if row.Policy2LastUpdate.Valid {
×
1768
                                        e2Time := time.Unix(
×
1769
                                                row.Policy2LastUpdate.Int64, 0,
×
1770
                                        )
×
1771
                                        e2UpdateTime = &e2Time
×
1772
                                }
×
1773

1774
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1775
                                        info, e1UpdateTime, e2UpdateTime,
×
1776
                                )
×
1777
                        }
1778

1779
                        err = db.UpsertZombieChannel(
×
1780
                                ctx, sqlc.UpsertZombieChannelParams{
×
1781
                                        Version:  int16(ProtocolV1),
×
1782
                                        Scid:     chanIDB,
×
1783
                                        NodeKey1: nodeKey1[:],
×
1784
                                        NodeKey2: nodeKey2[:],
×
1785
                                },
×
1786
                        )
×
1787
                        if err != nil {
×
1788
                                return fmt.Errorf("unable to mark channel as "+
×
1789
                                        "zombie: %w", err)
×
1790
                        }
×
1791
                }
1792

1793
                return nil
×
1794
        }, func() {
×
1795
                deleted = nil
×
1796
        })
×
1797
        if err != nil {
×
1798
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1799
                        err)
×
1800
        }
×
1801

1802
        for _, chanID := range chanIDs {
×
1803
                s.rejectCache.remove(chanID)
×
1804
                s.chanCache.remove(chanID)
×
1805
        }
×
1806

1807
        return deleted, nil
×
1808
}
1809

1810
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1811
// channel identified by the channel ID. If the channel can't be found, then
1812
// ErrEdgeNotFound is returned. A struct which houses the general information
1813
// for the channel itself is returned as well as two structs that contain the
1814
// routing policies for the channel in either direction.
1815
//
1816
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1817
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1818
// the ChannelEdgeInfo will only include the public keys of each node.
1819
//
1820
// NOTE: part of the V1Store interface.
1821
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1822
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1823
        *models.ChannelEdgePolicy, error) {
×
1824

×
1825
        var (
×
1826
                ctx              = context.TODO()
×
1827
                edge             *models.ChannelEdgeInfo
×
1828
                policy1, policy2 *models.ChannelEdgePolicy
×
1829
                chanIDB          = channelIDToBytes(chanID)
×
1830
        )
×
1831
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1832
                row, err := db.GetChannelBySCIDWithPolicies(
×
1833
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1834
                                Scid:    chanIDB,
×
1835
                                Version: int16(ProtocolV1),
×
1836
                        },
×
1837
                )
×
1838
                if errors.Is(err, sql.ErrNoRows) {
×
1839
                        // First check if this edge is perhaps in the zombie
×
1840
                        // index.
×
1841
                        zombie, err := db.GetZombieChannel(
×
1842
                                ctx, sqlc.GetZombieChannelParams{
×
1843
                                        Scid:    chanIDB,
×
1844
                                        Version: int16(ProtocolV1),
×
1845
                                },
×
1846
                        )
×
1847
                        if errors.Is(err, sql.ErrNoRows) {
×
1848
                                return ErrEdgeNotFound
×
1849
                        } else if err != nil {
×
1850
                                return fmt.Errorf("unable to check if "+
×
1851
                                        "channel is zombie: %w", err)
×
1852
                        }
×
1853

1854
                        // At this point, we know the channel is a zombie, so
1855
                        // we'll return an error indicating this, and we will
1856
                        // populate the edge info with the public keys of each
1857
                        // party as this is the only information we have about
1858
                        // it.
1859
                        edge = &models.ChannelEdgeInfo{}
×
1860
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1861
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1862

×
1863
                        return ErrZombieEdge
×
1864
                } else if err != nil {
×
1865
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1866
                }
×
1867

1868
                node1, node2, err := buildNodeVertices(
×
1869
                        row.Node.PubKey, row.Node_2.PubKey,
×
1870
                )
×
1871
                if err != nil {
×
1872
                        return err
×
1873
                }
×
1874

1875
                edge, err = getAndBuildEdgeInfo(
×
1876
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1877
                        node1, node2,
×
1878
                )
×
1879
                if err != nil {
×
1880
                        return fmt.Errorf("unable to build channel info: %w",
×
1881
                                err)
×
1882
                }
×
1883

1884
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1885
                if err != nil {
×
1886
                        return fmt.Errorf("unable to extract channel "+
×
1887
                                "policies: %w", err)
×
1888
                }
×
1889

1890
                policy1, policy2, err = getAndBuildChanPolicies(
×
1891
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1892
                )
×
1893
                if err != nil {
×
1894
                        return fmt.Errorf("unable to build channel "+
×
1895
                                "policies: %w", err)
×
1896
                }
×
1897

1898
                return nil
×
1899
        }, sqldb.NoOpReset)
1900
        if err != nil {
×
1901
                // If we are returning the ErrZombieEdge, then we also need to
×
1902
                // return the edge info as the method comment indicates that
×
1903
                // this will be populated when the edge is a zombie.
×
1904
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1905
                        err)
×
1906
        }
×
1907

1908
        return edge, policy1, policy2, nil
×
1909
}
1910

1911
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1912
// the channel identified by the funding outpoint. If the channel can't be
1913
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1914
// information for the channel itself is returned as well as two structs that
1915
// contain the routing policies for the channel in either direction.
1916
//
1917
// NOTE: part of the V1Store interface.
1918
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1919
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1920
        *models.ChannelEdgePolicy, error) {
×
1921

×
1922
        var (
×
1923
                ctx              = context.TODO()
×
1924
                edge             *models.ChannelEdgeInfo
×
1925
                policy1, policy2 *models.ChannelEdgePolicy
×
1926
        )
×
1927
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1928
                row, err := db.GetChannelByOutpointWithPolicies(
×
1929
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1930
                                Outpoint: op.String(),
×
1931
                                Version:  int16(ProtocolV1),
×
1932
                        },
×
1933
                )
×
1934
                if errors.Is(err, sql.ErrNoRows) {
×
1935
                        return ErrEdgeNotFound
×
1936
                } else if err != nil {
×
1937
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1938
                }
×
1939

1940
                node1, node2, err := buildNodeVertices(
×
1941
                        row.Node1Pubkey, row.Node2Pubkey,
×
1942
                )
×
1943
                if err != nil {
×
1944
                        return err
×
1945
                }
×
1946

1947
                edge, err = getAndBuildEdgeInfo(
×
1948
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1949
                        node1, node2,
×
1950
                )
×
1951
                if err != nil {
×
1952
                        return fmt.Errorf("unable to build channel info: %w",
×
1953
                                err)
×
1954
                }
×
1955

1956
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1957
                if err != nil {
×
1958
                        return fmt.Errorf("unable to extract channel "+
×
1959
                                "policies: %w", err)
×
1960
                }
×
1961

1962
                policy1, policy2, err = getAndBuildChanPolicies(
×
1963
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1964
                )
×
1965
                if err != nil {
×
1966
                        return fmt.Errorf("unable to build channel "+
×
1967
                                "policies: %w", err)
×
1968
                }
×
1969

1970
                return nil
×
1971
        }, sqldb.NoOpReset)
1972
        if err != nil {
×
1973
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1974
                        err)
×
1975
        }
×
1976

1977
        return edge, policy1, policy2, nil
×
1978
}
1979

1980
// HasChannelEdge returns true if the database knows of a channel edge with the
1981
// passed channel ID, and false otherwise. If an edge with that ID is found
1982
// within the graph, then two time stamps representing the last time the edge
1983
// was updated for both directed edges are returned along with the boolean. If
1984
// it is not found, then the zombie index is checked and its result is returned
1985
// as the second boolean.
1986
//
1987
// NOTE: part of the V1Store interface.
1988
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1989
        bool, error) {
×
1990

×
1991
        ctx := context.TODO()
×
1992

×
1993
        var (
×
1994
                exists          bool
×
1995
                isZombie        bool
×
1996
                node1LastUpdate time.Time
×
1997
                node2LastUpdate time.Time
×
1998
        )
×
1999

×
2000
        // We'll query the cache with the shared lock held to allow multiple
×
2001
        // readers to access values in the cache concurrently if they exist.
×
2002
        s.cacheMu.RLock()
×
2003
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2004
                s.cacheMu.RUnlock()
×
2005
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2006
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2007
                exists, isZombie = entry.flags.unpack()
×
2008

×
2009
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2010
        }
×
2011
        s.cacheMu.RUnlock()
×
2012

×
2013
        s.cacheMu.Lock()
×
2014
        defer s.cacheMu.Unlock()
×
2015

×
2016
        // The item was not found with the shared lock, so we'll acquire the
×
2017
        // exclusive lock and check the cache again in case another method added
×
2018
        // the entry to the cache while no lock was held.
×
2019
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2020
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2021
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2022
                exists, isZombie = entry.flags.unpack()
×
2023

×
2024
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2025
        }
×
2026

2027
        chanIDB := channelIDToBytes(chanID)
×
2028
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2029
                channel, err := db.GetChannelBySCID(
×
2030
                        ctx, sqlc.GetChannelBySCIDParams{
×
2031
                                Scid:    chanIDB,
×
2032
                                Version: int16(ProtocolV1),
×
2033
                        },
×
2034
                )
×
2035
                if errors.Is(err, sql.ErrNoRows) {
×
2036
                        // Check if it is a zombie channel.
×
2037
                        isZombie, err = db.IsZombieChannel(
×
2038
                                ctx, sqlc.IsZombieChannelParams{
×
2039
                                        Scid:    chanIDB,
×
2040
                                        Version: int16(ProtocolV1),
×
2041
                                },
×
2042
                        )
×
2043
                        if err != nil {
×
2044
                                return fmt.Errorf("could not check if channel "+
×
2045
                                        "is zombie: %w", err)
×
2046
                        }
×
2047

2048
                        return nil
×
2049
                } else if err != nil {
×
2050
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2051
                }
×
2052

2053
                exists = true
×
2054

×
2055
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2056
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2057
                                Version:   int16(ProtocolV1),
×
2058
                                ChannelID: channel.ID,
×
2059
                                NodeID:    channel.NodeID1,
×
2060
                        },
×
2061
                )
×
2062
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2063
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2064
                                err)
×
2065
                } else if err == nil {
×
2066
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2067
                }
×
2068

2069
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2070
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2071
                                Version:   int16(ProtocolV1),
×
2072
                                ChannelID: channel.ID,
×
2073
                                NodeID:    channel.NodeID2,
×
2074
                        },
×
2075
                )
×
2076
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2077
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2078
                                err)
×
2079
                } else if err == nil {
×
2080
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2081
                }
×
2082

2083
                return nil
×
2084
        }, sqldb.NoOpReset)
2085
        if err != nil {
×
2086
                return time.Time{}, time.Time{}, false, false,
×
2087
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2088
        }
×
2089

2090
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2091
                upd1Time: node1LastUpdate.Unix(),
×
2092
                upd2Time: node2LastUpdate.Unix(),
×
2093
                flags:    packRejectFlags(exists, isZombie),
×
2094
        })
×
2095

×
2096
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2097
}
2098

2099
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2100
// passed channel point (outpoint). If the passed channel doesn't exist within
2101
// the database, then ErrEdgeNotFound is returned.
2102
//
2103
// NOTE: part of the V1Store interface.
2104
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2105
        var (
×
2106
                ctx       = context.TODO()
×
2107
                channelID uint64
×
2108
        )
×
2109
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2110
                chanID, err := db.GetSCIDByOutpoint(
×
2111
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2112
                                Outpoint: chanPoint.String(),
×
2113
                                Version:  int16(ProtocolV1),
×
2114
                        },
×
2115
                )
×
2116
                if errors.Is(err, sql.ErrNoRows) {
×
2117
                        return ErrEdgeNotFound
×
2118
                } else if err != nil {
×
2119
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2120
                                err)
×
2121
                }
×
2122

2123
                channelID = byteOrder.Uint64(chanID)
×
2124

×
2125
                return nil
×
2126
        }, sqldb.NoOpReset)
2127
        if err != nil {
×
2128
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2129
        }
×
2130

2131
        return channelID, nil
×
2132
}
2133

2134
// IsPublicNode is a helper method that determines whether the node with the
2135
// given public key is seen as a public node in the graph from the graph's
2136
// source node's point of view.
2137
//
2138
// NOTE: part of the V1Store interface.
2139
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2140
        ctx := context.TODO()
×
2141

×
2142
        var isPublic bool
×
2143
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2144
                var err error
×
2145
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2146

×
2147
                return err
×
2148
        }, sqldb.NoOpReset)
×
2149
        if err != nil {
×
2150
                return false, fmt.Errorf("unable to check if node is "+
×
2151
                        "public: %w", err)
×
2152
        }
×
2153

2154
        return isPublic, nil
×
2155
}
2156

2157
// FetchChanInfos returns the set of channel edges that correspond to the passed
2158
// channel ID's. If an edge is the query is unknown to the database, it will
2159
// skipped and the result will contain only those edges that exist at the time
2160
// of the query. This can be used to respond to peer queries that are seeking to
2161
// fill in gaps in their view of the channel graph.
2162
//
2163
// NOTE: part of the V1Store interface.
2164
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2165
        var (
×
2166
                ctx   = context.TODO()
×
2167
                edges []ChannelEdge
×
2168
        )
×
2169
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2170
                for _, chanID := range chanIDs {
×
2171
                        chanIDB := channelIDToBytes(chanID)
×
2172

×
2173
                        // TODO(elle): potentially optimize this by using
×
2174
                        //  sqlc.slice() once that works for both SQLite and
×
2175
                        //  Postgres.
×
2176
                        row, err := db.GetChannelBySCIDWithPolicies(
×
2177
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
2178
                                        Scid:    chanIDB,
×
2179
                                        Version: int16(ProtocolV1),
×
2180
                                },
×
2181
                        )
×
2182
                        if errors.Is(err, sql.ErrNoRows) {
×
2183
                                continue
×
2184
                        } else if err != nil {
×
2185
                                return fmt.Errorf("unable to fetch channel: %w",
×
2186
                                        err)
×
2187
                        }
×
2188

2189
                        node1, node2, err := buildNodes(
×
2190
                                ctx, db, row.Node, row.Node_2,
×
2191
                        )
×
2192
                        if err != nil {
×
2193
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2194
                                        err)
×
2195
                        }
×
2196

2197
                        edge, err := getAndBuildEdgeInfo(
×
2198
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2199
                                row.Channel, node1.PubKeyBytes,
×
2200
                                node2.PubKeyBytes,
×
2201
                        )
×
2202
                        if err != nil {
×
2203
                                return fmt.Errorf("unable to build "+
×
2204
                                        "channel info: %w", err)
×
2205
                        }
×
2206

2207
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2208
                        if err != nil {
×
2209
                                return fmt.Errorf("unable to extract channel "+
×
2210
                                        "policies: %w", err)
×
2211
                        }
×
2212

2213
                        p1, p2, err := getAndBuildChanPolicies(
×
2214
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2215
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2216
                        )
×
2217
                        if err != nil {
×
2218
                                return fmt.Errorf("unable to build channel "+
×
2219
                                        "policies: %w", err)
×
2220
                        }
×
2221

2222
                        edges = append(edges, ChannelEdge{
×
2223
                                Info:    edge,
×
2224
                                Policy1: p1,
×
2225
                                Policy2: p2,
×
2226
                                Node1:   node1,
×
2227
                                Node2:   node2,
×
2228
                        })
×
2229
                }
2230

2231
                return nil
×
2232
        }, func() {
×
2233
                edges = nil
×
2234
        })
×
2235
        if err != nil {
×
2236
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2237
        }
×
2238

2239
        return edges, nil
×
2240
}
2241

2242
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2243
// ID's that we don't know and are not known zombies of the passed set. In other
2244
// words, we perform a set difference of our set of chan ID's and the ones
2245
// passed in. This method can be used by callers to determine the set of
2246
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2247
// known zombies is also returned.
2248
//
2249
// NOTE: part of the V1Store interface.
2250
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2251
        []ChannelUpdateInfo, error) {
×
2252

×
2253
        var (
×
2254
                ctx          = context.TODO()
×
2255
                newChanIDs   []uint64
×
2256
                knownZombies []ChannelUpdateInfo
×
2257
        )
×
2258
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2259
                for _, chanInfo := range chansInfo {
×
2260
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2261
                        chanIDB := channelIDToBytes(channelID)
×
2262

×
2263
                        // TODO(elle): potentially optimize this by using
×
2264
                        //  sqlc.slice() once that works for both SQLite and
×
2265
                        //  Postgres.
×
2266
                        _, err := db.GetChannelBySCID(
×
2267
                                ctx, sqlc.GetChannelBySCIDParams{
×
2268
                                        Version: int16(ProtocolV1),
×
2269
                                        Scid:    chanIDB,
×
2270
                                },
×
2271
                        )
×
2272
                        if err == nil {
×
2273
                                continue
×
2274
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
2275
                                return fmt.Errorf("unable to fetch channel: %w",
×
2276
                                        err)
×
2277
                        }
×
2278

2279
                        isZombie, err := db.IsZombieChannel(
×
2280
                                ctx, sqlc.IsZombieChannelParams{
×
2281
                                        Scid:    chanIDB,
×
2282
                                        Version: int16(ProtocolV1),
×
2283
                                },
×
2284
                        )
×
2285
                        if err != nil {
×
2286
                                return fmt.Errorf("unable to fetch zombie "+
×
2287
                                        "channel: %w", err)
×
2288
                        }
×
2289

2290
                        if isZombie {
×
2291
                                knownZombies = append(knownZombies, chanInfo)
×
2292

×
2293
                                continue
×
2294
                        }
2295

2296
                        newChanIDs = append(newChanIDs, channelID)
×
2297
                }
2298

2299
                return nil
×
2300
        }, func() {
×
2301
                newChanIDs = nil
×
2302
                knownZombies = nil
×
2303
        })
×
2304
        if err != nil {
×
2305
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2306
        }
×
2307

2308
        return newChanIDs, knownZombies, nil
×
2309
}
2310

2311
// PruneGraphNodes is a garbage collection method which attempts to prune out
2312
// any nodes from the channel graph that are currently unconnected. This ensure
2313
// that we only maintain a graph of reachable nodes. In the event that a pruned
2314
// node gains more channels, it will be re-added back to the graph.
2315
//
2316
// NOTE: this prunes nodes across protocol versions. It will never prune the
2317
// source nodes.
2318
//
2319
// NOTE: part of the V1Store interface.
2320
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2321
        var ctx = context.TODO()
×
2322

×
2323
        var prunedNodes []route.Vertex
×
2324
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2325
                var err error
×
2326
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2327

×
2328
                return err
×
2329
        }, func() {
×
2330
                prunedNodes = nil
×
2331
        })
×
2332
        if err != nil {
×
2333
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2334
        }
×
2335

2336
        return prunedNodes, nil
×
2337
}
2338

2339
// PruneGraph prunes newly closed channels from the channel graph in response
2340
// to a new block being solved on the network. Any transactions which spend the
2341
// funding output of any known channels within he graph will be deleted.
2342
// Additionally, the "prune tip", or the last block which has been used to
2343
// prune the graph is stored so callers can ensure the graph is fully in sync
2344
// with the current UTXO state. A slice of channels that have been closed by
2345
// the target block along with any pruned nodes are returned if the function
2346
// succeeds without error.
2347
//
2348
// NOTE: part of the V1Store interface.
2349
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2350
        blockHash *chainhash.Hash, blockHeight uint32) (
2351
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2352

×
2353
        ctx := context.TODO()
×
2354

×
2355
        s.cacheMu.Lock()
×
2356
        defer s.cacheMu.Unlock()
×
2357

×
2358
        var (
×
2359
                closedChans []*models.ChannelEdgeInfo
×
2360
                prunedNodes []route.Vertex
×
2361
        )
×
2362
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2363
                for _, outpoint := range spentOutputs {
×
2364
                        // TODO(elle): potentially optimize this by using
×
2365
                        //  sqlc.slice() once that works for both SQLite and
×
2366
                        //  Postgres.
×
2367
                        //
×
2368
                        // NOTE: this fetches channels for all protocol
×
2369
                        // versions.
×
2370
                        row, err := db.GetChannelByOutpoint(
×
2371
                                ctx, outpoint.String(),
×
2372
                        )
×
2373
                        if errors.Is(err, sql.ErrNoRows) {
×
2374
                                continue
×
2375
                        } else if err != nil {
×
2376
                                return fmt.Errorf("unable to fetch channel: %w",
×
2377
                                        err)
×
2378
                        }
×
2379

2380
                        node1, node2, err := buildNodeVertices(
×
2381
                                row.Node1Pubkey, row.Node2Pubkey,
×
2382
                        )
×
2383
                        if err != nil {
×
2384
                                return err
×
2385
                        }
×
2386

2387
                        info, err := getAndBuildEdgeInfo(
×
2388
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2389
                                row.Channel, node1, node2,
×
2390
                        )
×
2391
                        if err != nil {
×
2392
                                return err
×
2393
                        }
×
2394

2395
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2396
                        if err != nil {
×
2397
                                return fmt.Errorf("unable to delete "+
×
2398
                                        "channel: %w", err)
×
2399
                        }
×
2400

2401
                        closedChans = append(closedChans, info)
×
2402
                }
2403

2404
                err := db.UpsertPruneLogEntry(
×
2405
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2406
                                BlockHash:   blockHash[:],
×
2407
                                BlockHeight: int64(blockHeight),
×
2408
                        },
×
2409
                )
×
2410
                if err != nil {
×
2411
                        return fmt.Errorf("unable to insert prune log "+
×
2412
                                "entry: %w", err)
×
2413
                }
×
2414

2415
                // Now that we've pruned some channels, we'll also prune any
2416
                // nodes that no longer have any channels.
2417
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2418
                if err != nil {
×
2419
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2420
                                err)
×
2421
                }
×
2422

2423
                return nil
×
2424
        }, func() {
×
2425
                prunedNodes = nil
×
2426
                closedChans = nil
×
2427
        })
×
2428
        if err != nil {
×
2429
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2430
        }
×
2431

2432
        for _, channel := range closedChans {
×
2433
                s.rejectCache.remove(channel.ChannelID)
×
2434
                s.chanCache.remove(channel.ChannelID)
×
2435
        }
×
2436

2437
        return closedChans, prunedNodes, nil
×
2438
}
2439

2440
// ChannelView returns the verifiable edge information for each active channel
2441
// within the known channel graph. The set of UTXOs (along with their scripts)
2442
// returned are the ones that need to be watched on chain to detect channel
2443
// closes on the resident blockchain.
2444
//
2445
// NOTE: part of the V1Store interface.
2446
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2447
        var (
×
2448
                ctx        = context.TODO()
×
2449
                edgePoints []EdgePoint
×
2450
        )
×
2451

×
2452
        handleChannel := func(db SQLQueries,
×
2453
                channel sqlc.ListChannelsPaginatedRow) error {
×
2454

×
2455
                pkScript, err := genMultiSigP2WSH(
×
2456
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
2457
                )
×
2458
                if err != nil {
×
2459
                        return err
×
2460
                }
×
2461

2462
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2463
                if err != nil {
×
2464
                        return err
×
2465
                }
×
2466

2467
                edgePoints = append(edgePoints, EdgePoint{
×
2468
                        FundingPkScript: pkScript,
×
2469
                        OutPoint:        *op,
×
2470
                })
×
2471

×
2472
                return nil
×
2473
        }
2474

2475
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2476
                lastID := int64(-1)
×
2477
                for {
×
2478
                        rows, err := db.ListChannelsPaginated(
×
2479
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2480
                                        Version: int16(ProtocolV1),
×
2481
                                        ID:      lastID,
×
2482
                                        Limit:   pageSize,
×
2483
                                },
×
2484
                        )
×
2485
                        if err != nil {
×
2486
                                return err
×
2487
                        }
×
2488

2489
                        if len(rows) == 0 {
×
2490
                                break
×
2491
                        }
2492

2493
                        for _, row := range rows {
×
2494
                                err := handleChannel(db, row)
×
2495
                                if err != nil {
×
2496
                                        return err
×
2497
                                }
×
2498

2499
                                lastID = row.ID
×
2500
                        }
2501
                }
2502

2503
                return nil
×
2504
        }, func() {
×
2505
                edgePoints = nil
×
2506
        })
×
2507
        if err != nil {
×
2508
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2509
        }
×
2510

2511
        return edgePoints, nil
×
2512
}
2513

2514
// PruneTip returns the block height and hash of the latest block that has been
2515
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2516
// to tell if the graph is currently in sync with the current best known UTXO
2517
// state.
2518
//
2519
// NOTE: part of the V1Store interface.
2520
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2521
        var (
×
2522
                ctx       = context.TODO()
×
2523
                tipHash   chainhash.Hash
×
2524
                tipHeight uint32
×
2525
        )
×
2526
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2527
                pruneTip, err := db.GetPruneTip(ctx)
×
2528
                if errors.Is(err, sql.ErrNoRows) {
×
2529
                        return ErrGraphNeverPruned
×
2530
                } else if err != nil {
×
2531
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2532
                }
×
2533

2534
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2535
                tipHeight = uint32(pruneTip.BlockHeight)
×
2536

×
2537
                return nil
×
2538
        }, sqldb.NoOpReset)
2539
        if err != nil {
×
2540
                return nil, 0, err
×
2541
        }
×
2542

2543
        return &tipHash, tipHeight, nil
×
2544
}
2545

2546
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2547
//
2548
// NOTE: this prunes nodes across protocol versions. It will never prune the
2549
// source nodes.
2550
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2551
        db SQLQueries) ([]route.Vertex, error) {
×
2552

×
2553
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2554
        if err != nil {
×
2555
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2556
                        "nodes: %w", err)
×
2557
        }
×
2558

2559
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2560
        for i, nodeKey := range nodeKeys {
×
2561
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2562
                if err != nil {
×
2563
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2564
                                "from bytes: %w", err)
×
2565
                }
×
2566

2567
                prunedNodes[i] = pub
×
2568
        }
2569

2570
        return prunedNodes, nil
×
2571
}
2572

2573
// DisconnectBlockAtHeight is used to indicate that the block specified
2574
// by the passed height has been disconnected from the main chain. This
2575
// will "rewind" the graph back to the height below, deleting channels
2576
// that are no longer confirmed from the graph. The prune log will be
2577
// set to the last prune height valid for the remaining chain.
2578
// Channels that were removed from the graph resulting from the
2579
// disconnected block are returned.
2580
//
2581
// NOTE: part of the V1Store interface.
2582
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2583
        []*models.ChannelEdgeInfo, error) {
×
2584

×
2585
        ctx := context.TODO()
×
2586

×
2587
        var (
×
2588
                // Every channel having a ShortChannelID starting at 'height'
×
2589
                // will no longer be confirmed.
×
2590
                startShortChanID = lnwire.ShortChannelID{
×
2591
                        BlockHeight: height,
×
2592
                }
×
2593

×
2594
                // Delete everything after this height from the db up until the
×
2595
                // SCID alias range.
×
2596
                endShortChanID = aliasmgr.StartingAlias
×
2597

×
2598
                removedChans []*models.ChannelEdgeInfo
×
2599

×
2600
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2601
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2602
        )
×
2603

×
2604
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2605
                rows, err := db.GetChannelsBySCIDRange(
×
2606
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2607
                                StartScid: chanIDStart,
×
2608
                                EndScid:   chanIDEnd,
×
2609
                        },
×
2610
                )
×
2611
                if err != nil {
×
2612
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2613
                }
×
2614

2615
                for _, row := range rows {
×
2616
                        node1, node2, err := buildNodeVertices(
×
2617
                                row.Node1PubKey, row.Node2PubKey,
×
2618
                        )
×
2619
                        if err != nil {
×
2620
                                return err
×
2621
                        }
×
2622

2623
                        channel, err := getAndBuildEdgeInfo(
×
2624
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2625
                                row.Channel, node1, node2,
×
2626
                        )
×
2627
                        if err != nil {
×
2628
                                return err
×
2629
                        }
×
2630

2631
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2632
                        if err != nil {
×
2633
                                return fmt.Errorf("unable to delete "+
×
2634
                                        "channel: %w", err)
×
2635
                        }
×
2636

2637
                        removedChans = append(removedChans, channel)
×
2638
                }
2639

2640
                return db.DeletePruneLogEntriesInRange(
×
2641
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2642
                                StartHeight: int64(height),
×
2643
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2644
                        },
×
2645
                )
×
2646
        }, func() {
×
2647
                removedChans = nil
×
2648
        })
×
2649
        if err != nil {
×
2650
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2651
                        "height: %w", err)
×
2652
        }
×
2653

2654
        for _, channel := range removedChans {
×
2655
                s.rejectCache.remove(channel.ChannelID)
×
2656
                s.chanCache.remove(channel.ChannelID)
×
2657
        }
×
2658

2659
        return removedChans, nil
×
2660
}
2661

2662
// AddEdgeProof sets the proof of an existing edge in the graph database.
2663
//
2664
// NOTE: part of the V1Store interface.
2665
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2666
        proof *models.ChannelAuthProof) error {
×
2667

×
2668
        var (
×
2669
                ctx       = context.TODO()
×
2670
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2671
        )
×
2672

×
2673
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2674
                res, err := db.AddV1ChannelProof(
×
2675
                        ctx, sqlc.AddV1ChannelProofParams{
×
2676
                                Scid:              scidBytes,
×
2677
                                Node1Signature:    proof.NodeSig1Bytes,
×
2678
                                Node2Signature:    proof.NodeSig2Bytes,
×
2679
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2680
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2681
                        },
×
2682
                )
×
2683
                if err != nil {
×
2684
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2685
                }
×
2686

2687
                n, err := res.RowsAffected()
×
2688
                if err != nil {
×
2689
                        return err
×
2690
                }
×
2691

2692
                if n == 0 {
×
2693
                        return fmt.Errorf("no rows affected when adding edge "+
×
2694
                                "proof for SCID %v", scid)
×
2695
                } else if n > 1 {
×
2696
                        return fmt.Errorf("multiple rows affected when adding "+
×
2697
                                "edge proof for SCID %v: %d rows affected",
×
2698
                                scid, n)
×
2699
                }
×
2700

2701
                return nil
×
2702
        }, sqldb.NoOpReset)
2703
        if err != nil {
×
2704
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2705
        }
×
2706

2707
        return nil
×
2708
}
2709

2710
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2711
// that we can ignore channel announcements that we know to be closed without
2712
// having to validate them and fetch a block.
2713
//
2714
// NOTE: part of the V1Store interface.
2715
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2716
        var (
×
2717
                ctx     = context.TODO()
×
2718
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2719
        )
×
2720

×
2721
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2722
                return db.InsertClosedChannel(ctx, chanIDB)
×
2723
        }, sqldb.NoOpReset)
×
2724
}
2725

2726
// IsClosedScid checks whether a channel identified by the passed in scid is
2727
// closed. This helps avoid having to perform expensive validation checks.
2728
//
2729
// NOTE: part of the V1Store interface.
2730
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2731
        var (
×
2732
                ctx      = context.TODO()
×
2733
                isClosed bool
×
2734
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2735
        )
×
2736
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2737
                var err error
×
2738
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2739
                if err != nil {
×
2740
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2741
                                err)
×
2742
                }
×
2743

2744
                return nil
×
2745
        }, sqldb.NoOpReset)
2746
        if err != nil {
×
2747
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2748
                        err)
×
2749
        }
×
2750

2751
        return isClosed, nil
×
2752
}
2753

2754
// GraphSession will provide the call-back with access to a NodeTraverser
2755
// instance which can be used to perform queries against the channel graph.
2756
//
2757
// NOTE: part of the V1Store interface.
2758
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
NEW
2759
        reset func()) error {
×
NEW
2760

×
2761
        var ctx = context.TODO()
×
2762

×
2763
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2764
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
NEW
2765
        }, reset)
×
2766
}
2767

2768
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2769
// read only transaction for a consistent view of the graph.
2770
type sqlNodeTraverser struct {
2771
        db    SQLQueries
2772
        chain chainhash.Hash
2773
}
2774

2775
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2776
// NodeTraverser interface.
2777
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2778

2779
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2780
func newSQLNodeTraverser(db SQLQueries,
2781
        chain chainhash.Hash) *sqlNodeTraverser {
×
2782

×
2783
        return &sqlNodeTraverser{
×
2784
                db:    db,
×
2785
                chain: chain,
×
2786
        }
×
2787
}
×
2788

2789
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2790
// node.
2791
//
2792
// NOTE: Part of the NodeTraverser interface.
2793
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
NEW
2794
        cb func(channel *DirectedChannel) error, _ func()) error {
×
2795

×
2796
        ctx := context.TODO()
×
2797

×
2798
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2799
}
×
2800

2801
// FetchNodeFeatures returns the features of the given node. If the node is
2802
// unknown, assume no additional features are supported.
2803
//
2804
// NOTE: Part of the NodeTraverser interface.
2805
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2806
        *lnwire.FeatureVector, error) {
×
2807

×
2808
        ctx := context.TODO()
×
2809

×
2810
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2811
}
×
2812

2813
// forEachNodeDirectedChannel iterates through all channels of a given
2814
// node, executing the passed callback on the directed edge representing the
2815
// channel and its incoming policy. If the node is not found, no error is
2816
// returned.
2817
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2818
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2819

×
2820
        toNodeCallback := func() route.Vertex {
×
2821
                return nodePub
×
2822
        }
×
2823

2824
        dbID, err := db.GetNodeIDByPubKey(
×
2825
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2826
                        Version: int16(ProtocolV1),
×
2827
                        PubKey:  nodePub[:],
×
2828
                },
×
2829
        )
×
2830
        if errors.Is(err, sql.ErrNoRows) {
×
2831
                return nil
×
2832
        } else if err != nil {
×
2833
                return fmt.Errorf("unable to fetch node: %w", err)
×
2834
        }
×
2835

2836
        rows, err := db.ListChannelsByNodeID(
×
2837
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2838
                        Version: int16(ProtocolV1),
×
2839
                        NodeID1: dbID,
×
2840
                },
×
2841
        )
×
2842
        if err != nil {
×
2843
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2844
        }
×
2845

2846
        // Exit early if there are no channels for this node so we don't
2847
        // do the unnecessary feature fetching.
2848
        if len(rows) == 0 {
×
2849
                return nil
×
2850
        }
×
2851

2852
        features, err := getNodeFeatures(ctx, db, dbID)
×
2853
        if err != nil {
×
2854
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2855
        }
×
2856

2857
        for _, row := range rows {
×
2858
                node1, node2, err := buildNodeVertices(
×
2859
                        row.Node1Pubkey, row.Node2Pubkey,
×
2860
                )
×
2861
                if err != nil {
×
2862
                        return fmt.Errorf("unable to build node vertices: %w",
×
2863
                                err)
×
2864
                }
×
2865

2866
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
2867

×
2868
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2869
                if err != nil {
×
2870
                        return err
×
2871
                }
×
2872

2873
                var p1, p2 *models.CachedEdgePolicy
×
2874
                if dbPol1 != nil {
×
2875
                        policy1, err := buildChanPolicy(
×
2876
                                *dbPol1, edge.ChannelID, nil, node2,
×
2877
                        )
×
2878
                        if err != nil {
×
2879
                                return err
×
2880
                        }
×
2881

2882
                        p1 = models.NewCachedPolicy(policy1)
×
2883
                }
2884
                if dbPol2 != nil {
×
2885
                        policy2, err := buildChanPolicy(
×
2886
                                *dbPol2, edge.ChannelID, nil, node1,
×
2887
                        )
×
2888
                        if err != nil {
×
2889
                                return err
×
2890
                        }
×
2891

2892
                        p2 = models.NewCachedPolicy(policy2)
×
2893
                }
2894

2895
                // Determine the outgoing and incoming policy for this
2896
                // channel and node combo.
2897
                outPolicy, inPolicy := p1, p2
×
2898
                if p1 != nil && node2 == nodePub {
×
2899
                        outPolicy, inPolicy = p2, p1
×
2900
                } else if p2 != nil && node1 != nodePub {
×
2901
                        outPolicy, inPolicy = p2, p1
×
2902
                }
×
2903

2904
                var cachedInPolicy *models.CachedEdgePolicy
×
2905
                if inPolicy != nil {
×
2906
                        cachedInPolicy = inPolicy
×
2907
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2908
                        cachedInPolicy.ToNodeFeatures = features
×
2909
                }
×
2910

2911
                directedChannel := &DirectedChannel{
×
2912
                        ChannelID:    edge.ChannelID,
×
2913
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2914
                        OtherNode:    edge.NodeKey2Bytes,
×
2915
                        Capacity:     edge.Capacity,
×
2916
                        OutPolicySet: outPolicy != nil,
×
2917
                        InPolicy:     cachedInPolicy,
×
2918
                }
×
2919
                if outPolicy != nil {
×
2920
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2921
                                directedChannel.InboundFee = fee
×
2922
                        })
×
2923
                }
2924

2925
                if nodePub == edge.NodeKey2Bytes {
×
2926
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2927
                }
×
2928

2929
                if err := cb(directedChannel); err != nil {
×
2930
                        return err
×
2931
                }
×
2932
        }
2933

2934
        return nil
×
2935
}
2936

2937
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2938
// and executes the provided callback for each node.
2939
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
2940
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
2941

×
2942
        lastID := int64(-1)
×
2943

×
2944
        for {
×
2945
                nodes, err := db.ListNodeIDsAndPubKeys(
×
2946
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
2947
                                Version: int16(ProtocolV1),
×
2948
                                ID:      lastID,
×
2949
                                Limit:   pageSize,
×
2950
                        },
×
2951
                )
×
2952
                if err != nil {
×
2953
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
2954
                }
×
2955

2956
                if len(nodes) == 0 {
×
2957
                        break
×
2958
                }
2959

2960
                for _, node := range nodes {
×
2961
                        var pub route.Vertex
×
2962
                        copy(pub[:], node.PubKey)
×
2963

×
2964
                        if err := cb(node.ID, pub); err != nil {
×
2965
                                return fmt.Errorf("forEachNodeCacheable "+
×
2966
                                        "callback failed for node(id=%d): %w",
×
2967
                                        node.ID, err)
×
2968
                        }
×
2969

2970
                        lastID = node.ID
×
2971
                }
2972
        }
2973

2974
        return nil
×
2975
}
2976

2977
// forEachNodeChannel iterates through all channels of a node, executing
2978
// the passed callback on each. The call-back is provided with the channel's
2979
// edge information, the outgoing policy and the incoming policy for the
2980
// channel and node combo.
2981
func forEachNodeChannel(ctx context.Context, db SQLQueries,
2982
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
2983
                *models.ChannelEdgePolicy,
2984
                *models.ChannelEdgePolicy) error) error {
×
2985

×
2986
        // Get all the V1 channels for this node.Add commentMore actions
×
2987
        rows, err := db.ListChannelsByNodeID(
×
2988
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2989
                        Version: int16(ProtocolV1),
×
2990
                        NodeID1: id,
×
2991
                },
×
2992
        )
×
2993
        if err != nil {
×
2994
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2995
        }
×
2996

2997
        // Call the call-back for each channel and its known policies.
2998
        for _, row := range rows {
×
2999
                node1, node2, err := buildNodeVertices(
×
3000
                        row.Node1Pubkey, row.Node2Pubkey,
×
3001
                )
×
3002
                if err != nil {
×
3003
                        return fmt.Errorf("unable to build node vertices: %w",
×
3004
                                err)
×
3005
                }
×
3006

3007
                edge, err := getAndBuildEdgeInfo(
×
3008
                        ctx, db, chain, row.Channel.ID, row.Channel, node1,
×
3009
                        node2,
×
3010
                )
×
3011
                if err != nil {
×
3012
                        return fmt.Errorf("unable to build channel info: %w",
×
3013
                                err)
×
3014
                }
×
3015

3016
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3017
                if err != nil {
×
3018
                        return fmt.Errorf("unable to extract channel "+
×
3019
                                "policies: %w", err)
×
3020
                }
×
3021

3022
                p1, p2, err := getAndBuildChanPolicies(
×
3023
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3024
                )
×
3025
                if err != nil {
×
3026
                        return fmt.Errorf("unable to build channel "+
×
3027
                                "policies: %w", err)
×
3028
                }
×
3029

3030
                // Determine the outgoing and incoming policy for this
3031
                // channel and node combo.
3032
                p1ToNode := row.Channel.NodeID2
×
3033
                p2ToNode := row.Channel.NodeID1
×
3034
                outPolicy, inPolicy := p1, p2
×
3035
                if (p1 != nil && p1ToNode == id) ||
×
3036
                        (p2 != nil && p2ToNode != id) {
×
3037

×
3038
                        outPolicy, inPolicy = p2, p1
×
3039
                }
×
3040

3041
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3042
                        return err
×
3043
                }
×
3044
        }
3045

3046
        return nil
×
3047
}
3048

3049
// updateChanEdgePolicy upserts the channel policy info we have stored for
3050
// a channel we already know of.
3051
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3052
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3053
        error) {
×
3054

×
3055
        var (
×
3056
                node1Pub, node2Pub route.Vertex
×
3057
                isNode1            bool
×
3058
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3059
        )
×
3060

×
3061
        // Check that this edge policy refers to a channel that we already
×
3062
        // know of. We do this explicitly so that we can return the appropriate
×
3063
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3064
        // abort the transaction which would abort the entire batch.
×
3065
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3066
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3067
                        Scid:    chanIDB,
×
3068
                        Version: int16(ProtocolV1),
×
3069
                },
×
3070
        )
×
3071
        if errors.Is(err, sql.ErrNoRows) {
×
3072
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3073
        } else if err != nil {
×
3074
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3075
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3076
        }
×
3077

3078
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3079
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3080

×
3081
        // Figure out which node this edge is from.
×
3082
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3083
        nodeID := dbChan.NodeID1
×
3084
        if !isNode1 {
×
3085
                nodeID = dbChan.NodeID2
×
3086
        }
×
3087

3088
        var (
×
3089
                inboundBase sql.NullInt64
×
3090
                inboundRate sql.NullInt64
×
3091
        )
×
3092
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3093
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3094
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3095
        })
×
3096

3097
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3098
                Version:     int16(ProtocolV1),
×
3099
                ChannelID:   dbChan.ID,
×
3100
                NodeID:      nodeID,
×
3101
                Timelock:    int32(edge.TimeLockDelta),
×
3102
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3103
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3104
                MinHtlcMsat: int64(edge.MinHTLC),
×
3105
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3106
                Disabled: sql.NullBool{
×
3107
                        Valid: true,
×
3108
                        Bool:  edge.IsDisabled(),
×
3109
                },
×
3110
                MaxHtlcMsat: sql.NullInt64{
×
3111
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3112
                        Int64: int64(edge.MaxHTLC),
×
3113
                },
×
3114
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3115
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3116
                InboundBaseFeeMsat:      inboundBase,
×
3117
                InboundFeeRateMilliMsat: inboundRate,
×
3118
                Signature:               edge.SigBytes,
×
3119
        })
×
3120
        if err != nil {
×
3121
                return node1Pub, node2Pub, isNode1,
×
3122
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3123
        }
×
3124

3125
        // Convert the flat extra opaque data into a map of TLV types to
3126
        // values.
3127
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3128
        if err != nil {
×
3129
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3130
                        "marshal extra opaque data: %w", err)
×
3131
        }
×
3132

3133
        // Update the channel policy's extra signed fields.
3134
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3135
        if err != nil {
×
3136
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3137
                        "policy extra TLVs: %w", err)
×
3138
        }
×
3139

3140
        return node1Pub, node2Pub, isNode1, nil
×
3141
}
3142

3143
// getNodeByPubKey attempts to look up a target node by its public key.
3144
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3145
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3146

×
3147
        dbNode, err := db.GetNodeByPubKey(
×
3148
                ctx, sqlc.GetNodeByPubKeyParams{
×
3149
                        Version: int16(ProtocolV1),
×
3150
                        PubKey:  pubKey[:],
×
3151
                },
×
3152
        )
×
3153
        if errors.Is(err, sql.ErrNoRows) {
×
3154
                return 0, nil, ErrGraphNodeNotFound
×
3155
        } else if err != nil {
×
3156
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3157
        }
×
3158

3159
        node, err := buildNode(ctx, db, &dbNode)
×
3160
        if err != nil {
×
3161
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3162
        }
×
3163

3164
        return dbNode.ID, node, nil
×
3165
}
3166

3167
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3168
// provided database channel row and the public keys of the two nodes
3169
// involved in the channel.
3170
func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub,
3171
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3172

×
3173
        return &models.CachedEdgeInfo{
×
3174
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
3175
                NodeKey1Bytes: node1Pub,
×
3176
                NodeKey2Bytes: node2Pub,
×
3177
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
3178
        }
×
3179
}
×
3180

3181
// buildNode constructs a LightningNode instance from the given database node
3182
// record. The node's features, addresses and extra signed fields are also
3183
// fetched from the database and set on the node.
3184
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
3185
        *models.LightningNode, error) {
×
3186

×
3187
        if dbNode.Version != int16(ProtocolV1) {
×
3188
                return nil, fmt.Errorf("unsupported node version: %d",
×
3189
                        dbNode.Version)
×
3190
        }
×
3191

3192
        var pub [33]byte
×
3193
        copy(pub[:], dbNode.PubKey)
×
3194

×
3195
        node := &models.LightningNode{
×
3196
                PubKeyBytes: pub,
×
3197
                Features:    lnwire.EmptyFeatureVector(),
×
3198
                LastUpdate:  time.Unix(0, 0),
×
3199
        }
×
3200

×
3201
        if len(dbNode.Signature) == 0 {
×
3202
                return node, nil
×
3203
        }
×
3204

3205
        node.HaveNodeAnnouncement = true
×
3206
        node.AuthSigBytes = dbNode.Signature
×
3207
        node.Alias = dbNode.Alias.String
×
3208
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3209

×
3210
        var err error
×
3211
        if dbNode.Color.Valid {
×
3212
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3213
                if err != nil {
×
3214
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3215
                                err)
×
3216
                }
×
3217
        }
3218

3219
        // Fetch the node's features.
3220
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
3221
        if err != nil {
×
3222
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3223
                        "features: %w", dbNode.ID, err)
×
3224
        }
×
3225

3226
        // Fetch the node's addresses.
3227
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
3228
        if err != nil {
×
3229
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3230
                        "addresses: %w", dbNode.ID, err)
×
3231
        }
×
3232

3233
        // Fetch the node's extra signed fields.
3234
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
3235
        if err != nil {
×
3236
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3237
                        "extra signed fields: %w", dbNode.ID, err)
×
3238
        }
×
3239

3240
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
3241
        if err != nil {
×
3242
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3243
                        "fields: %w", err)
×
3244
        }
×
3245

3246
        if len(recs) != 0 {
×
3247
                node.ExtraOpaqueData = recs
×
3248
        }
×
3249

3250
        return node, nil
×
3251
}
3252

3253
// getNodeFeatures fetches the feature bits and constructs the feature vector
3254
// for a node with the given DB ID.
3255
func getNodeFeatures(ctx context.Context, db SQLQueries,
3256
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3257

×
3258
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3259
        if err != nil {
×
3260
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3261
                        nodeID, err)
×
3262
        }
×
3263

3264
        features := lnwire.EmptyFeatureVector()
×
3265
        for _, feature := range rows {
×
3266
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3267
        }
×
3268

3269
        return features, nil
×
3270
}
3271

3272
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3273
// given DB ID.
3274
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3275
        nodeID int64) (map[uint64][]byte, error) {
×
3276

×
3277
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3278
        if err != nil {
×
3279
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
3280
                        "signed fields: %w", nodeID, err)
×
3281
        }
×
3282

3283
        extraFields := make(map[uint64][]byte)
×
3284
        for _, field := range fields {
×
3285
                extraFields[uint64(field.Type)] = field.Value
×
3286
        }
×
3287

3288
        return extraFields, nil
×
3289
}
3290

3291
// upsertNode upserts the node record into the database. If the node already
3292
// exists, then the node's information is updated. If the node doesn't exist,
3293
// then a new node is created. The node's features, addresses and extra TLV
3294
// types are also updated. The node's DB ID is returned.
3295
func upsertNode(ctx context.Context, db SQLQueries,
3296
        node *models.LightningNode) (int64, error) {
×
3297

×
3298
        params := sqlc.UpsertNodeParams{
×
3299
                Version: int16(ProtocolV1),
×
3300
                PubKey:  node.PubKeyBytes[:],
×
3301
        }
×
3302

×
3303
        if node.HaveNodeAnnouncement {
×
3304
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3305
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3306
                params.Alias = sqldb.SQLStr(node.Alias)
×
3307
                params.Signature = node.AuthSigBytes
×
3308
        }
×
3309

3310
        nodeID, err := db.UpsertNode(ctx, params)
×
3311
        if err != nil {
×
3312
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3313
                        err)
×
3314
        }
×
3315

3316
        // We can exit here if we don't have the announcement yet.
3317
        if !node.HaveNodeAnnouncement {
×
3318
                return nodeID, nil
×
3319
        }
×
3320

3321
        // Update the node's features.
3322
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3323
        if err != nil {
×
3324
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3325
        }
×
3326

3327
        // Update the node's addresses.
3328
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3329
        if err != nil {
×
3330
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3331
        }
×
3332

3333
        // Convert the flat extra opaque data into a map of TLV types to
3334
        // values.
3335
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3336
        if err != nil {
×
3337
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3338
                        err)
×
3339
        }
×
3340

3341
        // Update the node's extra signed fields.
3342
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3343
        if err != nil {
×
3344
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3345
        }
×
3346

3347
        return nodeID, nil
×
3348
}
3349

3350
// upsertNodeFeatures updates the node's features node_features table. This
3351
// includes deleting any feature bits no longer present and inserting any new
3352
// feature bits. If the feature bit does not yet exist in the features table,
3353
// then an entry is created in that table first.
3354
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3355
        features *lnwire.FeatureVector) error {
×
3356

×
3357
        // Get any existing features for the node.
×
3358
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3359
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3360
                return err
×
3361
        }
×
3362

3363
        // Copy the nodes latest set of feature bits.
3364
        newFeatures := make(map[int32]struct{})
×
3365
        if features != nil {
×
3366
                for feature := range features.Features() {
×
3367
                        newFeatures[int32(feature)] = struct{}{}
×
3368
                }
×
3369
        }
3370

3371
        // For any current feature that already exists in the DB, remove it from
3372
        // the in-memory map. For any existing feature that does not exist in
3373
        // the in-memory map, delete it from the database.
3374
        for _, feature := range existingFeatures {
×
3375
                // The feature is still present, so there are no updates to be
×
3376
                // made.
×
3377
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3378
                        delete(newFeatures, feature.FeatureBit)
×
3379
                        continue
×
3380
                }
3381

3382
                // The feature is no longer present, so we remove it from the
3383
                // database.
3384
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3385
                        NodeID:     nodeID,
×
3386
                        FeatureBit: feature.FeatureBit,
×
3387
                })
×
3388
                if err != nil {
×
3389
                        return fmt.Errorf("unable to delete node(%d) "+
×
3390
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3391
                                err)
×
3392
                }
×
3393
        }
3394

3395
        // Any remaining entries in newFeatures are new features that need to be
3396
        // added to the database for the first time.
3397
        for feature := range newFeatures {
×
3398
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3399
                        NodeID:     nodeID,
×
3400
                        FeatureBit: feature,
×
3401
                })
×
3402
                if err != nil {
×
3403
                        return fmt.Errorf("unable to insert node(%d) "+
×
3404
                                "feature(%v): %w", nodeID, feature, err)
×
3405
                }
×
3406
        }
3407

3408
        return nil
×
3409
}
3410

3411
// fetchNodeFeatures fetches the features for a node with the given public key.
3412
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3413
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3414

×
3415
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3416
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3417
                        PubKey:  nodePub[:],
×
3418
                        Version: int16(ProtocolV1),
×
3419
                },
×
3420
        )
×
3421
        if err != nil {
×
3422
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3423
                        nodePub, err)
×
3424
        }
×
3425

3426
        features := lnwire.EmptyFeatureVector()
×
3427
        for _, bit := range rows {
×
3428
                features.Set(lnwire.FeatureBit(bit))
×
3429
        }
×
3430

3431
        return features, nil
×
3432
}
3433

3434
// dbAddressType is an enum type that represents the different address types
3435
// that we store in the node_addresses table. The address type determines how
3436
// the address is to be serialised/deserialize.
3437
type dbAddressType uint8
3438

3439
const (
3440
        addressTypeIPv4   dbAddressType = 1
3441
        addressTypeIPv6   dbAddressType = 2
3442
        addressTypeTorV2  dbAddressType = 3
3443
        addressTypeTorV3  dbAddressType = 4
3444
        addressTypeOpaque dbAddressType = math.MaxInt8
3445
)
3446

3447
// upsertNodeAddresses updates the node's addresses in the database. This
3448
// includes deleting any existing addresses and inserting the new set of
3449
// addresses. The deletion is necessary since the ordering of the addresses may
3450
// change, and we need to ensure that the database reflects the latest set of
3451
// addresses so that at the time of reconstructing the node announcement, the
3452
// order is preserved and the signature over the message remains valid.
3453
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3454
        addresses []net.Addr) error {
×
3455

×
3456
        // Delete any existing addresses for the node. This is required since
×
3457
        // even if the new set of addresses is the same, the ordering may have
×
3458
        // changed for a given address type.
×
3459
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3460
        if err != nil {
×
3461
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3462
                        nodeID, err)
×
3463
        }
×
3464

3465
        // Copy the nodes latest set of addresses.
3466
        newAddresses := map[dbAddressType][]string{
×
3467
                addressTypeIPv4:   {},
×
3468
                addressTypeIPv6:   {},
×
3469
                addressTypeTorV2:  {},
×
3470
                addressTypeTorV3:  {},
×
3471
                addressTypeOpaque: {},
×
3472
        }
×
3473
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3474
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3475
        }
×
3476

3477
        for _, address := range addresses {
×
3478
                switch addr := address.(type) {
×
3479
                case *net.TCPAddr:
×
3480
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3481
                                addAddr(addressTypeIPv4, addr)
×
3482
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3483
                                addAddr(addressTypeIPv6, addr)
×
3484
                        } else {
×
3485
                                return fmt.Errorf("unhandled IP address: %v",
×
3486
                                        addr)
×
3487
                        }
×
3488

3489
                case *tor.OnionAddr:
×
3490
                        switch len(addr.OnionService) {
×
3491
                        case tor.V2Len:
×
3492
                                addAddr(addressTypeTorV2, addr)
×
3493
                        case tor.V3Len:
×
3494
                                addAddr(addressTypeTorV3, addr)
×
3495
                        default:
×
3496
                                return fmt.Errorf("invalid length for a tor " +
×
3497
                                        "address")
×
3498
                        }
3499

3500
                case *lnwire.OpaqueAddrs:
×
3501
                        addAddr(addressTypeOpaque, addr)
×
3502

3503
                default:
×
3504
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3505
                }
3506
        }
3507

3508
        // Any remaining entries in newAddresses are new addresses that need to
3509
        // be added to the database for the first time.
3510
        for addrType, addrList := range newAddresses {
×
3511
                for position, addr := range addrList {
×
3512
                        err := db.InsertNodeAddress(
×
3513
                                ctx, sqlc.InsertNodeAddressParams{
×
3514
                                        NodeID:   nodeID,
×
3515
                                        Type:     int16(addrType),
×
3516
                                        Address:  addr,
×
3517
                                        Position: int32(position),
×
3518
                                },
×
3519
                        )
×
3520
                        if err != nil {
×
3521
                                return fmt.Errorf("unable to insert "+
×
3522
                                        "node(%d) address(%v): %w", nodeID,
×
3523
                                        addr, err)
×
3524
                        }
×
3525
                }
3526
        }
3527

3528
        return nil
×
3529
}
3530

3531
// getNodeAddresses fetches the addresses for a node with the given public key.
3532
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3533
        []net.Addr, error) {
×
3534

×
3535
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
3536
        // are returned in the same order as they were inserted.
×
3537
        rows, err := db.GetNodeAddressesByPubKey(
×
3538
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
3539
                        Version: int16(ProtocolV1),
×
3540
                        PubKey:  nodePub,
×
3541
                },
×
3542
        )
×
3543
        if err != nil {
×
3544
                return false, nil, err
×
3545
        }
×
3546

3547
        // GetNodeAddressesByPubKey uses a left join so there should always be
3548
        // at least one row returned if the node exists even if it has no
3549
        // addresses.
3550
        if len(rows) == 0 {
×
3551
                return false, nil, nil
×
3552
        }
×
3553

3554
        addresses := make([]net.Addr, 0, len(rows))
×
3555
        for _, addr := range rows {
×
3556
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3557
                        continue
×
3558
                }
3559

3560
                address := addr.Address.String
×
3561

×
3562
                switch dbAddressType(addr.Type.Int16) {
×
3563
                case addressTypeIPv4:
×
3564
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3565
                        if err != nil {
×
3566
                                return false, nil, nil
×
3567
                        }
×
3568
                        tcp.IP = tcp.IP.To4()
×
3569

×
3570
                        addresses = append(addresses, tcp)
×
3571

3572
                case addressTypeIPv6:
×
3573
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3574
                        if err != nil {
×
3575
                                return false, nil, nil
×
3576
                        }
×
3577
                        addresses = append(addresses, tcp)
×
3578

3579
                case addressTypeTorV3, addressTypeTorV2:
×
3580
                        service, portStr, err := net.SplitHostPort(address)
×
3581
                        if err != nil {
×
3582
                                return false, nil, fmt.Errorf("unable to "+
×
3583
                                        "split tor v3 address: %v",
×
3584
                                        addr.Address)
×
3585
                        }
×
3586

3587
                        port, err := strconv.Atoi(portStr)
×
3588
                        if err != nil {
×
3589
                                return false, nil, err
×
3590
                        }
×
3591

3592
                        addresses = append(addresses, &tor.OnionAddr{
×
3593
                                OnionService: service,
×
3594
                                Port:         port,
×
3595
                        })
×
3596

3597
                case addressTypeOpaque:
×
3598
                        opaque, err := hex.DecodeString(address)
×
3599
                        if err != nil {
×
3600
                                return false, nil, fmt.Errorf("unable to "+
×
3601
                                        "decode opaque address: %v", addr)
×
3602
                        }
×
3603

3604
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3605
                                Payload: opaque,
×
3606
                        })
×
3607

3608
                default:
×
3609
                        return false, nil, fmt.Errorf("unknown address "+
×
3610
                                "type: %v", addr.Type)
×
3611
                }
3612
        }
3613

3614
        // If we have no addresses, then we'll return nil instead of an
3615
        // empty slice.
3616
        if len(addresses) == 0 {
×
3617
                addresses = nil
×
3618
        }
×
3619

3620
        return true, addresses, nil
×
3621
}
3622

3623
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3624
// database. This includes updating any existing types, inserting any new types,
3625
// and deleting any types that are no longer present.
3626
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3627
        nodeID int64, extraFields map[uint64][]byte) error {
×
3628

×
3629
        // Get any existing extra signed fields for the node.
×
3630
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3631
        if err != nil {
×
3632
                return err
×
3633
        }
×
3634

3635
        // Make a lookup map of the existing field types so that we can use it
3636
        // to keep track of any fields we should delete.
3637
        m := make(map[uint64]bool)
×
3638
        for _, field := range existingFields {
×
3639
                m[uint64(field.Type)] = true
×
3640
        }
×
3641

3642
        // For all the new fields, we'll upsert them and remove them from the
3643
        // map of existing fields.
3644
        for tlvType, value := range extraFields {
×
3645
                err = db.UpsertNodeExtraType(
×
3646
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3647
                                NodeID: nodeID,
×
3648
                                Type:   int64(tlvType),
×
3649
                                Value:  value,
×
3650
                        },
×
3651
                )
×
3652
                if err != nil {
×
3653
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3654
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3655
                }
×
3656

3657
                // Remove the field from the map of existing fields if it was
3658
                // present.
3659
                delete(m, tlvType)
×
3660
        }
3661

3662
        // For all the fields that are left in the map of existing fields, we'll
3663
        // delete them as they are no longer present in the new set of fields.
3664
        for tlvType := range m {
×
3665
                err = db.DeleteExtraNodeType(
×
3666
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3667
                                NodeID: nodeID,
×
3668
                                Type:   int64(tlvType),
×
3669
                        },
×
3670
                )
×
3671
                if err != nil {
×
3672
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3673
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3674
                }
×
3675
        }
3676

3677
        return nil
×
3678
}
3679

3680
// srcNodeInfo holds the information about the source node of the graph.
3681
type srcNodeInfo struct {
3682
        // id is the DB level ID of the source node entry in the "nodes" table.
3683
        id int64
3684

3685
        // pub is the public key of the source node.
3686
        pub route.Vertex
3687
}
3688

3689
// sourceNode returns the DB node ID and pub key of the source node for the
3690
// specified protocol version.
3691
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3692
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3693

×
3694
        s.srcNodeMu.Lock()
×
3695
        defer s.srcNodeMu.Unlock()
×
3696

×
3697
        // If we already have the source node ID and pub key cached, then
×
3698
        // return them.
×
3699
        if info, ok := s.srcNodes[version]; ok {
×
3700
                return info.id, info.pub, nil
×
3701
        }
×
3702

3703
        var pubKey route.Vertex
×
3704

×
3705
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3706
        if err != nil {
×
3707
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3708
                        err)
×
3709
        }
×
3710

3711
        if len(nodes) == 0 {
×
3712
                return 0, pubKey, ErrSourceNodeNotSet
×
3713
        } else if len(nodes) > 1 {
×
3714
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3715
                        "protocol %s found", version)
×
3716
        }
×
3717

3718
        copy(pubKey[:], nodes[0].PubKey)
×
3719

×
3720
        s.srcNodes[version] = &srcNodeInfo{
×
3721
                id:  nodes[0].NodeID,
×
3722
                pub: pubKey,
×
3723
        }
×
3724

×
3725
        return nodes[0].NodeID, pubKey, nil
×
3726
}
3727

3728
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3729
// This then produces a map from TLV type to value. If the input is not a
3730
// valid TLV stream, then an error is returned.
3731
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3732
        r := bytes.NewReader(data)
×
3733

×
3734
        tlvStream, err := tlv.NewStream()
×
3735
        if err != nil {
×
3736
                return nil, err
×
3737
        }
×
3738

3739
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3740
        // pass it into the P2P decoding variant.
3741
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3742
        if err != nil {
×
3743
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3744
        }
×
3745
        if len(parsedTypes) == 0 {
×
3746
                return nil, nil
×
3747
        }
×
3748

3749
        records := make(map[uint64][]byte)
×
3750
        for k, v := range parsedTypes {
×
3751
                records[uint64(k)] = v
×
3752
        }
×
3753

3754
        return records, nil
×
3755
}
3756

3757
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3758
// channel.
3759
type dbChanInfo struct {
3760
        channelID int64
3761
        node1ID   int64
3762
        node2ID   int64
3763
}
3764

3765
// insertChannel inserts a new channel record into the database.
3766
func insertChannel(ctx context.Context, db SQLQueries,
3767
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3768

×
3769
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3770

×
3771
        // Make sure that the channel doesn't already exist. We do this
×
3772
        // explicitly instead of relying on catching a unique constraint error
×
3773
        // because relying on SQL to throw that error would abort the entire
×
3774
        // batch of transactions.
×
3775
        _, err := db.GetChannelBySCID(
×
3776
                ctx, sqlc.GetChannelBySCIDParams{
×
3777
                        Scid:    chanIDB,
×
3778
                        Version: int16(ProtocolV1),
×
3779
                },
×
3780
        )
×
3781
        if err == nil {
×
3782
                return nil, ErrEdgeAlreadyExist
×
3783
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3784
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3785
        }
×
3786

3787
        // Make sure that at least a "shell" entry for each node is present in
3788
        // the nodes table.
3789
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3790
        if err != nil {
×
3791
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3792
        }
×
3793

3794
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3795
        if err != nil {
×
3796
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3797
        }
×
3798

3799
        var capacity sql.NullInt64
×
3800
        if edge.Capacity != 0 {
×
3801
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3802
        }
×
3803

3804
        createParams := sqlc.CreateChannelParams{
×
3805
                Version:     int16(ProtocolV1),
×
3806
                Scid:        chanIDB,
×
3807
                NodeID1:     node1DBID,
×
3808
                NodeID2:     node2DBID,
×
3809
                Outpoint:    edge.ChannelPoint.String(),
×
3810
                Capacity:    capacity,
×
3811
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3812
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3813
        }
×
3814

×
3815
        if edge.AuthProof != nil {
×
3816
                proof := edge.AuthProof
×
3817

×
3818
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3819
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3820
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3821
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3822
        }
×
3823

3824
        // Insert the new channel record.
3825
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3826
        if err != nil {
×
3827
                return nil, err
×
3828
        }
×
3829

3830
        // Insert any channel features.
3831
        for feature := range edge.Features.Features() {
×
3832
                err = db.InsertChannelFeature(
×
3833
                        ctx, sqlc.InsertChannelFeatureParams{
×
3834
                                ChannelID:  dbChanID,
×
3835
                                FeatureBit: int32(feature),
×
3836
                        },
×
3837
                )
×
3838
                if err != nil {
×
3839
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
3840
                                "feature(%v): %w", dbChanID, feature, err)
×
3841
                }
×
3842
        }
3843

3844
        // Finally, insert any extra TLV fields in the channel announcement.
3845
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3846
        if err != nil {
×
3847
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
3848
                        "data: %w", err)
×
3849
        }
×
3850

3851
        for tlvType, value := range extra {
×
3852
                err := db.CreateChannelExtraType(
×
3853
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3854
                                ChannelID: dbChanID,
×
3855
                                Type:      int64(tlvType),
×
3856
                                Value:     value,
×
3857
                        },
×
3858
                )
×
3859
                if err != nil {
×
3860
                        return nil, fmt.Errorf("unable to upsert "+
×
3861
                                "channel(%d) extra signed field(%v): %w",
×
3862
                                edge.ChannelID, tlvType, err)
×
3863
                }
×
3864
        }
3865

3866
        return &dbChanInfo{
×
3867
                channelID: dbChanID,
×
3868
                node1ID:   node1DBID,
×
3869
                node2ID:   node2DBID,
×
3870
        }, nil
×
3871
}
3872

3873
// maybeCreateShellNode checks if a shell node entry exists for the
3874
// given public key. If it does not exist, then a new shell node entry is
3875
// created. The ID of the node is returned. A shell node only has a protocol
3876
// version and public key persisted.
3877
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3878
        pubKey route.Vertex) (int64, error) {
×
3879

×
3880
        dbNode, err := db.GetNodeByPubKey(
×
3881
                ctx, sqlc.GetNodeByPubKeyParams{
×
3882
                        PubKey:  pubKey[:],
×
3883
                        Version: int16(ProtocolV1),
×
3884
                },
×
3885
        )
×
3886
        // The node exists. Return the ID.
×
3887
        if err == nil {
×
3888
                return dbNode.ID, nil
×
3889
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3890
                return 0, err
×
3891
        }
×
3892

3893
        // Otherwise, the node does not exist, so we create a shell entry for
3894
        // it.
3895
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3896
                Version: int16(ProtocolV1),
×
3897
                PubKey:  pubKey[:],
×
3898
        })
×
3899
        if err != nil {
×
3900
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3901
        }
×
3902

3903
        return id, nil
×
3904
}
3905

3906
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3907
// the database. This includes deleting any existing types and then inserting
3908
// the new types.
3909
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3910
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3911

×
3912
        // Delete all existing extra signed fields for the channel policy.
×
3913
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3914
        if err != nil {
×
3915
                return fmt.Errorf("unable to delete "+
×
3916
                        "existing policy extra signed fields for policy %d: %w",
×
3917
                        chanPolicyID, err)
×
3918
        }
×
3919

3920
        // Insert all new extra signed fields for the channel policy.
3921
        for tlvType, value := range extraFields {
×
3922
                err = db.InsertChanPolicyExtraType(
×
3923
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3924
                                ChannelPolicyID: chanPolicyID,
×
3925
                                Type:            int64(tlvType),
×
3926
                                Value:           value,
×
3927
                        },
×
3928
                )
×
3929
                if err != nil {
×
3930
                        return fmt.Errorf("unable to insert "+
×
3931
                                "channel_policy(%d) extra signed field(%v): %w",
×
3932
                                chanPolicyID, tlvType, err)
×
3933
                }
×
3934
        }
3935

3936
        return nil
×
3937
}
3938

3939
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3940
// provided dbChanRow and also fetches any other required information
3941
// to construct the edge info.
3942
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
3943
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1,
3944
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3945

×
3946
        if dbChan.Version != int16(ProtocolV1) {
×
3947
                return nil, fmt.Errorf("unsupported channel version: %d",
×
3948
                        dbChan.Version)
×
3949
        }
×
3950

3951
        fv, extras, err := getChanFeaturesAndExtras(
×
3952
                ctx, db, dbChanID,
×
3953
        )
×
3954
        if err != nil {
×
3955
                return nil, err
×
3956
        }
×
3957

3958
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
3959
        if err != nil {
×
3960
                return nil, err
×
3961
        }
×
3962

3963
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3964
        if err != nil {
×
3965
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3966
                        "fields: %w", err)
×
3967
        }
×
3968
        if recs == nil {
×
3969
                recs = make([]byte, 0)
×
3970
        }
×
3971

3972
        var btcKey1, btcKey2 route.Vertex
×
3973
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
3974
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
3975

×
3976
        channel := &models.ChannelEdgeInfo{
×
3977
                ChainHash:        chain,
×
3978
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
3979
                NodeKey1Bytes:    node1,
×
3980
                NodeKey2Bytes:    node2,
×
3981
                BitcoinKey1Bytes: btcKey1,
×
3982
                BitcoinKey2Bytes: btcKey2,
×
3983
                ChannelPoint:     *op,
×
3984
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
3985
                Features:         fv,
×
3986
                ExtraOpaqueData:  recs,
×
3987
        }
×
3988

×
3989
        // We always set all the signatures at the same time, so we can
×
3990
        // safely check if one signature is present to determine if we have the
×
3991
        // rest of the signatures for the auth proof.
×
3992
        if len(dbChan.Bitcoin1Signature) > 0 {
×
3993
                channel.AuthProof = &models.ChannelAuthProof{
×
3994
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
3995
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
3996
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
3997
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
3998
                }
×
3999
        }
×
4000

4001
        return channel, nil
×
4002
}
4003

4004
// buildNodeVertices is a helper that converts raw node public keys
4005
// into route.Vertex instances.
4006
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4007
        route.Vertex, error) {
×
4008

×
4009
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4010
        if err != nil {
×
4011
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4012
                        "create vertex from node1 pubkey: %w", err)
×
4013
        }
×
4014

4015
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4016
        if err != nil {
×
4017
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4018
                        "create vertex from node2 pubkey: %w", err)
×
4019
        }
×
4020

4021
        return node1Vertex, node2Vertex, nil
×
4022
}
4023

4024
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
4025
// for a channel with the given ID.
4026
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
4027
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
4028

×
4029
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
4030
        if err != nil {
×
4031
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
4032
                        "features and extras: %w", err)
×
4033
        }
×
4034

4035
        var (
×
4036
                fv     = lnwire.EmptyFeatureVector()
×
4037
                extras = make(map[uint64][]byte)
×
4038
        )
×
4039
        for _, row := range rows {
×
4040
                if row.IsFeature {
×
4041
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
4042

×
4043
                        continue
×
4044
                }
4045

4046
                tlvType, ok := row.ExtraKey.(int64)
×
4047
                if !ok {
×
4048
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4049
                                "TLV type: %T", row.ExtraKey)
×
4050
                }
×
4051

4052
                valueBytes, ok := row.Value.([]byte)
×
4053
                if !ok {
×
4054
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4055
                                "Value: %T", row.Value)
×
4056
                }
×
4057

4058
                extras[uint64(tlvType)] = valueBytes
×
4059
        }
4060

4061
        return fv, extras, nil
×
4062
}
4063

4064
// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves
4065
// all the extra info required to build the complete models.ChannelEdgePolicy
4066
// types. It returns two policies, which may be nil if the provided
4067
// sqlc.ChannelPolicy records are nil.
4068
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4069
        dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1,
4070
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4071
        *models.ChannelEdgePolicy, error) {
×
4072

×
4073
        if dbPol1 == nil && dbPol2 == nil {
×
4074
                return nil, nil, nil
×
4075
        }
×
4076

4077
        var (
×
4078
                policy1ID int64
×
4079
                policy2ID int64
×
4080
        )
×
4081
        if dbPol1 != nil {
×
4082
                policy1ID = dbPol1.ID
×
4083
        }
×
4084
        if dbPol2 != nil {
×
4085
                policy2ID = dbPol2.ID
×
4086
        }
×
4087
        rows, err := db.GetChannelPolicyExtraTypes(
×
4088
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
4089
                        ID:   policy1ID,
×
4090
                        ID_2: policy2ID,
×
4091
                },
×
4092
        )
×
4093
        if err != nil {
×
4094
                return nil, nil, err
×
4095
        }
×
4096

4097
        var (
×
4098
                dbPol1Extras = make(map[uint64][]byte)
×
4099
                dbPol2Extras = make(map[uint64][]byte)
×
4100
        )
×
4101
        for _, row := range rows {
×
4102
                switch row.PolicyID {
×
4103
                case policy1ID:
×
4104
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
4105
                case policy2ID:
×
4106
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
4107
                default:
×
4108
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
4109
                                "in row: %v", row.PolicyID, row)
×
4110
                }
4111
        }
4112

4113
        var pol1, pol2 *models.ChannelEdgePolicy
×
4114
        if dbPol1 != nil {
×
4115
                pol1, err = buildChanPolicy(
×
4116
                        *dbPol1, channelID, dbPol1Extras, node2,
×
4117
                )
×
4118
                if err != nil {
×
4119
                        return nil, nil, err
×
4120
                }
×
4121
        }
4122
        if dbPol2 != nil {
×
4123
                pol2, err = buildChanPolicy(
×
4124
                        *dbPol2, channelID, dbPol2Extras, node1,
×
4125
                )
×
4126
                if err != nil {
×
4127
                        return nil, nil, err
×
4128
                }
×
4129
        }
4130

4131
        return pol1, pol2, nil
×
4132
}
4133

4134
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4135
// provided sqlc.ChannelPolicy and other required information.
4136
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
4137
        extras map[uint64][]byte,
4138
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4139

×
4140
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4141
        if err != nil {
×
4142
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4143
                        "fields: %w", err)
×
4144
        }
×
4145

4146
        var inboundFee fn.Option[lnwire.Fee]
×
4147
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4148
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4149

×
4150
                inboundFee = fn.Some(lnwire.Fee{
×
4151
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4152
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4153
                })
×
4154
        }
×
4155

4156
        return &models.ChannelEdgePolicy{
×
4157
                SigBytes:  dbPolicy.Signature,
×
4158
                ChannelID: channelID,
×
4159
                LastUpdate: time.Unix(
×
4160
                        dbPolicy.LastUpdate.Int64, 0,
×
4161
                ),
×
4162
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4163
                        dbPolicy.MessageFlags,
×
4164
                ),
×
4165
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4166
                        dbPolicy.ChannelFlags,
×
4167
                ),
×
4168
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4169
                MinHTLC: lnwire.MilliSatoshi(
×
4170
                        dbPolicy.MinHtlcMsat,
×
4171
                ),
×
4172
                MaxHTLC: lnwire.MilliSatoshi(
×
4173
                        dbPolicy.MaxHtlcMsat.Int64,
×
4174
                ),
×
4175
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4176
                        dbPolicy.BaseFeeMsat,
×
4177
                ),
×
4178
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4179
                ToNode:                    toNode,
×
4180
                InboundFee:                inboundFee,
×
4181
                ExtraOpaqueData:           recs,
×
4182
        }, nil
×
4183
}
4184

4185
// buildNodes builds the models.LightningNode instances for the
4186
// given row which is expected to be a sqlc type that contains node information.
4187
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4188
        dbNode2 sqlc.Node) (*models.LightningNode, *models.LightningNode,
4189
        error) {
×
4190

×
4191
        node1, err := buildNode(ctx, db, &dbNode1)
×
4192
        if err != nil {
×
4193
                return nil, nil, err
×
4194
        }
×
4195

4196
        node2, err := buildNode(ctx, db, &dbNode2)
×
4197
        if err != nil {
×
4198
                return nil, nil, err
×
4199
        }
×
4200

4201
        return node1, node2, nil
×
4202
}
4203

4204
// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give
4205
// row which is expected to be a sqlc type that contains channel policy
4206
// information. It returns two policies, which may be nil if the policy
4207
// information is not present in the row.
4208
//
4209
//nolint:ll,dupl,funlen
4210
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
4211
        error) {
×
4212

×
4213
        var policy1, policy2 *sqlc.ChannelPolicy
×
4214
        switch r := row.(type) {
×
4215
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4216
                if r.Policy1ID.Valid {
×
4217
                        policy1 = &sqlc.ChannelPolicy{
×
4218
                                ID:                      r.Policy1ID.Int64,
×
4219
                                Version:                 r.Policy1Version.Int16,
×
4220
                                ChannelID:               r.Channel.ID,
×
4221
                                NodeID:                  r.Policy1NodeID.Int64,
×
4222
                                Timelock:                r.Policy1Timelock.Int32,
×
4223
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4224
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4225
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4226
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4227
                                LastUpdate:              r.Policy1LastUpdate,
×
4228
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4229
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4230
                                Disabled:                r.Policy1Disabled,
×
4231
                                MessageFlags:            r.Policy1MessageFlags,
×
4232
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4233
                                Signature:               r.Policy1Signature,
×
4234
                        }
×
4235
                }
×
4236
                if r.Policy2ID.Valid {
×
4237
                        policy2 = &sqlc.ChannelPolicy{
×
4238
                                ID:                      r.Policy2ID.Int64,
×
4239
                                Version:                 r.Policy2Version.Int16,
×
4240
                                ChannelID:               r.Channel.ID,
×
4241
                                NodeID:                  r.Policy2NodeID.Int64,
×
4242
                                Timelock:                r.Policy2Timelock.Int32,
×
4243
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4244
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4245
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4246
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4247
                                LastUpdate:              r.Policy2LastUpdate,
×
4248
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4249
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4250
                                Disabled:                r.Policy2Disabled,
×
4251
                                MessageFlags:            r.Policy2MessageFlags,
×
4252
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4253
                                Signature:               r.Policy2Signature,
×
4254
                        }
×
4255
                }
×
4256

4257
                return policy1, policy2, nil
×
4258

4259
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4260
                if r.Policy1ID.Valid {
×
4261
                        policy1 = &sqlc.ChannelPolicy{
×
4262
                                ID:                      r.Policy1ID.Int64,
×
4263
                                Version:                 r.Policy1Version.Int16,
×
4264
                                ChannelID:               r.Channel.ID,
×
4265
                                NodeID:                  r.Policy1NodeID.Int64,
×
4266
                                Timelock:                r.Policy1Timelock.Int32,
×
4267
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4268
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4269
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4270
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4271
                                LastUpdate:              r.Policy1LastUpdate,
×
4272
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4273
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4274
                                Disabled:                r.Policy1Disabled,
×
4275
                                MessageFlags:            r.Policy1MessageFlags,
×
4276
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4277
                                Signature:               r.Policy1Signature,
×
4278
                        }
×
4279
                }
×
4280
                if r.Policy2ID.Valid {
×
4281
                        policy2 = &sqlc.ChannelPolicy{
×
4282
                                ID:                      r.Policy2ID.Int64,
×
4283
                                Version:                 r.Policy2Version.Int16,
×
4284
                                ChannelID:               r.Channel.ID,
×
4285
                                NodeID:                  r.Policy2NodeID.Int64,
×
4286
                                Timelock:                r.Policy2Timelock.Int32,
×
4287
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4288
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4289
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4290
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4291
                                LastUpdate:              r.Policy2LastUpdate,
×
4292
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4293
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4294
                                Disabled:                r.Policy2Disabled,
×
4295
                                MessageFlags:            r.Policy2MessageFlags,
×
4296
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4297
                                Signature:               r.Policy2Signature,
×
4298
                        }
×
4299
                }
×
4300

4301
                return policy1, policy2, nil
×
4302

4303
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4304
                if r.Policy1ID.Valid {
×
4305
                        policy1 = &sqlc.ChannelPolicy{
×
4306
                                ID:                      r.Policy1ID.Int64,
×
4307
                                Version:                 r.Policy1Version.Int16,
×
4308
                                ChannelID:               r.Channel.ID,
×
4309
                                NodeID:                  r.Policy1NodeID.Int64,
×
4310
                                Timelock:                r.Policy1Timelock.Int32,
×
4311
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4312
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4313
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4314
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4315
                                LastUpdate:              r.Policy1LastUpdate,
×
4316
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4317
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4318
                                Disabled:                r.Policy1Disabled,
×
4319
                                MessageFlags:            r.Policy1MessageFlags,
×
4320
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4321
                                Signature:               r.Policy1Signature,
×
4322
                        }
×
4323
                }
×
4324
                if r.Policy2ID.Valid {
×
4325
                        policy2 = &sqlc.ChannelPolicy{
×
4326
                                ID:                      r.Policy2ID.Int64,
×
4327
                                Version:                 r.Policy2Version.Int16,
×
4328
                                ChannelID:               r.Channel.ID,
×
4329
                                NodeID:                  r.Policy2NodeID.Int64,
×
4330
                                Timelock:                r.Policy2Timelock.Int32,
×
4331
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4332
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4333
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4334
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4335
                                LastUpdate:              r.Policy2LastUpdate,
×
4336
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4337
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4338
                                Disabled:                r.Policy2Disabled,
×
4339
                                MessageFlags:            r.Policy2MessageFlags,
×
4340
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4341
                                Signature:               r.Policy2Signature,
×
4342
                        }
×
4343
                }
×
4344

4345
                return policy1, policy2, nil
×
4346

4347
        case sqlc.ListChannelsByNodeIDRow:
×
4348
                if r.Policy1ID.Valid {
×
4349
                        policy1 = &sqlc.ChannelPolicy{
×
4350
                                ID:                      r.Policy1ID.Int64,
×
4351
                                Version:                 r.Policy1Version.Int16,
×
4352
                                ChannelID:               r.Channel.ID,
×
4353
                                NodeID:                  r.Policy1NodeID.Int64,
×
4354
                                Timelock:                r.Policy1Timelock.Int32,
×
4355
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4356
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4357
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4358
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4359
                                LastUpdate:              r.Policy1LastUpdate,
×
4360
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4361
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4362
                                Disabled:                r.Policy1Disabled,
×
4363
                                MessageFlags:            r.Policy1MessageFlags,
×
4364
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4365
                                Signature:               r.Policy1Signature,
×
4366
                        }
×
4367
                }
×
4368
                if r.Policy2ID.Valid {
×
4369
                        policy2 = &sqlc.ChannelPolicy{
×
4370
                                ID:                      r.Policy2ID.Int64,
×
4371
                                Version:                 r.Policy2Version.Int16,
×
4372
                                ChannelID:               r.Channel.ID,
×
4373
                                NodeID:                  r.Policy2NodeID.Int64,
×
4374
                                Timelock:                r.Policy2Timelock.Int32,
×
4375
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4376
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4377
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4378
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4379
                                LastUpdate:              r.Policy2LastUpdate,
×
4380
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4381
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4382
                                Disabled:                r.Policy2Disabled,
×
4383
                                MessageFlags:            r.Policy2MessageFlags,
×
4384
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4385
                                Signature:               r.Policy2Signature,
×
4386
                        }
×
4387
                }
×
4388

4389
                return policy1, policy2, nil
×
4390

4391
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4392
                if r.Policy1ID.Valid {
×
4393
                        policy1 = &sqlc.ChannelPolicy{
×
4394
                                ID:                      r.Policy1ID.Int64,
×
4395
                                Version:                 r.Policy1Version.Int16,
×
4396
                                ChannelID:               r.Channel.ID,
×
4397
                                NodeID:                  r.Policy1NodeID.Int64,
×
4398
                                Timelock:                r.Policy1Timelock.Int32,
×
4399
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4400
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4401
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4402
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4403
                                LastUpdate:              r.Policy1LastUpdate,
×
4404
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4405
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4406
                                Disabled:                r.Policy1Disabled,
×
4407
                                MessageFlags:            r.Policy1MessageFlags,
×
4408
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4409
                                Signature:               r.Policy1Signature,
×
4410
                        }
×
4411
                }
×
4412
                if r.Policy2ID.Valid {
×
4413
                        policy2 = &sqlc.ChannelPolicy{
×
4414
                                ID:                      r.Policy2ID.Int64,
×
4415
                                Version:                 r.Policy2Version.Int16,
×
4416
                                ChannelID:               r.Channel.ID,
×
4417
                                NodeID:                  r.Policy2NodeID.Int64,
×
4418
                                Timelock:                r.Policy2Timelock.Int32,
×
4419
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4420
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4421
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4422
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4423
                                LastUpdate:              r.Policy2LastUpdate,
×
4424
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4425
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4426
                                Disabled:                r.Policy2Disabled,
×
4427
                                MessageFlags:            r.Policy2MessageFlags,
×
4428
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4429
                                Signature:               r.Policy2Signature,
×
4430
                        }
×
4431
                }
×
4432

4433
                return policy1, policy2, nil
×
4434
        default:
×
4435
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4436
                        "extractChannelPolicies: %T", r)
×
4437
        }
4438
}
4439

4440
// channelIDToBytes converts a channel ID (SCID) to a byte array
4441
// representation.
4442
func channelIDToBytes(channelID uint64) []byte {
×
4443
        var chanIDB [8]byte
×
4444
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4445

×
4446
        return chanIDB[:]
×
4447
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc