• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15880798456

25 Jun 2025 03:38PM UTC coverage: 57.8% (+0.5%) from 57.316%
15880798456

Pull #9972

github

web-flow
Merge c9776fe87 into fb1fef9e6
Pull Request #9972: [17] multi: run all graph unit tests against SQL backends & run itest suite against SQL graph backend

10 of 45 new or added lines in 5 files covered. (22.22%)

66 existing lines in 11 files now uncovered.

98368 of 170187 relevant lines covered (57.8%)

1.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        GetUnconnectedNodes(ctx context.Context) ([]sqlc.GetUnconnectedNodesRow, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
72
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
73
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
74

75
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
77
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
78

79
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
80
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
81
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
82
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
83

84
        /*
85
                Source node queries.
86
        */
87
        AddSourceNode(ctx context.Context, nodeID int64) error
88
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
89

90
        /*
91
                Channel queries.
92
        */
93
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
94
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
95
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
96
        GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error)
97
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
98
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
99
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
100
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
101
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
102
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
103
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
104
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
105
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
106
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
107
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
108
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
109
        DeleteChannel(ctx context.Context, id int64) error
110

111
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
112
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
113

114
        /*
115
                Channel Policy table queries.
116
        */
117
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
118
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.ChannelPolicy, error)
119
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
120

121
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
122
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
123
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
124

125
        /*
126
                Zombie index queries.
127
        */
128
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
129
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error)
130
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
131
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
132
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
133

134
        /*
135
                Prune log table queries.
136
        */
137
        GetPruneTip(ctx context.Context) (sqlc.PruneLog, error)
138
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
139
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
140

141
        /*
142
                Closed SCID table queries.
143
        */
144
        InsertClosedChannel(ctx context.Context, scid []byte) error
145
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
146
}
147

148
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
149
// database operations.
150
type BatchedSQLQueries interface {
151
        SQLQueries
152
        sqldb.BatchedTx[SQLQueries]
153
}
154

155
// SQLStore is an implementation of the V1Store interface that uses a SQL
156
// database as the backend.
157
//
158
// NOTE: currently, this temporarily embeds the KVStore struct so that we can
159
// implement the V1Store interface incrementally. For any method not
160
// implemented,  things will fall back to the KVStore. This is ONLY the case
161
// for the time being while this struct is purely used in unit tests only.
162
type SQLStore struct {
163
        cfg *SQLStoreConfig
164
        db  BatchedSQLQueries
165

166
        // cacheMu guards all caches (rejectCache and chanCache). If
167
        // this mutex will be acquired at the same time as the DB mutex then
168
        // the cacheMu MUST be acquired first to prevent deadlock.
169
        cacheMu     sync.RWMutex
170
        rejectCache *rejectCache
171
        chanCache   *channelCache
172

173
        chanScheduler batch.Scheduler[SQLQueries]
174
        nodeScheduler batch.Scheduler[SQLQueries]
175

176
        srcNodes  map[ProtocolVersion]*srcNodeInfo
177
        srcNodeMu sync.Mutex
178
}
179

180
// A compile-time assertion to ensure that SQLStore implements the V1Store
181
// interface.
182
var _ V1Store = (*SQLStore)(nil)
183

184
// SQLStoreConfig holds the configuration for the SQLStore.
185
type SQLStoreConfig struct {
186
        // ChainHash is the genesis hash for the chain that all the gossip
187
        // messages in this store are aimed at.
188
        ChainHash chainhash.Hash
189
}
190

191
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
192
// storage backend.
193
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
194
        options ...StoreOptionModifier) (*SQLStore, error) {
×
195

×
196
        opts := DefaultOptions()
×
197
        for _, o := range options {
×
198
                o(opts)
×
199
        }
×
200

201
        if opts.NoMigration {
×
202
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
203
                        "supported for SQL stores")
×
204
        }
×
205

206
        s := &SQLStore{
×
207
                cfg:         cfg,
×
208
                db:          db,
×
209
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
210
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
211
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
212
        }
×
213

×
214
        s.chanScheduler = batch.NewTimeScheduler(
×
215
                db, &s.cacheMu, opts.BatchCommitInterval,
×
216
        )
×
217
        s.nodeScheduler = batch.NewTimeScheduler(
×
218
                db, nil, opts.BatchCommitInterval,
×
219
        )
×
220

×
221
        return s, nil
×
222
}
223

224
// AddLightningNode adds a vertex/node to the graph database. If the node is not
225
// in the database from before, this will add a new, unconnected one to the
226
// graph. If it is present from before, this will update that node's
227
// information.
228
//
229
// NOTE: part of the V1Store interface.
230
func (s *SQLStore) AddLightningNode(ctx context.Context,
231
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
232

×
233
        r := &batch.Request[SQLQueries]{
×
234
                Opts: batch.NewSchedulerOptions(opts...),
×
235
                Do: func(queries SQLQueries) error {
×
236
                        _, err := upsertNode(ctx, queries, node)
×
237
                        return err
×
238
                },
×
239
        }
240

241
        return s.nodeScheduler.Execute(ctx, r)
×
242
}
243

244
// FetchLightningNode attempts to look up a target node by its identity public
245
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
246
// returned.
247
//
248
// NOTE: part of the V1Store interface.
249
func (s *SQLStore) FetchLightningNode(ctx context.Context,
250
        pubKey route.Vertex) (*models.LightningNode, error) {
×
251

×
252
        var node *models.LightningNode
×
253
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
254
                var err error
×
255
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
256

×
257
                return err
×
258
        }, sqldb.NoOpReset)
×
259
        if err != nil {
×
260
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
261
        }
×
262

263
        return node, nil
×
264
}
265

266
// HasLightningNode determines if the graph has a vertex identified by the
267
// target node identity public key. If the node exists in the database, a
268
// timestamp of when the data for the node was lasted updated is returned along
269
// with a true boolean. Otherwise, an empty time.Time is returned with a false
270
// boolean.
271
//
272
// NOTE: part of the V1Store interface.
273
func (s *SQLStore) HasLightningNode(ctx context.Context,
274
        pubKey [33]byte) (time.Time, bool, error) {
×
275

×
276
        var (
×
277
                exists     bool
×
278
                lastUpdate time.Time
×
279
        )
×
280
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
281
                dbNode, err := db.GetNodeByPubKey(
×
282
                        ctx, sqlc.GetNodeByPubKeyParams{
×
283
                                Version: int16(ProtocolV1),
×
284
                                PubKey:  pubKey[:],
×
285
                        },
×
286
                )
×
287
                if errors.Is(err, sql.ErrNoRows) {
×
288
                        return nil
×
289
                } else if err != nil {
×
290
                        return fmt.Errorf("unable to fetch node: %w", err)
×
291
                }
×
292

293
                exists = true
×
294

×
295
                if dbNode.LastUpdate.Valid {
×
296
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
297
                }
×
298

299
                return nil
×
300
        }, sqldb.NoOpReset)
301
        if err != nil {
×
302
                return time.Time{}, false,
×
303
                        fmt.Errorf("unable to fetch node: %w", err)
×
304
        }
×
305

306
        return lastUpdate, exists, nil
×
307
}
308

309
// AddrsForNode returns all known addresses for the target node public key
310
// that the graph DB is aware of. The returned boolean indicates if the
311
// given node is unknown to the graph DB or not.
312
//
313
// NOTE: part of the V1Store interface.
314
func (s *SQLStore) AddrsForNode(ctx context.Context,
315
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
316

×
317
        var (
×
318
                addresses []net.Addr
×
319
                known     bool
×
320
        )
×
321
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
322
                var err error
×
323
                known, addresses, err = getNodeAddresses(
×
324
                        ctx, db, nodePub.SerializeCompressed(),
×
325
                )
×
326
                if err != nil {
×
327
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
328
                                err)
×
329
                }
×
330

331
                return nil
×
332
        }, sqldb.NoOpReset)
333
        if err != nil {
×
334
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
335
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
336
        }
×
337

338
        return known, addresses, nil
×
339
}
340

341
// DeleteLightningNode starts a new database transaction to remove a vertex/node
342
// from the database according to the node's public key.
343
//
344
// NOTE: part of the V1Store interface.
345
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
346
        pubKey route.Vertex) error {
×
347

×
348
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
349
                res, err := db.DeleteNodeByPubKey(
×
350
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
351
                                Version: int16(ProtocolV1),
×
352
                                PubKey:  pubKey[:],
×
353
                        },
×
354
                )
×
355
                if err != nil {
×
356
                        return err
×
357
                }
×
358

359
                rows, err := res.RowsAffected()
×
360
                if err != nil {
×
361
                        return err
×
362
                }
×
363

364
                if rows == 0 {
×
365
                        return ErrGraphNodeNotFound
×
366
                } else if rows > 1 {
×
367
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
368
                }
×
369

370
                return err
×
371
        }, sqldb.NoOpReset)
372
        if err != nil {
×
373
                return fmt.Errorf("unable to delete node: %w", err)
×
374
        }
×
375

376
        return nil
×
377
}
378

379
// FetchNodeFeatures returns the features of the given node. If no features are
380
// known for the node, an empty feature vector is returned.
381
//
382
// NOTE: this is part of the graphdb.NodeTraverser interface.
383
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
384
        *lnwire.FeatureVector, error) {
×
385

×
386
        ctx := context.TODO()
×
387

×
388
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
389
}
×
390

391
// DisabledChannelIDs returns the channel ids of disabled channels.
392
// A channel is disabled when two of the associated ChanelEdgePolicies
393
// have their disabled bit on.
394
//
395
// NOTE: part of the V1Store interface.
396
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
397
        var (
×
398
                ctx     = context.TODO()
×
399
                chanIDs []uint64
×
400
        )
×
401
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
402
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
403
                if err != nil {
×
404
                        return fmt.Errorf("unable to fetch disabled "+
×
405
                                "channels: %w", err)
×
406
                }
×
407

408
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
409

×
410
                return nil
×
411
        }, sqldb.NoOpReset)
412
        if err != nil {
×
413
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
414
                        err)
×
415
        }
×
416

417
        return chanIDs, nil
×
418
}
419

420
// LookupAlias attempts to return the alias as advertised by the target node.
421
//
422
// NOTE: part of the V1Store interface.
423
func (s *SQLStore) LookupAlias(ctx context.Context,
424
        pub *btcec.PublicKey) (string, error) {
×
425

×
426
        var alias string
×
427
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
428
                dbNode, err := db.GetNodeByPubKey(
×
429
                        ctx, sqlc.GetNodeByPubKeyParams{
×
430
                                Version: int16(ProtocolV1),
×
431
                                PubKey:  pub.SerializeCompressed(),
×
432
                        },
×
433
                )
×
434
                if errors.Is(err, sql.ErrNoRows) {
×
435
                        return ErrNodeAliasNotFound
×
436
                } else if err != nil {
×
437
                        return fmt.Errorf("unable to fetch node: %w", err)
×
438
                }
×
439

440
                if !dbNode.Alias.Valid {
×
441
                        return ErrNodeAliasNotFound
×
442
                }
×
443

444
                alias = dbNode.Alias.String
×
445

×
446
                return nil
×
447
        }, sqldb.NoOpReset)
448
        if err != nil {
×
449
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
450
        }
×
451

452
        return alias, nil
×
453
}
454

455
// SourceNode returns the source node of the graph. The source node is treated
456
// as the center node within a star-graph. This method may be used to kick off
457
// a path finding algorithm in order to explore the reachability of another
458
// node based off the source node.
459
//
460
// NOTE: part of the V1Store interface.
461
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
462
        error) {
×
463

×
464
        var node *models.LightningNode
×
465
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
466
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
467
                if err != nil {
×
468
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
469
                                err)
×
470
                }
×
471

472
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
473

×
474
                return err
×
475
        }, sqldb.NoOpReset)
476
        if err != nil {
×
477
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
478
        }
×
479

480
        return node, nil
×
481
}
482

483
// SetSourceNode sets the source node within the graph database. The source
484
// node is to be used as the center of a star-graph within path finding
485
// algorithms.
486
//
487
// NOTE: part of the V1Store interface.
488
func (s *SQLStore) SetSourceNode(ctx context.Context,
489
        node *models.LightningNode) error {
×
490

×
491
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
492
                id, err := upsertNode(ctx, db, node)
×
493
                if err != nil {
×
494
                        return fmt.Errorf("unable to upsert source node: %w",
×
495
                                err)
×
496
                }
×
497

498
                // Make sure that if a source node for this version is already
499
                // set, then the ID is the same as the one we are about to set.
500
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
501
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
502
                        return fmt.Errorf("unable to fetch source node: %w",
×
503
                                err)
×
504
                } else if err == nil {
×
505
                        if dbSourceNodeID != id {
×
506
                                return fmt.Errorf("v1 source node already "+
×
507
                                        "set to a different node: %d vs %d",
×
508
                                        dbSourceNodeID, id)
×
509
                        }
×
510

511
                        return nil
×
512
                }
513

514
                return db.AddSourceNode(ctx, id)
×
515
        }, sqldb.NoOpReset)
516
}
517

518
// NodeUpdatesInHorizon returns all the known lightning node which have an
519
// update timestamp within the passed range. This method can be used by two
520
// nodes to quickly determine if they have the same set of up to date node
521
// announcements.
522
//
523
// NOTE: This is part of the V1Store interface.
524
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
525
        endTime time.Time) ([]models.LightningNode, error) {
×
526

×
527
        ctx := context.TODO()
×
528

×
529
        var nodes []models.LightningNode
×
530
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
531
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
532
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
533
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
534
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
535
                        },
×
536
                )
×
537
                if err != nil {
×
538
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
539
                }
×
540

541
                for _, dbNode := range dbNodes {
×
542
                        node, err := buildNode(ctx, db, &dbNode)
×
543
                        if err != nil {
×
544
                                return fmt.Errorf("unable to build node: %w",
×
545
                                        err)
×
546
                        }
×
547

548
                        nodes = append(nodes, *node)
×
549
                }
550

551
                return nil
×
552
        }, sqldb.NoOpReset)
553
        if err != nil {
×
554
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
555
        }
×
556

557
        return nodes, nil
×
558
}
559

560
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
561
// undirected edge from the two target nodes are created. The information stored
562
// denotes the static attributes of the channel, such as the channelID, the keys
563
// involved in creation of the channel, and the set of features that the channel
564
// supports. The chanPoint and chanID are used to uniquely identify the edge
565
// globally within the database.
566
//
567
// NOTE: part of the V1Store interface.
568
func (s *SQLStore) AddChannelEdge(ctx context.Context,
569
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
570

×
571
        var alreadyExists bool
×
572
        r := &batch.Request[SQLQueries]{
×
573
                Opts: batch.NewSchedulerOptions(opts...),
×
574
                Reset: func() {
×
575
                        alreadyExists = false
×
576
                },
×
577
                Do: func(tx SQLQueries) error {
×
578
                        err := insertChannel(ctx, tx, edge)
×
579

×
580
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
581
                        // succeed, but propagate the error via local state.
×
582
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
583
                                alreadyExists = true
×
584
                                return nil
×
585
                        }
×
586

587
                        return err
×
588
                },
589
                OnCommit: func(err error) error {
×
590
                        switch {
×
591
                        case err != nil:
×
592
                                return err
×
593
                        case alreadyExists:
×
594
                                return ErrEdgeAlreadyExist
×
595
                        default:
×
596
                                s.rejectCache.remove(edge.ChannelID)
×
597
                                s.chanCache.remove(edge.ChannelID)
×
598
                                return nil
×
599
                        }
600
                },
601
        }
602

603
        return s.chanScheduler.Execute(ctx, r)
×
604
}
605

606
// HighestChanID returns the "highest" known channel ID in the channel graph.
607
// This represents the "newest" channel from the PoV of the chain. This method
608
// can be used by peers to quickly determine if their graphs are in sync.
609
//
610
// NOTE: This is part of the V1Store interface.
611
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
612
        var highestChanID uint64
×
613
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
614
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
615
                if errors.Is(err, sql.ErrNoRows) {
×
616
                        return nil
×
617
                } else if err != nil {
×
618
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
619
                                err)
×
620
                }
×
621

622
                highestChanID = byteOrder.Uint64(chanID)
×
623

×
624
                return nil
×
625
        }, sqldb.NoOpReset)
626
        if err != nil {
×
627
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
628
        }
×
629

630
        return highestChanID, nil
×
631
}
632

633
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
634
// within the database for the referenced channel. The `flags` attribute within
635
// the ChannelEdgePolicy determines which of the directed edges are being
636
// updated. If the flag is 1, then the first node's information is being
637
// updated, otherwise it's the second node's information. The node ordering is
638
// determined by the lexicographical ordering of the identity public keys of the
639
// nodes on either side of the channel.
640
//
641
// NOTE: part of the V1Store interface.
642
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
643
        edge *models.ChannelEdgePolicy,
644
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
645

×
646
        var (
×
647
                isUpdate1    bool
×
648
                edgeNotFound bool
×
649
                from, to     route.Vertex
×
650
        )
×
651

×
652
        r := &batch.Request[SQLQueries]{
×
653
                Opts: batch.NewSchedulerOptions(opts...),
×
654
                Reset: func() {
×
655
                        isUpdate1 = false
×
656
                        edgeNotFound = false
×
657
                },
×
658
                Do: func(tx SQLQueries) error {
×
659
                        var err error
×
660
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
661
                                ctx, tx, edge,
×
662
                        )
×
663
                        if err != nil {
×
664
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
665
                        }
×
666

667
                        // Silence ErrEdgeNotFound so that the batch can
668
                        // succeed, but propagate the error via local state.
669
                        if errors.Is(err, ErrEdgeNotFound) {
×
670
                                edgeNotFound = true
×
671
                                return nil
×
672
                        }
×
673

674
                        return err
×
675
                },
676
                OnCommit: func(err error) error {
×
677
                        switch {
×
678
                        case err != nil:
×
679
                                return err
×
680
                        case edgeNotFound:
×
681
                                return ErrEdgeNotFound
×
682
                        default:
×
683
                                s.updateEdgeCache(edge, isUpdate1)
×
684
                                return nil
×
685
                        }
686
                },
687
        }
688

689
        err := s.chanScheduler.Execute(ctx, r)
×
690

×
691
        return from, to, err
×
692
}
693

694
// updateEdgeCache updates our reject and channel caches with the new
695
// edge policy information.
696
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
697
        isUpdate1 bool) {
×
698

×
699
        // If an entry for this channel is found in reject cache, we'll modify
×
700
        // the entry with the updated timestamp for the direction that was just
×
701
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
702
        // during the next query for this edge.
×
703
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
704
                if isUpdate1 {
×
705
                        entry.upd1Time = e.LastUpdate.Unix()
×
706
                } else {
×
707
                        entry.upd2Time = e.LastUpdate.Unix()
×
708
                }
×
709
                s.rejectCache.insert(e.ChannelID, entry)
×
710
        }
711

712
        // If an entry for this channel is found in channel cache, we'll modify
713
        // the entry with the updated policy for the direction that was just
714
        // written. If the edge doesn't exist, we'll defer loading the info and
715
        // policies and lazily read from disk during the next query.
716
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
717
                if isUpdate1 {
×
718
                        channel.Policy1 = e
×
719
                } else {
×
720
                        channel.Policy2 = e
×
721
                }
×
722
                s.chanCache.insert(e.ChannelID, channel)
×
723
        }
724
}
725

726
// ForEachSourceNodeChannel iterates through all channels of the source node,
727
// executing the passed callback on each. The call-back is provided with the
728
// channel's outpoint, whether we have a policy for the channel and the channel
729
// peer's node information.
730
//
731
// NOTE: part of the V1Store interface.
732
func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint,
733
        havePolicy bool, otherNode *models.LightningNode) error) error {
×
734

×
735
        var ctx = context.TODO()
×
736

×
737
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
738
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
739
                if err != nil {
×
740
                        return fmt.Errorf("unable to fetch source node: %w",
×
741
                                err)
×
742
                }
×
743

744
                return forEachNodeChannel(
×
745
                        ctx, db, s.cfg.ChainHash, nodeID,
×
746
                        func(info *models.ChannelEdgeInfo,
×
747
                                outPolicy *models.ChannelEdgePolicy,
×
748
                                _ *models.ChannelEdgePolicy) error {
×
749

×
750
                                // Fetch the other node.
×
751
                                var (
×
752
                                        otherNodePub [33]byte
×
753
                                        node1        = info.NodeKey1Bytes
×
754
                                        node2        = info.NodeKey2Bytes
×
755
                                )
×
756
                                switch {
×
757
                                case bytes.Equal(node1[:], nodePub[:]):
×
758
                                        otherNodePub = node2
×
759
                                case bytes.Equal(node2[:], nodePub[:]):
×
760
                                        otherNodePub = node1
×
761
                                default:
×
762
                                        return fmt.Errorf("node not " +
×
763
                                                "participating in this channel")
×
764
                                }
765

766
                                _, otherNode, err := getNodeByPubKey(
×
767
                                        ctx, db, otherNodePub,
×
768
                                )
×
769
                                if err != nil {
×
770
                                        return fmt.Errorf("unable to fetch "+
×
771
                                                "other node(%x): %w",
×
772
                                                otherNodePub, err)
×
773
                                }
×
774

775
                                return cb(
×
776
                                        info.ChannelPoint, outPolicy != nil,
×
777
                                        otherNode,
×
778
                                )
×
779
                        },
780
                )
781
        }, sqldb.NoOpReset)
782
}
783

784
// ForEachNode iterates through all the stored vertices/nodes in the graph,
785
// executing the passed callback with each node encountered. If the callback
786
// returns an error, then the transaction is aborted and the iteration stops
787
// early. Any operations performed on the NodeTx passed to the call-back are
788
// executed under the same read transaction and so, methods on the NodeTx object
789
// _MUST_ only be called from within the call-back.
790
//
791
// NOTE: part of the V1Store interface.
792
func (s *SQLStore) ForEachNode(cb func(tx NodeRTx) error) error {
×
793
        var (
×
794
                ctx          = context.TODO()
×
795
                lastID int64 = 0
×
796
        )
×
797

×
798
        handleNode := func(db SQLQueries, dbNode sqlc.Node) error {
×
799
                node, err := buildNode(ctx, db, &dbNode)
×
800
                if err != nil {
×
801
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
802
                                dbNode.ID, err)
×
803
                }
×
804

805
                err = cb(
×
806
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
807
                )
×
808
                if err != nil {
×
809
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
810
                                dbNode.ID, err)
×
811
                }
×
812

813
                return nil
×
814
        }
815

816
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
817
                for {
×
818
                        nodes, err := db.ListNodesPaginated(
×
819
                                ctx, sqlc.ListNodesPaginatedParams{
×
820
                                        Version: int16(ProtocolV1),
×
821
                                        ID:      lastID,
×
822
                                        Limit:   pageSize,
×
823
                                },
×
824
                        )
×
825
                        if err != nil {
×
826
                                return fmt.Errorf("unable to fetch nodes: %w",
×
827
                                        err)
×
828
                        }
×
829

830
                        if len(nodes) == 0 {
×
831
                                break
×
832
                        }
833

834
                        for _, dbNode := range nodes {
×
835
                                err = handleNode(db, dbNode)
×
836
                                if err != nil {
×
837
                                        return err
×
838
                                }
×
839

840
                                lastID = dbNode.ID
×
841
                        }
842
                }
843

844
                return nil
×
845
        }, sqldb.NoOpReset)
846
}
847

848
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
849
// SQLStore and a SQL transaction.
850
type sqlGraphNodeTx struct {
851
        db    SQLQueries
852
        id    int64
853
        node  *models.LightningNode
854
        chain chainhash.Hash
855
}
856

857
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
858
// interface.
859
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
860

861
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
862
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
863

×
864
        return &sqlGraphNodeTx{
×
865
                db:    db,
×
866
                chain: chain,
×
867
                id:    id,
×
868
                node:  node,
×
869
        }
×
870
}
×
871

872
// Node returns the raw information of the node.
873
//
874
// NOTE: This is a part of the NodeRTx interface.
875
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
876
        return s.node
×
877
}
×
878

879
// ForEachChannel can be used to iterate over the node's channels under the same
880
// transaction used to fetch the node.
881
//
882
// NOTE: This is a part of the NodeRTx interface.
883
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
884
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
885

×
886
        ctx := context.TODO()
×
887

×
888
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
889
}
×
890

891
// FetchNode fetches the node with the given pub key under the same transaction
892
// used to fetch the current node. The returned node is also a NodeRTx and any
893
// operations on that NodeRTx will also be done under the same transaction.
894
//
895
// NOTE: This is a part of the NodeRTx interface.
896
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
897
        ctx := context.TODO()
×
898

×
899
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
900
        if err != nil {
×
901
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
902
                        nodePub, err)
×
903
        }
×
904

905
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
906
}
907

908
// ForEachNodeDirectedChannel iterates through all channels of a given node,
909
// executing the passed callback on the directed edge representing the channel
910
// and its incoming policy. If the callback returns an error, then the iteration
911
// is halted with the error propagated back up to the caller.
912
//
913
// Unknown policies are passed into the callback as nil values.
914
//
915
// NOTE: this is part of the graphdb.NodeTraverser interface.
916
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
917
        cb func(channel *DirectedChannel) error) error {
×
918

×
919
        var ctx = context.TODO()
×
920

×
921
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
922
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
923
        }, sqldb.NoOpReset)
×
924
}
925

926
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
927
// graph, executing the passed callback with each node encountered. If the
928
// callback returns an error, then the transaction is aborted and the iteration
929
// stops early.
930
//
931
// NOTE: This is a part of the V1Store interface.
932
func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex,
933
        *lnwire.FeatureVector) error) error {
×
934

×
935
        ctx := context.TODO()
×
936

×
937
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
938
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
939
                        nodePub route.Vertex) error {
×
940

×
941
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
942
                        if err != nil {
×
943
                                return fmt.Errorf("unable to fetch node "+
×
944
                                        "features: %w", err)
×
945
                        }
×
946

947
                        return cb(nodePub, features)
×
948
                })
949
        }, sqldb.NoOpReset)
950
        if err != nil {
×
951
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
952
        }
×
953

954
        return nil
×
955
}
956

957
// ForEachNodeChannel iterates through all channels of the given node,
958
// executing the passed callback with an edge info structure and the policies
959
// of each end of the channel. The first edge policy is the outgoing edge *to*
960
// the connecting node, while the second is the incoming edge *from* the
961
// connecting node. If the callback returns an error, then the iteration is
962
// halted with the error propagated back up to the caller.
963
//
964
// Unknown policies are passed into the callback as nil values.
965
//
966
// NOTE: part of the V1Store interface.
967
func (s *SQLStore) ForEachNodeChannel(nodePub route.Vertex,
968
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
969
                *models.ChannelEdgePolicy) error) error {
×
970

×
971
        var ctx = context.TODO()
×
972

×
973
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
974
                dbNode, err := db.GetNodeByPubKey(
×
975
                        ctx, sqlc.GetNodeByPubKeyParams{
×
976
                                Version: int16(ProtocolV1),
×
977
                                PubKey:  nodePub[:],
×
978
                        },
×
979
                )
×
980
                if errors.Is(err, sql.ErrNoRows) {
×
981
                        return nil
×
982
                } else if err != nil {
×
983
                        return fmt.Errorf("unable to fetch node: %w", err)
×
984
                }
×
985

986
                return forEachNodeChannel(
×
987
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
988
                )
×
989
        }, sqldb.NoOpReset)
990
}
991

992
// ChanUpdatesInHorizon returns all the known channel edges which have at least
993
// one edge that has an update timestamp within the specified horizon.
994
//
995
// NOTE: This is part of the V1Store interface.
996
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
997
        endTime time.Time) ([]ChannelEdge, error) {
×
998

×
999
        s.cacheMu.Lock()
×
1000
        defer s.cacheMu.Unlock()
×
1001

×
1002
        var (
×
1003
                ctx = context.TODO()
×
1004
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
1005
                // an additional map to keep track of the edges already seen to
×
1006
                // prevent re-adding it.
×
1007
                edgesSeen    = make(map[uint64]struct{})
×
1008
                edgesToCache = make(map[uint64]ChannelEdge)
×
1009
                edges        []ChannelEdge
×
1010
                hits         int
×
1011
        )
×
1012
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1013
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1014
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1015
                                Version:   int16(ProtocolV1),
×
1016
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1017
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1018
                        },
×
1019
                )
×
1020
                if err != nil {
×
1021
                        return err
×
1022
                }
×
1023

1024
                for _, row := range rows {
×
1025
                        // If we've already retrieved the info and policies for
×
1026
                        // this edge, then we can skip it as we don't need to do
×
1027
                        // so again.
×
1028
                        chanIDInt := byteOrder.Uint64(row.Channel.Scid)
×
1029
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1030
                                continue
×
1031
                        }
1032

1033
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1034
                                hits++
×
1035
                                edgesSeen[chanIDInt] = struct{}{}
×
1036
                                edges = append(edges, channel)
×
1037

×
1038
                                continue
×
1039
                        }
1040

1041
                        node1, node2, err := buildNodes(
×
1042
                                ctx, db, row.Node, row.Node_2,
×
1043
                        )
×
1044
                        if err != nil {
×
1045
                                return err
×
1046
                        }
×
1047

1048
                        channel, err := getAndBuildEdgeInfo(
×
1049
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1050
                                row.Channel, node1.PubKeyBytes,
×
1051
                                node2.PubKeyBytes,
×
1052
                        )
×
1053
                        if err != nil {
×
1054
                                return fmt.Errorf("unable to build channel "+
×
1055
                                        "info: %w", err)
×
1056
                        }
×
1057

1058
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1059
                        if err != nil {
×
1060
                                return fmt.Errorf("unable to extract channel "+
×
1061
                                        "policies: %w", err)
×
1062
                        }
×
1063

1064
                        p1, p2, err := getAndBuildChanPolicies(
×
1065
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1066
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1067
                        )
×
1068
                        if err != nil {
×
1069
                                return fmt.Errorf("unable to build channel "+
×
1070
                                        "policies: %w", err)
×
1071
                        }
×
1072

1073
                        edgesSeen[chanIDInt] = struct{}{}
×
1074
                        chanEdge := ChannelEdge{
×
1075
                                Info:    channel,
×
1076
                                Policy1: p1,
×
1077
                                Policy2: p2,
×
1078
                                Node1:   node1,
×
1079
                                Node2:   node2,
×
1080
                        }
×
1081
                        edges = append(edges, chanEdge)
×
1082
                        edgesToCache[chanIDInt] = chanEdge
×
1083
                }
1084

1085
                return nil
×
1086
        }, func() {
×
1087
                edgesSeen = make(map[uint64]struct{})
×
1088
                edgesToCache = make(map[uint64]ChannelEdge)
×
1089
                edges = nil
×
1090
        })
×
1091
        if err != nil {
×
1092
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1093
        }
×
1094

1095
        // Insert any edges loaded from disk into the cache.
1096
        for chanid, channel := range edgesToCache {
×
1097
                s.chanCache.insert(chanid, channel)
×
1098
        }
×
1099

1100
        if len(edges) > 0 {
×
1101
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1102
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1103
        } else {
×
1104
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1105
                        "horizon (%s, %s)", startTime, endTime)
×
1106
        }
×
1107

1108
        return edges, nil
×
1109
}
1110

1111
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1112
// data to the call-back.
1113
//
1114
// NOTE: The callback contents MUST not be modified.
1115
//
1116
// NOTE: part of the V1Store interface.
1117
func (s *SQLStore) ForEachNodeCached(cb func(node route.Vertex,
1118
        chans map[uint64]*DirectedChannel) error) error {
×
1119

×
1120
        var ctx = context.TODO()
×
1121

×
1122
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1123
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1124
                        nodePub route.Vertex) error {
×
1125

×
1126
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1127
                        if err != nil {
×
1128
                                return fmt.Errorf("unable to fetch "+
×
1129
                                        "node(id=%d) features: %w", nodeID, err)
×
1130
                        }
×
1131

1132
                        toNodeCallback := func() route.Vertex {
×
1133
                                return nodePub
×
1134
                        }
×
1135

1136
                        rows, err := db.ListChannelsByNodeID(
×
1137
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1138
                                        Version: int16(ProtocolV1),
×
1139
                                        NodeID1: nodeID,
×
1140
                                },
×
1141
                        )
×
1142
                        if err != nil {
×
1143
                                return fmt.Errorf("unable to fetch channels "+
×
1144
                                        "of node(id=%d): %w", nodeID, err)
×
1145
                        }
×
1146

1147
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1148
                        for _, row := range rows {
×
1149
                                node1, node2, err := buildNodeVertices(
×
1150
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1151
                                )
×
1152
                                if err != nil {
×
1153
                                        return err
×
1154
                                }
×
1155

1156
                                e, err := getAndBuildEdgeInfo(
×
1157
                                        ctx, db, s.cfg.ChainHash,
×
1158
                                        row.Channel.ID, row.Channel, node1,
×
1159
                                        node2,
×
1160
                                )
×
1161
                                if err != nil {
×
1162
                                        return fmt.Errorf("unable to build "+
×
1163
                                                "channel info: %w", err)
×
1164
                                }
×
1165

1166
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1167
                                        row,
×
1168
                                )
×
1169
                                if err != nil {
×
1170
                                        return fmt.Errorf("unable to "+
×
1171
                                                "extract channel "+
×
1172
                                                "policies: %w", err)
×
1173
                                }
×
1174

1175
                                p1, p2, err := getAndBuildChanPolicies(
×
1176
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1177
                                        node1, node2,
×
1178
                                )
×
1179
                                if err != nil {
×
1180
                                        return fmt.Errorf("unable to "+
×
1181
                                                "build channel policies: %w",
×
1182
                                                err)
×
1183
                                }
×
1184

1185
                                // Determine the outgoing and incoming policy
1186
                                // for this channel and node combo.
1187
                                outPolicy, inPolicy := p1, p2
×
1188
                                if p1 != nil && p1.ToNode == nodePub {
×
1189
                                        outPolicy, inPolicy = p2, p1
×
1190
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1191
                                        outPolicy, inPolicy = p2, p1
×
1192
                                }
×
1193

1194
                                var cachedInPolicy *models.CachedEdgePolicy
×
1195
                                if inPolicy != nil {
×
1196
                                        cachedInPolicy = models.NewCachedPolicy(
×
1197
                                                p2,
×
1198
                                        )
×
1199
                                        cachedInPolicy.ToNodePubKey =
×
1200
                                                toNodeCallback
×
1201
                                        cachedInPolicy.ToNodeFeatures =
×
1202
                                                features
×
1203
                                }
×
1204

1205
                                var inboundFee lnwire.Fee
×
1206
                                outPolicy.InboundFee.WhenSome(
×
1207
                                        func(fee lnwire.Fee) {
×
1208
                                                inboundFee = fee
×
1209
                                        },
×
1210
                                )
1211

1212
                                directedChannel := &DirectedChannel{
×
1213
                                        ChannelID: e.ChannelID,
×
1214
                                        IsNode1: nodePub ==
×
1215
                                                e.NodeKey1Bytes,
×
1216
                                        OtherNode:    e.NodeKey2Bytes,
×
1217
                                        Capacity:     e.Capacity,
×
1218
                                        OutPolicySet: p1 != nil,
×
1219
                                        InPolicy:     cachedInPolicy,
×
1220
                                        InboundFee:   inboundFee,
×
1221
                                }
×
1222

×
1223
                                if nodePub == e.NodeKey2Bytes {
×
1224
                                        directedChannel.OtherNode =
×
1225
                                                e.NodeKey1Bytes
×
1226
                                }
×
1227

1228
                                channels[e.ChannelID] = directedChannel
×
1229
                        }
1230

1231
                        return cb(nodePub, channels)
×
1232
                })
1233
        }, sqldb.NoOpReset)
1234
}
1235

1236
// ForEachChannelCacheable iterates through all the channel edges stored
1237
// within the graph and invokes the passed callback for each edge. The
1238
// callback takes two edges as since this is a directed graph, both the
1239
// in/out edges are visited. If the callback returns an error, then the
1240
// transaction is aborted and the iteration stops early.
1241
//
1242
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1243
// pointer for that particular channel edge routing policy will be
1244
// passed into the callback.
1245
//
1246
// NOTE: this method is like ForEachChannel but fetches only the data
1247
// required for the graph cache.
1248
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1249
        *models.CachedEdgePolicy,
1250
        *models.CachedEdgePolicy) error) error {
×
1251

×
1252
        ctx := context.TODO()
×
1253

×
1254
        handleChannel := func(db SQLQueries,
×
1255
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1256

×
1257
                node1, node2, err := buildNodeVertices(
×
1258
                        row.Node1Pubkey, row.Node2Pubkey,
×
1259
                )
×
1260
                if err != nil {
×
1261
                        return err
×
1262
                }
×
1263

1264
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
1265

×
1266
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1267
                if err != nil {
×
1268
                        return err
×
1269
                }
×
1270

1271
                var pol1, pol2 *models.CachedEdgePolicy
×
1272
                if dbPol1 != nil {
×
1273
                        policy1, err := buildChanPolicy(
×
1274
                                *dbPol1, edge.ChannelID, nil, node2, true,
×
1275
                        )
×
1276
                        if err != nil {
×
1277
                                return err
×
1278
                        }
×
1279

1280
                        pol1 = models.NewCachedPolicy(policy1)
×
1281
                }
1282
                if dbPol2 != nil {
×
1283
                        policy2, err := buildChanPolicy(
×
1284
                                *dbPol2, edge.ChannelID, nil, node1, false,
×
1285
                        )
×
1286
                        if err != nil {
×
1287
                                return err
×
1288
                        }
×
1289

1290
                        pol2 = models.NewCachedPolicy(policy2)
×
1291
                }
1292

1293
                if err := cb(edge, pol1, pol2); err != nil {
×
1294
                        return err
×
1295
                }
×
1296

1297
                return nil
×
1298
        }
1299

1300
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1301
                lastID := int64(-1)
×
1302
                for {
×
1303
                        //nolint:ll
×
1304
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1305
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1306
                                        Version: int16(ProtocolV1),
×
1307
                                        ID:      lastID,
×
1308
                                        Limit:   pageSize,
×
1309
                                },
×
1310
                        )
×
1311
                        if err != nil {
×
1312
                                return err
×
1313
                        }
×
1314

1315
                        if len(rows) == 0 {
×
1316
                                break
×
1317
                        }
1318

1319
                        for _, row := range rows {
×
1320
                                err := handleChannel(db, row)
×
1321
                                if err != nil {
×
1322
                                        return err
×
1323
                                }
×
1324

1325
                                lastID = row.Channel.ID
×
1326
                        }
1327
                }
1328

1329
                return nil
×
1330
        }, sqldb.NoOpReset)
1331
}
1332

1333
// ForEachChannel iterates through all the channel edges stored within the
1334
// graph and invokes the passed callback for each edge. The callback takes two
1335
// edges as since this is a directed graph, both the in/out edges are visited.
1336
// If the callback returns an error, then the transaction is aborted and the
1337
// iteration stops early.
1338
//
1339
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1340
// for that particular channel edge routing policy will be passed into the
1341
// callback.
1342
//
1343
// NOTE: part of the V1Store interface.
1344
func (s *SQLStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
1345
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
1346

×
1347
        ctx := context.TODO()
×
1348

×
1349
        handleChannel := func(db SQLQueries,
×
1350
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1351

×
1352
                node1, node2, err := buildNodeVertices(
×
1353
                        row.Node1Pubkey, row.Node2Pubkey,
×
1354
                )
×
1355
                if err != nil {
×
1356
                        return fmt.Errorf("unable to build node vertices: %w",
×
1357
                                err)
×
1358
                }
×
1359

1360
                edge, err := getAndBuildEdgeInfo(
×
1361
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1362
                        node1, node2,
×
1363
                )
×
1364
                if err != nil {
×
1365
                        return fmt.Errorf("unable to build channel info: %w",
×
1366
                                err)
×
1367
                }
×
1368

1369
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1370
                if err != nil {
×
1371
                        return fmt.Errorf("unable to extract channel "+
×
1372
                                "policies: %w", err)
×
1373
                }
×
1374

1375
                p1, p2, err := getAndBuildChanPolicies(
×
1376
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1377
                )
×
1378
                if err != nil {
×
1379
                        return fmt.Errorf("unable to build channel "+
×
1380
                                "policies: %w", err)
×
1381
                }
×
1382

1383
                err = cb(edge, p1, p2)
×
1384
                if err != nil {
×
1385
                        return fmt.Errorf("callback failed for channel "+
×
1386
                                "id=%d: %w", edge.ChannelID, err)
×
1387
                }
×
1388

1389
                return nil
×
1390
        }
1391

1392
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1393
                lastID := int64(-1)
×
1394
                for {
×
1395
                        //nolint:ll
×
1396
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1397
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1398
                                        Version: int16(ProtocolV1),
×
1399
                                        ID:      lastID,
×
1400
                                        Limit:   pageSize,
×
1401
                                },
×
1402
                        )
×
1403
                        if err != nil {
×
1404
                                return err
×
1405
                        }
×
1406

1407
                        if len(rows) == 0 {
×
1408
                                break
×
1409
                        }
1410

1411
                        for _, row := range rows {
×
1412
                                err := handleChannel(db, row)
×
1413
                                if err != nil {
×
1414
                                        return err
×
1415
                                }
×
1416

1417
                                lastID = row.Channel.ID
×
1418
                        }
1419
                }
1420

1421
                return nil
×
1422
        }, sqldb.NoOpReset)
1423
}
1424

1425
// FilterChannelRange returns the channel ID's of all known channels which were
1426
// mined in a block height within the passed range. The channel IDs are grouped
1427
// by their common block height. This method can be used to quickly share with a
1428
// peer the set of channels we know of within a particular range to catch them
1429
// up after a period of time offline. If withTimestamps is true then the
1430
// timestamp info of the latest received channel update messages of the channel
1431
// will be included in the response.
1432
//
1433
// NOTE: This is part of the V1Store interface.
1434
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1435
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1436

×
1437
        var (
×
1438
                ctx       = context.TODO()
×
1439
                startSCID = &lnwire.ShortChannelID{
×
1440
                        BlockHeight: startHeight,
×
1441
                }
×
1442
                endSCID = lnwire.ShortChannelID{
×
1443
                        BlockHeight: endHeight,
×
1444
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1445
                        TxPosition:  math.MaxUint16,
×
1446
                }
×
1447
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1448
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1449
        )
×
1450

×
1451
        // 1) get all channels where channelID is between start and end chan ID.
×
1452
        // 2) skip if not public (ie, no channel_proof)
×
1453
        // 3) collect that channel.
×
1454
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1455
        //    and add those timestamps to the collected channel.
×
1456
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1457
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1458
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1459
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1460
                                StartScid: chanIDStart[:],
×
1461
                                EndScid:   chanIDEnd[:],
×
1462
                        },
×
1463
                )
×
1464
                if err != nil {
×
1465
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1466
                                err)
×
1467
                }
×
1468

1469
                for _, dbChan := range dbChans {
×
1470
                        cid := lnwire.NewShortChanIDFromInt(
×
1471
                                byteOrder.Uint64(dbChan.Scid),
×
1472
                        )
×
1473
                        chanInfo := NewChannelUpdateInfo(
×
1474
                                cid, time.Time{}, time.Time{},
×
1475
                        )
×
1476

×
1477
                        if !withTimestamps {
×
1478
                                channelsPerBlock[cid.BlockHeight] = append(
×
1479
                                        channelsPerBlock[cid.BlockHeight],
×
1480
                                        chanInfo,
×
1481
                                )
×
1482

×
1483
                                continue
×
1484
                        }
1485

1486
                        //nolint:ll
1487
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1488
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1489
                                        Version:   int16(ProtocolV1),
×
1490
                                        ChannelID: dbChan.ID,
×
1491
                                        NodeID:    dbChan.NodeID1,
×
1492
                                },
×
1493
                        )
×
1494
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1495
                                return fmt.Errorf("unable to fetch node1 "+
×
1496
                                        "policy: %w", err)
×
1497
                        } else if err == nil {
×
1498
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1499
                                        node1Policy.LastUpdate.Int64, 0,
×
1500
                                )
×
1501
                        }
×
1502

1503
                        //nolint:ll
1504
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1505
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1506
                                        Version:   int16(ProtocolV1),
×
1507
                                        ChannelID: dbChan.ID,
×
1508
                                        NodeID:    dbChan.NodeID2,
×
1509
                                },
×
1510
                        )
×
1511
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1512
                                return fmt.Errorf("unable to fetch node2 "+
×
1513
                                        "policy: %w", err)
×
1514
                        } else if err == nil {
×
1515
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1516
                                        node2Policy.LastUpdate.Int64, 0,
×
1517
                                )
×
1518
                        }
×
1519

1520
                        channelsPerBlock[cid.BlockHeight] = append(
×
1521
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1522
                        )
×
1523
                }
1524

1525
                return nil
×
1526
        }, func() {
×
1527
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1528
        })
×
1529
        if err != nil {
×
1530
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1531
        }
×
1532

1533
        if len(channelsPerBlock) == 0 {
×
1534
                return nil, nil
×
1535
        }
×
1536

1537
        // Return the channel ranges in ascending block height order.
1538
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1539
        slices.Sort(blocks)
×
1540

×
1541
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1542
                return BlockChannelRange{
×
1543
                        Height:   block,
×
1544
                        Channels: channelsPerBlock[block],
×
1545
                }
×
1546
        }), nil
×
1547
}
1548

1549
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1550
// zombie. This method is used on an ad-hoc basis, when channels need to be
1551
// marked as zombies outside the normal pruning cycle.
1552
//
1553
// NOTE: part of the V1Store interface.
1554
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1555
        pubKey1, pubKey2 [33]byte) error {
×
1556

×
1557
        ctx := context.TODO()
×
1558

×
1559
        s.cacheMu.Lock()
×
1560
        defer s.cacheMu.Unlock()
×
1561

×
1562
        chanIDB := channelIDToBytes(chanID)
×
1563

×
1564
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1565
                return db.UpsertZombieChannel(
×
1566
                        ctx, sqlc.UpsertZombieChannelParams{
×
1567
                                Version:  int16(ProtocolV1),
×
1568
                                Scid:     chanIDB[:],
×
1569
                                NodeKey1: pubKey1[:],
×
1570
                                NodeKey2: pubKey2[:],
×
1571
                        },
×
1572
                )
×
1573
        }, sqldb.NoOpReset)
×
1574
        if err != nil {
×
1575
                return fmt.Errorf("unable to upsert zombie channel "+
×
1576
                        "(channel_id=%d): %w", chanID, err)
×
1577
        }
×
1578

1579
        s.rejectCache.remove(chanID)
×
1580
        s.chanCache.remove(chanID)
×
1581

×
1582
        return nil
×
1583
}
1584

1585
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1586
//
1587
// NOTE: part of the V1Store interface.
1588
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1589
        s.cacheMu.Lock()
×
1590
        defer s.cacheMu.Unlock()
×
1591

×
1592
        var (
×
1593
                ctx     = context.TODO()
×
1594
                chanIDB = channelIDToBytes(chanID)
×
1595
        )
×
1596

×
1597
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1598
                res, err := db.DeleteZombieChannel(
×
1599
                        ctx, sqlc.DeleteZombieChannelParams{
×
1600
                                Scid:    chanIDB[:],
×
1601
                                Version: int16(ProtocolV1),
×
1602
                        },
×
1603
                )
×
1604
                if err != nil {
×
1605
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1606
                                err)
×
1607
                }
×
1608

1609
                rows, err := res.RowsAffected()
×
1610
                if err != nil {
×
1611
                        return err
×
1612
                }
×
1613

1614
                if rows == 0 {
×
1615
                        return ErrZombieEdgeNotFound
×
1616
                } else if rows > 1 {
×
1617
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1618
                                "expected 1", rows)
×
1619
                }
×
1620

1621
                return nil
×
1622
        }, sqldb.NoOpReset)
1623
        if err != nil {
×
1624
                return fmt.Errorf("unable to mark edge live "+
×
1625
                        "(channel_id=%d): %w", chanID, err)
×
1626
        }
×
1627

1628
        s.rejectCache.remove(chanID)
×
1629
        s.chanCache.remove(chanID)
×
1630

×
1631
        return err
×
1632
}
1633

1634
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1635
// zombie, then the two node public keys corresponding to this edge are also
1636
// returned.
1637
//
1638
// NOTE: part of the V1Store interface.
1639
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
NEW
1640
        error) {
×
NEW
1641

×
1642
        var (
×
1643
                ctx              = context.TODO()
×
1644
                isZombie         bool
×
1645
                pubKey1, pubKey2 route.Vertex
×
1646
                chanIDB          = channelIDToBytes(chanID)
×
1647
        )
×
1648

×
1649
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1650
                zombie, err := db.GetZombieChannel(
×
1651
                        ctx, sqlc.GetZombieChannelParams{
×
1652
                                Scid:    chanIDB[:],
×
1653
                                Version: int16(ProtocolV1),
×
1654
                        },
×
1655
                )
×
1656
                if errors.Is(err, sql.ErrNoRows) {
×
1657
                        return nil
×
1658
                }
×
1659
                if err != nil {
×
1660
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1661
                                err)
×
1662
                }
×
1663

1664
                copy(pubKey1[:], zombie.NodeKey1)
×
1665
                copy(pubKey2[:], zombie.NodeKey2)
×
1666
                isZombie = true
×
1667

×
1668
                return nil
×
1669
        }, sqldb.NoOpReset)
1670
        if err != nil {
×
NEW
1671
                return false, route.Vertex{}, route.Vertex{},
×
NEW
1672
                        fmt.Errorf("unable to check if edge is zombie "+
×
NEW
1673
                                "(channel_id=%d): %w", chanID, err)
×
UNCOV
1674
        }
×
1675

NEW
1676
        return isZombie, pubKey1, pubKey2, nil
×
1677
}
1678

1679
// NumZombies returns the current number of zombie channels in the graph.
1680
//
1681
// NOTE: part of the V1Store interface.
1682
func (s *SQLStore) NumZombies() (uint64, error) {
×
1683
        var (
×
1684
                ctx        = context.TODO()
×
1685
                numZombies uint64
×
1686
        )
×
1687
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1688
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1689
                if err != nil {
×
1690
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1691
                                err)
×
1692
                }
×
1693

1694
                numZombies = uint64(count)
×
1695

×
1696
                return nil
×
1697
        }, sqldb.NoOpReset)
1698
        if err != nil {
×
1699
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1700
        }
×
1701

1702
        return numZombies, nil
×
1703
}
1704

1705
// DeleteChannelEdges removes edges with the given channel IDs from the
1706
// database and marks them as zombies. This ensures that we're unable to re-add
1707
// it to our database once again. If an edge does not exist within the
1708
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1709
// true, then when we mark these edges as zombies, we'll set up the keys such
1710
// that we require the node that failed to send the fresh update to be the one
1711
// that resurrects the channel from its zombie state. The markZombie bool
1712
// denotes whether to mark the channel as a zombie.
1713
//
1714
// NOTE: part of the V1Store interface.
1715
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1716
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1717

×
1718
        s.cacheMu.Lock()
×
1719
        defer s.cacheMu.Unlock()
×
1720

×
1721
        var (
×
1722
                ctx     = context.TODO()
×
1723
                deleted []*models.ChannelEdgeInfo
×
1724
        )
×
1725
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1726
                for _, chanID := range chanIDs {
×
1727
                        chanIDB := channelIDToBytes(chanID)
×
1728

×
1729
                        row, err := db.GetChannelBySCIDWithPolicies(
×
1730
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1731
                                        Scid:    chanIDB[:],
×
1732
                                        Version: int16(ProtocolV1),
×
1733
                                },
×
1734
                        )
×
1735
                        if errors.Is(err, sql.ErrNoRows) {
×
1736
                                return ErrEdgeNotFound
×
1737
                        } else if err != nil {
×
1738
                                return fmt.Errorf("unable to fetch channel: %w",
×
1739
                                        err)
×
1740
                        }
×
1741

1742
                        node1, node2, err := buildNodeVertices(
×
1743
                                row.Node.PubKey, row.Node_2.PubKey,
×
1744
                        )
×
1745
                        if err != nil {
×
1746
                                return err
×
1747
                        }
×
1748

1749
                        info, err := getAndBuildEdgeInfo(
×
1750
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1751
                                row.Channel, node1, node2,
×
1752
                        )
×
1753
                        if err != nil {
×
1754
                                return err
×
1755
                        }
×
1756

1757
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
1758
                        if err != nil {
×
1759
                                return fmt.Errorf("unable to delete "+
×
1760
                                        "channel: %w", err)
×
1761
                        }
×
1762

1763
                        deleted = append(deleted, info)
×
1764

×
1765
                        if !markZombie {
×
1766
                                continue
×
1767
                        }
1768

1769
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1770
                                info.NodeKey2Bytes
×
1771
                        if strictZombiePruning {
×
1772
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1773
                                if row.Policy1LastUpdate.Valid {
×
1774
                                        e1Time := time.Unix(
×
1775
                                                row.Policy1LastUpdate.Int64, 0,
×
1776
                                        )
×
1777
                                        e1UpdateTime = &e1Time
×
1778
                                }
×
1779
                                if row.Policy2LastUpdate.Valid {
×
1780
                                        e2Time := time.Unix(
×
1781
                                                row.Policy2LastUpdate.Int64, 0,
×
1782
                                        )
×
1783
                                        e2UpdateTime = &e2Time
×
1784
                                }
×
1785

1786
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1787
                                        info, e1UpdateTime, e2UpdateTime,
×
1788
                                )
×
1789
                        }
1790

1791
                        err = db.UpsertZombieChannel(
×
1792
                                ctx, sqlc.UpsertZombieChannelParams{
×
1793
                                        Version:  int16(ProtocolV1),
×
1794
                                        Scid:     chanIDB[:],
×
1795
                                        NodeKey1: nodeKey1[:],
×
1796
                                        NodeKey2: nodeKey2[:],
×
1797
                                },
×
1798
                        )
×
1799
                        if err != nil {
×
1800
                                return fmt.Errorf("unable to mark channel as "+
×
1801
                                        "zombie: %w", err)
×
1802
                        }
×
1803
                }
1804

1805
                return nil
×
1806
        }, func() {
×
1807
                deleted = nil
×
1808
        })
×
1809
        if err != nil {
×
1810
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1811
                        err)
×
1812
        }
×
1813

1814
        for _, chanID := range chanIDs {
×
1815
                s.rejectCache.remove(chanID)
×
1816
                s.chanCache.remove(chanID)
×
1817
        }
×
1818

1819
        return deleted, nil
×
1820
}
1821

1822
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1823
// channel identified by the channel ID. If the channel can't be found, then
1824
// ErrEdgeNotFound is returned. A struct which houses the general information
1825
// for the channel itself is returned as well as two structs that contain the
1826
// routing policies for the channel in either direction.
1827
//
1828
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1829
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1830
// the ChannelEdgeInfo will only include the public keys of each node.
1831
//
1832
// NOTE: part of the V1Store interface.
1833
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1834
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1835
        *models.ChannelEdgePolicy, error) {
×
1836

×
1837
        var (
×
1838
                ctx              = context.TODO()
×
1839
                edge             *models.ChannelEdgeInfo
×
1840
                policy1, policy2 *models.ChannelEdgePolicy
×
1841
        )
×
1842
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1843
                var chanIDB [8]byte
×
1844
                byteOrder.PutUint64(chanIDB[:], chanID)
×
1845

×
1846
                row, err := db.GetChannelBySCIDWithPolicies(
×
1847
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1848
                                Scid:    chanIDB[:],
×
1849
                                Version: int16(ProtocolV1),
×
1850
                        },
×
1851
                )
×
1852
                if errors.Is(err, sql.ErrNoRows) {
×
1853
                        // First check if this edge is perhaps in the zombie
×
1854
                        // index.
×
1855
                        isZombie, err := db.IsZombieChannel(
×
1856
                                ctx, sqlc.IsZombieChannelParams{
×
1857
                                        Scid:    chanIDB[:],
×
1858
                                        Version: int16(ProtocolV1),
×
1859
                                },
×
1860
                        )
×
1861
                        if err != nil {
×
1862
                                return fmt.Errorf("unable to check if "+
×
1863
                                        "channel is zombie: %w", err)
×
1864
                        } else if isZombie {
×
1865
                                return ErrZombieEdge
×
1866
                        }
×
1867

1868
                        return ErrEdgeNotFound
×
1869
                } else if err != nil {
×
1870
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1871
                }
×
1872

1873
                node1, node2, err := buildNodeVertices(
×
1874
                        row.Node.PubKey, row.Node_2.PubKey,
×
1875
                )
×
1876
                if err != nil {
×
1877
                        return err
×
1878
                }
×
1879

1880
                edge, err = getAndBuildEdgeInfo(
×
1881
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1882
                        node1, node2,
×
1883
                )
×
1884
                if err != nil {
×
1885
                        return fmt.Errorf("unable to build channel info: %w",
×
1886
                                err)
×
1887
                }
×
1888

1889
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1890
                if err != nil {
×
1891
                        return fmt.Errorf("unable to extract channel "+
×
1892
                                "policies: %w", err)
×
1893
                }
×
1894

1895
                policy1, policy2, err = getAndBuildChanPolicies(
×
1896
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1897
                )
×
1898
                if err != nil {
×
1899
                        return fmt.Errorf("unable to build channel "+
×
1900
                                "policies: %w", err)
×
1901
                }
×
1902

1903
                return nil
×
1904
        }, sqldb.NoOpReset)
1905
        if err != nil {
×
1906
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1907
                        err)
×
1908
        }
×
1909

1910
        return edge, policy1, policy2, nil
×
1911
}
1912

1913
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1914
// the channel identified by the funding outpoint. If the channel can't be
1915
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1916
// information for the channel itself is returned as well as two structs that
1917
// contain the routing policies for the channel in either direction.
1918
//
1919
// NOTE: part of the V1Store interface.
1920
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1921
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1922
        *models.ChannelEdgePolicy, error) {
×
1923

×
1924
        var (
×
1925
                ctx              = context.TODO()
×
1926
                edge             *models.ChannelEdgeInfo
×
1927
                policy1, policy2 *models.ChannelEdgePolicy
×
1928
        )
×
1929
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1930
                row, err := db.GetChannelByOutpointWithPolicies(
×
1931
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1932
                                Outpoint: op.String(),
×
1933
                                Version:  int16(ProtocolV1),
×
1934
                        },
×
1935
                )
×
1936
                if errors.Is(err, sql.ErrNoRows) {
×
1937
                        return ErrEdgeNotFound
×
1938
                } else if err != nil {
×
1939
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1940
                }
×
1941

1942
                node1, node2, err := buildNodeVertices(
×
1943
                        row.Node1Pubkey, row.Node2Pubkey,
×
1944
                )
×
1945
                if err != nil {
×
1946
                        return err
×
1947
                }
×
1948

1949
                edge, err = getAndBuildEdgeInfo(
×
1950
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1951
                        node1, node2,
×
1952
                )
×
1953
                if err != nil {
×
1954
                        return fmt.Errorf("unable to build channel info: %w",
×
1955
                                err)
×
1956
                }
×
1957

1958
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1959
                if err != nil {
×
1960
                        return fmt.Errorf("unable to extract channel "+
×
1961
                                "policies: %w", err)
×
1962
                }
×
1963

1964
                policy1, policy2, err = getAndBuildChanPolicies(
×
1965
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1966
                )
×
1967
                if err != nil {
×
1968
                        return fmt.Errorf("unable to build channel "+
×
1969
                                "policies: %w", err)
×
1970
                }
×
1971

1972
                return nil
×
1973
        }, sqldb.NoOpReset)
1974
        if err != nil {
×
1975
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1976
                        err)
×
1977
        }
×
1978

1979
        return edge, policy1, policy2, nil
×
1980
}
1981

1982
// HasChannelEdge returns true if the database knows of a channel edge with the
1983
// passed channel ID, and false otherwise. If an edge with that ID is found
1984
// within the graph, then two time stamps representing the last time the edge
1985
// was updated for both directed edges are returned along with the boolean. If
1986
// it is not found, then the zombie index is checked and its result is returned
1987
// as the second boolean.
1988
//
1989
// NOTE: part of the V1Store interface.
1990
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
1991
        bool, error) {
×
1992

×
1993
        ctx := context.TODO()
×
1994

×
1995
        var (
×
1996
                exists          bool
×
1997
                isZombie        bool
×
1998
                node1LastUpdate time.Time
×
1999
                node2LastUpdate time.Time
×
2000
        )
×
2001

×
2002
        // We'll query the cache with the shared lock held to allow multiple
×
2003
        // readers to access values in the cache concurrently if they exist.
×
2004
        s.cacheMu.RLock()
×
2005
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2006
                s.cacheMu.RUnlock()
×
2007
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2008
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2009
                exists, isZombie = entry.flags.unpack()
×
2010

×
2011
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2012
        }
×
2013
        s.cacheMu.RUnlock()
×
2014

×
2015
        s.cacheMu.Lock()
×
2016
        defer s.cacheMu.Unlock()
×
2017

×
2018
        // The item was not found with the shared lock, so we'll acquire the
×
2019
        // exclusive lock and check the cache again in case another method added
×
2020
        // the entry to the cache while no lock was held.
×
2021
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2022
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2023
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2024
                exists, isZombie = entry.flags.unpack()
×
2025

×
2026
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2027
        }
×
2028

2029
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2030
                var chanIDB [8]byte
×
2031
                byteOrder.PutUint64(chanIDB[:], chanID)
×
2032

×
2033
                channel, err := db.GetChannelBySCID(
×
2034
                        ctx, sqlc.GetChannelBySCIDParams{
×
2035
                                Scid:    chanIDB[:],
×
2036
                                Version: int16(ProtocolV1),
×
2037
                        },
×
2038
                )
×
2039
                if errors.Is(err, sql.ErrNoRows) {
×
2040
                        // Check if it is a zombie channel.
×
2041
                        isZombie, err = db.IsZombieChannel(
×
2042
                                ctx, sqlc.IsZombieChannelParams{
×
2043
                                        Scid:    chanIDB[:],
×
2044
                                        Version: int16(ProtocolV1),
×
2045
                                },
×
2046
                        )
×
2047
                        if err != nil {
×
2048
                                return fmt.Errorf("could not check if channel "+
×
2049
                                        "is zombie: %w", err)
×
2050
                        }
×
2051

2052
                        return nil
×
2053
                } else if err != nil {
×
2054
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2055
                }
×
2056

2057
                exists = true
×
2058

×
2059
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2060
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2061
                                Version:   int16(ProtocolV1),
×
2062
                                ChannelID: channel.ID,
×
2063
                                NodeID:    channel.NodeID1,
×
2064
                        },
×
2065
                )
×
2066
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2067
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2068
                                err)
×
2069
                } else if err == nil {
×
2070
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2071
                }
×
2072

2073
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2074
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2075
                                Version:   int16(ProtocolV1),
×
2076
                                ChannelID: channel.ID,
×
2077
                                NodeID:    channel.NodeID2,
×
2078
                        },
×
2079
                )
×
2080
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2081
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2082
                                err)
×
2083
                } else if err == nil {
×
2084
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2085
                }
×
2086

2087
                return nil
×
2088
        }, sqldb.NoOpReset)
2089
        if err != nil {
×
2090
                return time.Time{}, time.Time{}, false, false,
×
2091
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2092
        }
×
2093

2094
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2095
                upd1Time: node1LastUpdate.Unix(),
×
2096
                upd2Time: node2LastUpdate.Unix(),
×
2097
                flags:    packRejectFlags(exists, isZombie),
×
2098
        })
×
2099

×
2100
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2101
}
2102

2103
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2104
// passed channel point (outpoint). If the passed channel doesn't exist within
2105
// the database, then ErrEdgeNotFound is returned.
2106
//
2107
// NOTE: part of the V1Store interface.
2108
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2109
        var (
×
2110
                ctx       = context.TODO()
×
2111
                channelID uint64
×
2112
        )
×
2113
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2114
                chanID, err := db.GetSCIDByOutpoint(
×
2115
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2116
                                Outpoint: chanPoint.String(),
×
2117
                                Version:  int16(ProtocolV1),
×
2118
                        },
×
2119
                )
×
2120
                if errors.Is(err, sql.ErrNoRows) {
×
2121
                        return ErrEdgeNotFound
×
2122
                } else if err != nil {
×
2123
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2124
                                err)
×
2125
                }
×
2126

2127
                channelID = byteOrder.Uint64(chanID)
×
2128

×
2129
                return nil
×
2130
        }, sqldb.NoOpReset)
2131
        if err != nil {
×
2132
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2133
        }
×
2134

2135
        return channelID, nil
×
2136
}
2137

2138
// IsPublicNode is a helper method that determines whether the node with the
2139
// given public key is seen as a public node in the graph from the graph's
2140
// source node's point of view.
2141
//
2142
// NOTE: part of the V1Store interface.
2143
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2144
        ctx := context.TODO()
×
2145

×
2146
        var isPublic bool
×
2147
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2148
                var err error
×
2149
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2150

×
2151
                return err
×
2152
        }, sqldb.NoOpReset)
×
2153
        if err != nil {
×
2154
                return false, fmt.Errorf("unable to check if node is "+
×
2155
                        "public: %w", err)
×
2156
        }
×
2157

2158
        return isPublic, nil
×
2159
}
2160

2161
// FetchChanInfos returns the set of channel edges that correspond to the passed
2162
// channel ID's. If an edge is the query is unknown to the database, it will
2163
// skipped and the result will contain only those edges that exist at the time
2164
// of the query. This can be used to respond to peer queries that are seeking to
2165
// fill in gaps in their view of the channel graph.
2166
//
2167
// NOTE: part of the V1Store interface.
2168
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2169
        var (
×
2170
                ctx   = context.TODO()
×
2171
                edges []ChannelEdge
×
2172
        )
×
2173
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2174
                for _, chanID := range chanIDs {
×
2175
                        var chanIDB [8]byte
×
2176
                        byteOrder.PutUint64(chanIDB[:], chanID)
×
2177

×
2178
                        // TODO(elle): potentially optimize this by using
×
2179
                        //  sqlc.slice() once that works for both SQLite and
×
2180
                        //  Postgres.
×
2181
                        row, err := db.GetChannelBySCIDWithPolicies(
×
2182
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
2183
                                        Scid:    chanIDB[:],
×
2184
                                        Version: int16(ProtocolV1),
×
2185
                                },
×
2186
                        )
×
2187
                        if errors.Is(err, sql.ErrNoRows) {
×
2188
                                continue
×
2189
                        } else if err != nil {
×
2190
                                return fmt.Errorf("unable to fetch channel: %w",
×
2191
                                        err)
×
2192
                        }
×
2193

2194
                        node1, node2, err := buildNodes(
×
2195
                                ctx, db, row.Node, row.Node_2,
×
2196
                        )
×
2197
                        if err != nil {
×
2198
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2199
                                        err)
×
2200
                        }
×
2201

2202
                        edge, err := getAndBuildEdgeInfo(
×
2203
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2204
                                row.Channel, node1.PubKeyBytes,
×
2205
                                node2.PubKeyBytes,
×
2206
                        )
×
2207
                        if err != nil {
×
2208
                                return fmt.Errorf("unable to build "+
×
2209
                                        "channel info: %w", err)
×
2210
                        }
×
2211

2212
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2213
                        if err != nil {
×
2214
                                return fmt.Errorf("unable to extract channel "+
×
2215
                                        "policies: %w", err)
×
2216
                        }
×
2217

2218
                        p1, p2, err := getAndBuildChanPolicies(
×
2219
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2220
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2221
                        )
×
2222
                        if err != nil {
×
2223
                                return fmt.Errorf("unable to build channel "+
×
2224
                                        "policies: %w", err)
×
2225
                        }
×
2226

2227
                        edges = append(edges, ChannelEdge{
×
2228
                                Info:    edge,
×
2229
                                Policy1: p1,
×
2230
                                Policy2: p2,
×
2231
                                Node1:   node1,
×
2232
                                Node2:   node2,
×
2233
                        })
×
2234
                }
2235

2236
                return nil
×
2237
        }, func() {
×
2238
                edges = nil
×
2239
        })
×
2240
        if err != nil {
×
2241
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2242
        }
×
2243

2244
        return edges, nil
×
2245
}
2246

2247
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2248
// ID's that we don't know and are not known zombies of the passed set. In other
2249
// words, we perform a set difference of our set of chan ID's and the ones
2250
// passed in. This method can be used by callers to determine the set of
2251
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2252
// known zombies is also returned.
2253
//
2254
// NOTE: part of the V1Store interface.
2255
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2256
        []ChannelUpdateInfo, error) {
×
2257

×
2258
        var (
×
2259
                ctx          = context.TODO()
×
2260
                newChanIDs   []uint64
×
2261
                knownZombies []ChannelUpdateInfo
×
2262
        )
×
2263
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2264
                for _, chanInfo := range chansInfo {
×
2265
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2266
                        var chanIDB [8]byte
×
2267
                        byteOrder.PutUint64(chanIDB[:], channelID)
×
2268

×
2269
                        // TODO(elle): potentially optimize this by using
×
2270
                        //  sqlc.slice() once that works for both SQLite and
×
2271
                        //  Postgres.
×
2272
                        _, err := db.GetChannelBySCID(
×
2273
                                ctx, sqlc.GetChannelBySCIDParams{
×
2274
                                        Version: int16(ProtocolV1),
×
2275
                                        Scid:    chanIDB[:],
×
2276
                                },
×
2277
                        )
×
2278
                        if err == nil {
×
2279
                                continue
×
2280
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
2281
                                return fmt.Errorf("unable to fetch channel: %w",
×
2282
                                        err)
×
2283
                        }
×
2284

2285
                        isZombie, err := db.IsZombieChannel(
×
2286
                                ctx, sqlc.IsZombieChannelParams{
×
2287
                                        Scid:    chanIDB[:],
×
2288
                                        Version: int16(ProtocolV1),
×
2289
                                },
×
2290
                        )
×
2291
                        if err != nil {
×
2292
                                return fmt.Errorf("unable to fetch zombie "+
×
2293
                                        "channel: %w", err)
×
2294
                        }
×
2295

2296
                        if isZombie {
×
2297
                                knownZombies = append(knownZombies, chanInfo)
×
2298

×
2299
                                continue
×
2300
                        }
2301

2302
                        newChanIDs = append(newChanIDs, channelID)
×
2303
                }
2304

2305
                return nil
×
2306
        }, func() {
×
2307
                newChanIDs = nil
×
2308
                knownZombies = nil
×
2309
        })
×
2310
        if err != nil {
×
2311
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2312
        }
×
2313

2314
        return newChanIDs, knownZombies, nil
×
2315
}
2316

2317
// PruneGraphNodes is a garbage collection method which attempts to prune out
2318
// any nodes from the channel graph that are currently unconnected. This ensure
2319
// that we only maintain a graph of reachable nodes. In the event that a pruned
2320
// node gains more channels, it will be re-added back to the graph.
2321
//
2322
// NOTE: this prunes nodes across protocol versions. It will never prune the
2323
// source nodes.
2324
//
2325
// NOTE: part of the V1Store interface.
2326
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2327
        var ctx = context.TODO()
×
2328

×
2329
        var prunedNodes []route.Vertex
×
2330
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2331
                var err error
×
2332
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2333

×
2334
                return err
×
2335
        }, func() {
×
2336
                prunedNodes = nil
×
2337
        })
×
2338
        if err != nil {
×
2339
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2340
        }
×
2341

2342
        return prunedNodes, nil
×
2343
}
2344

2345
// PruneGraph prunes newly closed channels from the channel graph in response
2346
// to a new block being solved on the network. Any transactions which spend the
2347
// funding output of any known channels within he graph will be deleted.
2348
// Additionally, the "prune tip", or the last block which has been used to
2349
// prune the graph is stored so callers can ensure the graph is fully in sync
2350
// with the current UTXO state. A slice of channels that have been closed by
2351
// the target block along with any pruned nodes are returned if the function
2352
// succeeds without error.
2353
//
2354
// NOTE: part of the V1Store interface.
2355
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2356
        blockHash *chainhash.Hash, blockHeight uint32) (
2357
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2358

×
2359
        ctx := context.TODO()
×
2360

×
2361
        s.cacheMu.Lock()
×
2362
        defer s.cacheMu.Unlock()
×
2363

×
2364
        var (
×
2365
                closedChans []*models.ChannelEdgeInfo
×
2366
                prunedNodes []route.Vertex
×
2367
        )
×
2368
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2369
                for _, outpoint := range spentOutputs {
×
2370
                        // TODO(elle): potentially optimize this by using
×
2371
                        //  sqlc.slice() once that works for both SQLite and
×
2372
                        //  Postgres.
×
2373
                        //
×
2374
                        // NOTE: this fetches channels for all protocol
×
2375
                        // versions.
×
2376
                        row, err := db.GetChannelByOutpoint(
×
2377
                                ctx, outpoint.String(),
×
2378
                        )
×
2379
                        if errors.Is(err, sql.ErrNoRows) {
×
2380
                                continue
×
2381
                        } else if err != nil {
×
2382
                                return fmt.Errorf("unable to fetch channel: %w",
×
2383
                                        err)
×
2384
                        }
×
2385

2386
                        node1, node2, err := buildNodeVertices(
×
2387
                                row.Node1Pubkey, row.Node2Pubkey,
×
2388
                        )
×
2389
                        if err != nil {
×
2390
                                return err
×
2391
                        }
×
2392

2393
                        info, err := getAndBuildEdgeInfo(
×
2394
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2395
                                row.Channel, node1, node2,
×
2396
                        )
×
2397
                        if err != nil {
×
2398
                                return err
×
2399
                        }
×
2400

2401
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2402
                        if err != nil {
×
2403
                                return fmt.Errorf("unable to delete "+
×
2404
                                        "channel: %w", err)
×
2405
                        }
×
2406

2407
                        closedChans = append(closedChans, info)
×
2408
                }
2409

2410
                err := db.UpsertPruneLogEntry(
×
2411
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2412
                                BlockHash:   blockHash[:],
×
2413
                                BlockHeight: int64(blockHeight),
×
2414
                        },
×
2415
                )
×
2416
                if err != nil {
×
2417
                        return fmt.Errorf("unable to insert prune log "+
×
2418
                                "entry: %w", err)
×
2419
                }
×
2420

2421
                // Now that we've pruned some channels, we'll also prune any
2422
                // nodes that no longer have any channels.
2423
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2424
                if err != nil {
×
2425
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2426
                                err)
×
2427
                }
×
2428

2429
                return nil
×
2430
        }, func() {
×
2431
                prunedNodes = nil
×
2432
                closedChans = nil
×
2433
        })
×
2434
        if err != nil {
×
2435
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2436
        }
×
2437

2438
        for _, channel := range closedChans {
×
2439
                s.rejectCache.remove(channel.ChannelID)
×
2440
                s.chanCache.remove(channel.ChannelID)
×
2441
        }
×
2442

2443
        return closedChans, prunedNodes, nil
×
2444
}
2445

2446
// ChannelView returns the verifiable edge information for each active channel
2447
// within the known channel graph. The set of UTXOs (along with their scripts)
2448
// returned are the ones that need to be watched on chain to detect channel
2449
// closes on the resident blockchain.
2450
//
2451
// NOTE: part of the V1Store interface.
2452
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2453
        var (
×
2454
                ctx        = context.TODO()
×
2455
                edgePoints []EdgePoint
×
2456
        )
×
2457

×
2458
        handleChannel := func(db SQLQueries,
×
2459
                channel sqlc.ListChannelsPaginatedRow) error {
×
2460

×
2461
                pkScript, err := genMultiSigP2WSH(
×
2462
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
2463
                )
×
2464
                if err != nil {
×
2465
                        return err
×
2466
                }
×
2467

2468
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2469
                if err != nil {
×
2470
                        return err
×
2471
                }
×
2472

2473
                edgePoints = append(edgePoints, EdgePoint{
×
2474
                        FundingPkScript: pkScript,
×
2475
                        OutPoint:        *op,
×
2476
                })
×
2477

×
2478
                return nil
×
2479
        }
2480

2481
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2482
                lastID := int64(-1)
×
2483
                for {
×
2484
                        rows, err := db.ListChannelsPaginated(
×
2485
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2486
                                        Version: int16(ProtocolV1),
×
2487
                                        ID:      lastID,
×
2488
                                        Limit:   pageSize,
×
2489
                                },
×
2490
                        )
×
2491
                        if err != nil {
×
2492
                                return err
×
2493
                        }
×
2494

2495
                        if len(rows) == 0 {
×
2496
                                break
×
2497
                        }
2498

2499
                        for _, row := range rows {
×
2500
                                err := handleChannel(db, row)
×
2501
                                if err != nil {
×
2502
                                        return err
×
2503
                                }
×
2504

2505
                                lastID = row.ID
×
2506
                        }
2507
                }
2508

2509
                return nil
×
2510
        }, func() {
×
2511
                edgePoints = nil
×
2512
        })
×
2513
        if err != nil {
×
2514
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2515
        }
×
2516

2517
        return edgePoints, nil
×
2518
}
2519

2520
// PruneTip returns the block height and hash of the latest block that has been
2521
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2522
// to tell if the graph is currently in sync with the current best known UTXO
2523
// state.
2524
//
2525
// NOTE: part of the V1Store interface.
2526
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2527
        var (
×
2528
                ctx       = context.TODO()
×
2529
                tipHash   chainhash.Hash
×
2530
                tipHeight uint32
×
2531
        )
×
2532
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2533
                pruneTip, err := db.GetPruneTip(ctx)
×
2534
                if errors.Is(err, sql.ErrNoRows) {
×
2535
                        return ErrGraphNeverPruned
×
2536
                } else if err != nil {
×
2537
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2538
                }
×
2539

2540
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2541
                tipHeight = uint32(pruneTip.BlockHeight)
×
2542

×
2543
                return nil
×
2544
        }, sqldb.NoOpReset)
2545
        if err != nil {
×
2546
                return nil, 0, err
×
2547
        }
×
2548

2549
        return &tipHash, tipHeight, nil
×
2550
}
2551

2552
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2553
//
2554
// NOTE: this prunes nodes across protocol versions. It will never prune the
2555
// source nodes.
2556
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2557
        db SQLQueries) ([]route.Vertex, error) {
×
2558

×
2559
        // Fetch all un-connected nodes from the database.
×
2560
        // NOTE: this will not include any nodes that are listed in the
×
2561
        // source table.
×
2562
        nodes, err := db.GetUnconnectedNodes(ctx)
×
2563
        if err != nil {
×
2564
                return nil, fmt.Errorf("unable to fetch unconnected nodes: %w",
×
2565
                        err)
×
2566
        }
×
2567

2568
        prunedNodes := make([]route.Vertex, 0, len(nodes))
×
2569
        for _, node := range nodes {
×
2570
                // TODO(elle): update to use sqlc.slice() once that works.
×
2571
                if err = db.DeleteNode(ctx, node.ID); err != nil {
×
2572
                        return nil, fmt.Errorf("unable to delete "+
×
2573
                                "node(id=%d): %w", node.ID, err)
×
2574
                }
×
2575

2576
                pubKey, err := route.NewVertexFromBytes(node.PubKey)
×
2577
                if err != nil {
×
2578
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2579
                                "for node(id=%d): %w", node.ID, err)
×
2580
                }
×
2581

2582
                prunedNodes = append(prunedNodes, pubKey)
×
2583
        }
2584

2585
        return prunedNodes, nil
×
2586
}
2587

2588
// DisconnectBlockAtHeight is used to indicate that the block specified
2589
// by the passed height has been disconnected from the main chain. This
2590
// will "rewind" the graph back to the height below, deleting channels
2591
// that are no longer confirmed from the graph. The prune log will be
2592
// set to the last prune height valid for the remaining chain.
2593
// Channels that were removed from the graph resulting from the
2594
// disconnected block are returned.
2595
//
2596
// NOTE: part of the V1Store interface.
2597
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2598
        []*models.ChannelEdgeInfo, error) {
×
2599

×
2600
        ctx := context.TODO()
×
2601

×
2602
        var (
×
2603
                // Every channel having a ShortChannelID starting at 'height'
×
2604
                // will no longer be confirmed.
×
2605
                startShortChanID = lnwire.ShortChannelID{
×
2606
                        BlockHeight: height,
×
2607
                }
×
2608

×
2609
                // Delete everything after this height from the db up until the
×
2610
                // SCID alias range.
×
2611
                endShortChanID = aliasmgr.StartingAlias
×
2612

×
2613
                removedChans []*models.ChannelEdgeInfo
×
2614
        )
×
2615

×
2616
        var chanIDStart [8]byte
×
2617
        byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64())
×
2618
        var chanIDEnd [8]byte
×
2619
        byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64())
×
2620

×
2621
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2622
                rows, err := db.GetChannelsBySCIDRange(
×
2623
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2624
                                StartScid: chanIDStart[:],
×
2625
                                EndScid:   chanIDEnd[:],
×
2626
                        },
×
2627
                )
×
2628
                if err != nil {
×
2629
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2630
                }
×
2631

2632
                for _, row := range rows {
×
2633
                        node1, node2, err := buildNodeVertices(
×
2634
                                row.Node1PubKey, row.Node2PubKey,
×
2635
                        )
×
2636
                        if err != nil {
×
2637
                                return err
×
2638
                        }
×
2639

2640
                        channel, err := getAndBuildEdgeInfo(
×
2641
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2642
                                row.Channel, node1, node2,
×
2643
                        )
×
2644
                        if err != nil {
×
2645
                                return err
×
2646
                        }
×
2647

2648
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
2649
                        if err != nil {
×
2650
                                return fmt.Errorf("unable to delete "+
×
2651
                                        "channel: %w", err)
×
2652
                        }
×
2653

2654
                        removedChans = append(removedChans, channel)
×
2655
                }
2656

2657
                return db.DeletePruneLogEntriesInRange(
×
2658
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2659
                                StartHeight: int64(height),
×
2660
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2661
                        },
×
2662
                )
×
2663
        }, func() {
×
2664
                removedChans = nil
×
2665
        })
×
2666
        if err != nil {
×
2667
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2668
                        "height: %w", err)
×
2669
        }
×
2670

2671
        for _, channel := range removedChans {
×
2672
                s.rejectCache.remove(channel.ChannelID)
×
2673
                s.chanCache.remove(channel.ChannelID)
×
2674
        }
×
2675

2676
        return removedChans, nil
×
2677
}
2678

2679
// AddEdgeProof sets the proof of an existing edge in the graph database.
2680
//
2681
// NOTE: part of the V1Store interface.
2682
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2683
        proof *models.ChannelAuthProof) error {
×
2684

×
2685
        var (
×
2686
                ctx       = context.TODO()
×
2687
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2688
        )
×
2689

×
2690
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2691
                res, err := db.AddV1ChannelProof(
×
2692
                        ctx, sqlc.AddV1ChannelProofParams{
×
2693
                                Scid:              scidBytes[:],
×
2694
                                Node1Signature:    proof.NodeSig1Bytes,
×
2695
                                Node2Signature:    proof.NodeSig2Bytes,
×
2696
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2697
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2698
                        },
×
2699
                )
×
2700
                if err != nil {
×
2701
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2702
                }
×
2703

2704
                n, err := res.RowsAffected()
×
2705
                if err != nil {
×
2706
                        return err
×
2707
                }
×
2708

2709
                if n == 0 {
×
2710
                        return fmt.Errorf("no rows affected when adding edge "+
×
2711
                                "proof for SCID %v", scid)
×
2712
                } else if n > 1 {
×
2713
                        return fmt.Errorf("multiple rows affected when adding "+
×
2714
                                "edge proof for SCID %v: %d rows affected",
×
2715
                                scid, n)
×
2716
                }
×
2717

2718
                return nil
×
2719
        }, sqldb.NoOpReset)
2720
        if err != nil {
×
2721
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2722
        }
×
2723

2724
        return nil
×
2725
}
2726

2727
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2728
// that we can ignore channel announcements that we know to be closed without
2729
// having to validate them and fetch a block.
2730
//
2731
// NOTE: part of the V1Store interface.
2732
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2733
        var (
×
2734
                ctx     = context.TODO()
×
2735
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2736
        )
×
2737

×
2738
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2739
                return db.InsertClosedChannel(ctx, chanIDB[:])
×
2740
        }, sqldb.NoOpReset)
×
2741
}
2742

2743
// IsClosedScid checks whether a channel identified by the passed in scid is
2744
// closed. This helps avoid having to perform expensive validation checks.
2745
//
2746
// NOTE: part of the V1Store interface.
2747
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2748
        var (
×
2749
                ctx      = context.TODO()
×
2750
                isClosed bool
×
2751
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2752
        )
×
2753
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2754
                var err error
×
2755
                isClosed, err = db.IsClosedChannel(ctx, chanIDB[:])
×
2756
                if err != nil {
×
2757
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2758
                                err)
×
2759
                }
×
2760

2761
                return nil
×
2762
        }, sqldb.NoOpReset)
2763
        if err != nil {
×
2764
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2765
                        err)
×
2766
        }
×
2767

2768
        return isClosed, nil
×
2769
}
2770

2771
// GraphSession will provide the call-back with access to a NodeTraverser
2772
// instance which can be used to perform queries against the channel graph.
2773
//
2774
// NOTE: part of the V1Store interface.
2775
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error) error {
×
2776
        var ctx = context.TODO()
×
2777

×
2778
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2779
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2780
        }, sqldb.NoOpReset)
×
2781
}
2782

2783
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2784
// read only transaction for a consistent view of the graph.
2785
type sqlNodeTraverser struct {
2786
        db    SQLQueries
2787
        chain chainhash.Hash
2788
}
2789

2790
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2791
// NodeTraverser interface.
2792
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2793

2794
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2795
func newSQLNodeTraverser(db SQLQueries,
2796
        chain chainhash.Hash) *sqlNodeTraverser {
×
2797

×
2798
        return &sqlNodeTraverser{
×
2799
                db:    db,
×
2800
                chain: chain,
×
2801
        }
×
2802
}
×
2803

2804
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2805
// node.
2806
//
2807
// NOTE: Part of the NodeTraverser interface.
2808
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2809
        cb func(channel *DirectedChannel) error) error {
×
2810

×
2811
        ctx := context.TODO()
×
2812

×
2813
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2814
}
×
2815

2816
// FetchNodeFeatures returns the features of the given node. If the node is
2817
// unknown, assume no additional features are supported.
2818
//
2819
// NOTE: Part of the NodeTraverser interface.
2820
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2821
        *lnwire.FeatureVector, error) {
×
2822

×
2823
        ctx := context.TODO()
×
2824

×
2825
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2826
}
×
2827

2828
// forEachNodeDirectedChannel iterates through all channels of a given
2829
// node, executing the passed callback on the directed edge representing the
2830
// channel and its incoming policy. If the node is not found, no error is
2831
// returned.
2832
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2833
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2834

×
2835
        toNodeCallback := func() route.Vertex {
×
2836
                return nodePub
×
2837
        }
×
2838

2839
        dbID, err := db.GetNodeIDByPubKey(
×
2840
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2841
                        Version: int16(ProtocolV1),
×
2842
                        PubKey:  nodePub[:],
×
2843
                },
×
2844
        )
×
2845
        if errors.Is(err, sql.ErrNoRows) {
×
2846
                return nil
×
2847
        } else if err != nil {
×
2848
                return fmt.Errorf("unable to fetch node: %w", err)
×
2849
        }
×
2850

2851
        rows, err := db.ListChannelsByNodeID(
×
2852
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2853
                        Version: int16(ProtocolV1),
×
2854
                        NodeID1: dbID,
×
2855
                },
×
2856
        )
×
2857
        if err != nil {
×
2858
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2859
        }
×
2860

2861
        // Exit early if there are no channels for this node so we don't
2862
        // do the unnecessary feature fetching.
2863
        if len(rows) == 0 {
×
2864
                return nil
×
2865
        }
×
2866

2867
        features, err := getNodeFeatures(ctx, db, dbID)
×
2868
        if err != nil {
×
2869
                return fmt.Errorf("unable to fetch node features: %w", err)
×
2870
        }
×
2871

2872
        for _, row := range rows {
×
2873
                node1, node2, err := buildNodeVertices(
×
2874
                        row.Node1Pubkey, row.Node2Pubkey,
×
2875
                )
×
2876
                if err != nil {
×
2877
                        return fmt.Errorf("unable to build node vertices: %w",
×
2878
                                err)
×
2879
                }
×
2880

2881
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
2882

×
2883
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2884
                if err != nil {
×
2885
                        return err
×
2886
                }
×
2887

2888
                var p1, p2 *models.CachedEdgePolicy
×
2889
                if dbPol1 != nil {
×
2890
                        policy1, err := buildChanPolicy(
×
2891
                                *dbPol1, edge.ChannelID, nil, node2, true,
×
2892
                        )
×
2893
                        if err != nil {
×
2894
                                return err
×
2895
                        }
×
2896

2897
                        p1 = models.NewCachedPolicy(policy1)
×
2898
                }
2899
                if dbPol2 != nil {
×
2900
                        policy2, err := buildChanPolicy(
×
2901
                                *dbPol2, edge.ChannelID, nil, node1, false,
×
2902
                        )
×
2903
                        if err != nil {
×
2904
                                return err
×
2905
                        }
×
2906

2907
                        p2 = models.NewCachedPolicy(policy2)
×
2908
                }
2909

2910
                // Determine the outgoing and incoming policy for this
2911
                // channel and node combo.
2912
                outPolicy, inPolicy := p1, p2
×
2913
                if p1 != nil && node2 == nodePub {
×
2914
                        outPolicy, inPolicy = p2, p1
×
2915
                } else if p2 != nil && node1 != nodePub {
×
2916
                        outPolicy, inPolicy = p2, p1
×
2917
                }
×
2918

2919
                var cachedInPolicy *models.CachedEdgePolicy
×
2920
                if inPolicy != nil {
×
2921
                        cachedInPolicy = inPolicy
×
2922
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2923
                        cachedInPolicy.ToNodeFeatures = features
×
2924
                }
×
2925

2926
                directedChannel := &DirectedChannel{
×
2927
                        ChannelID:    edge.ChannelID,
×
2928
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2929
                        OtherNode:    edge.NodeKey2Bytes,
×
2930
                        Capacity:     edge.Capacity,
×
2931
                        OutPolicySet: outPolicy != nil,
×
2932
                        InPolicy:     cachedInPolicy,
×
2933
                }
×
2934
                if outPolicy != nil {
×
2935
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2936
                                directedChannel.InboundFee = fee
×
2937
                        })
×
2938
                }
2939

2940
                if nodePub == edge.NodeKey2Bytes {
×
2941
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2942
                }
×
2943

2944
                if err := cb(directedChannel); err != nil {
×
2945
                        return err
×
2946
                }
×
2947
        }
2948

2949
        return nil
×
2950
}
2951

2952
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
2953
// and executes the provided callback for each node.
2954
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
2955
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
2956

×
2957
        lastID := int64(-1)
×
2958

×
2959
        for {
×
2960
                nodes, err := db.ListNodeIDsAndPubKeys(
×
2961
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
2962
                                Version: int16(ProtocolV1),
×
2963
                                ID:      lastID,
×
2964
                                Limit:   pageSize,
×
2965
                        },
×
2966
                )
×
2967
                if err != nil {
×
2968
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
2969
                }
×
2970

2971
                if len(nodes) == 0 {
×
2972
                        break
×
2973
                }
2974

2975
                for _, node := range nodes {
×
2976
                        var pub route.Vertex
×
2977
                        copy(pub[:], node.PubKey)
×
2978

×
2979
                        if err := cb(node.ID, pub); err != nil {
×
2980
                                return fmt.Errorf("forEachNodeCacheable "+
×
2981
                                        "callback failed for node(id=%d): %w",
×
2982
                                        node.ID, err)
×
2983
                        }
×
2984

2985
                        lastID = node.ID
×
2986
                }
2987
        }
2988

2989
        return nil
×
2990
}
2991

2992
// forEachNodeChannel iterates through all channels of a node, executing
2993
// the passed callback on each. The call-back is provided with the channel's
2994
// edge information, the outgoing policy and the incoming policy for the
2995
// channel and node combo.
2996
func forEachNodeChannel(ctx context.Context, db SQLQueries,
2997
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
2998
                *models.ChannelEdgePolicy,
2999
                *models.ChannelEdgePolicy) error) error {
×
3000

×
3001
        // Get all the V1 channels for this node.Add commentMore actions
×
3002
        rows, err := db.ListChannelsByNodeID(
×
3003
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3004
                        Version: int16(ProtocolV1),
×
3005
                        NodeID1: id,
×
3006
                },
×
3007
        )
×
3008
        if err != nil {
×
3009
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3010
        }
×
3011

3012
        // Call the call-back for each channel and its known policies.
3013
        for _, row := range rows {
×
3014
                node1, node2, err := buildNodeVertices(
×
3015
                        row.Node1Pubkey, row.Node2Pubkey,
×
3016
                )
×
3017
                if err != nil {
×
3018
                        return fmt.Errorf("unable to build node vertices: %w",
×
3019
                                err)
×
3020
                }
×
3021

3022
                edge, err := getAndBuildEdgeInfo(
×
3023
                        ctx, db, chain, row.Channel.ID, row.Channel, node1,
×
3024
                        node2,
×
3025
                )
×
3026
                if err != nil {
×
3027
                        return fmt.Errorf("unable to build channel info: %w",
×
3028
                                err)
×
3029
                }
×
3030

3031
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3032
                if err != nil {
×
3033
                        return fmt.Errorf("unable to extract channel "+
×
3034
                                "policies: %w", err)
×
3035
                }
×
3036

3037
                p1, p2, err := getAndBuildChanPolicies(
×
3038
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3039
                )
×
3040
                if err != nil {
×
3041
                        return fmt.Errorf("unable to build channel "+
×
3042
                                "policies: %w", err)
×
3043
                }
×
3044

3045
                // Determine the outgoing and incoming policy for this
3046
                // channel and node combo.
3047
                p1ToNode := row.Channel.NodeID2
×
3048
                p2ToNode := row.Channel.NodeID1
×
3049
                outPolicy, inPolicy := p1, p2
×
3050
                if (p1 != nil && p1ToNode == id) ||
×
3051
                        (p2 != nil && p2ToNode != id) {
×
3052

×
3053
                        outPolicy, inPolicy = p2, p1
×
3054
                }
×
3055

3056
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3057
                        return err
×
3058
                }
×
3059
        }
3060

3061
        return nil
×
3062
}
3063

3064
// updateChanEdgePolicy upserts the channel policy info we have stored for
3065
// a channel we already know of.
3066
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3067
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3068
        error) {
×
3069

×
3070
        var (
×
3071
                node1Pub, node2Pub route.Vertex
×
3072
                isNode1            bool
×
3073
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3074
        )
×
3075

×
3076
        // Check that this edge policy refers to a channel that we already
×
3077
        // know of. We do this explicitly so that we can return the appropriate
×
3078
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3079
        // abort the transaction which would abort the entire batch.
×
3080
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3081
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3082
                        Scid:    chanIDB[:],
×
3083
                        Version: int16(ProtocolV1),
×
3084
                },
×
3085
        )
×
3086
        if errors.Is(err, sql.ErrNoRows) {
×
3087
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3088
        } else if err != nil {
×
3089
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3090
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3091
        }
×
3092

3093
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3094
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3095

×
3096
        // Figure out which node this edge is from.
×
3097
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3098
        nodeID := dbChan.NodeID1
×
3099
        if !isNode1 {
×
3100
                nodeID = dbChan.NodeID2
×
3101
        }
×
3102

3103
        var (
×
3104
                inboundBase sql.NullInt64
×
3105
                inboundRate sql.NullInt64
×
3106
        )
×
3107
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3108
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3109
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3110
        })
×
3111

3112
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3113
                Version:     int16(ProtocolV1),
×
3114
                ChannelID:   dbChan.ID,
×
3115
                NodeID:      nodeID,
×
3116
                Timelock:    int32(edge.TimeLockDelta),
×
3117
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3118
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3119
                MinHtlcMsat: int64(edge.MinHTLC),
×
3120
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3121
                Disabled: sql.NullBool{
×
3122
                        Valid: true,
×
3123
                        Bool:  edge.IsDisabled(),
×
3124
                },
×
3125
                MaxHtlcMsat: sql.NullInt64{
×
3126
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3127
                        Int64: int64(edge.MaxHTLC),
×
3128
                },
×
3129
                InboundBaseFeeMsat:      inboundBase,
×
3130
                InboundFeeRateMilliMsat: inboundRate,
×
3131
                Signature:               edge.SigBytes,
×
3132
        })
×
3133
        if err != nil {
×
3134
                return node1Pub, node2Pub, isNode1,
×
3135
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3136
        }
×
3137

3138
        // Convert the flat extra opaque data into a map of TLV types to
3139
        // values.
3140
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3141
        if err != nil {
×
3142
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3143
                        "marshal extra opaque data: %w", err)
×
3144
        }
×
3145

3146
        // Update the channel policy's extra signed fields.
3147
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3148
        if err != nil {
×
3149
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3150
                        "policy extra TLVs: %w", err)
×
3151
        }
×
3152

3153
        return node1Pub, node2Pub, isNode1, nil
×
3154
}
3155

3156
// getNodeByPubKey attempts to look up a target node by its public key.
3157
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3158
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3159

×
3160
        dbNode, err := db.GetNodeByPubKey(
×
3161
                ctx, sqlc.GetNodeByPubKeyParams{
×
3162
                        Version: int16(ProtocolV1),
×
3163
                        PubKey:  pubKey[:],
×
3164
                },
×
3165
        )
×
3166
        if errors.Is(err, sql.ErrNoRows) {
×
3167
                return 0, nil, ErrGraphNodeNotFound
×
3168
        } else if err != nil {
×
3169
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3170
        }
×
3171

3172
        node, err := buildNode(ctx, db, &dbNode)
×
3173
        if err != nil {
×
3174
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3175
        }
×
3176

3177
        return dbNode.ID, node, nil
×
3178
}
3179

3180
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3181
// provided database channel row and the public keys of the two nodes
3182
// involved in the channel.
3183
func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub,
3184
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3185

×
3186
        return &models.CachedEdgeInfo{
×
3187
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
3188
                NodeKey1Bytes: node1Pub,
×
3189
                NodeKey2Bytes: node2Pub,
×
3190
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
3191
        }
×
3192
}
×
3193

3194
// buildNode constructs a LightningNode instance from the given database node
3195
// record. The node's features, addresses and extra signed fields are also
3196
// fetched from the database and set on the node.
3197
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
3198
        *models.LightningNode, error) {
×
3199

×
3200
        if dbNode.Version != int16(ProtocolV1) {
×
3201
                return nil, fmt.Errorf("unsupported node version: %d",
×
3202
                        dbNode.Version)
×
3203
        }
×
3204

3205
        var pub [33]byte
×
3206
        copy(pub[:], dbNode.PubKey)
×
3207

×
3208
        node := &models.LightningNode{
×
3209
                PubKeyBytes: pub,
×
3210
                Features:    lnwire.EmptyFeatureVector(),
×
3211
                LastUpdate:  time.Unix(0, 0),
×
3212
        }
×
3213

×
3214
        if len(dbNode.Signature) == 0 {
×
3215
                return node, nil
×
3216
        }
×
3217

3218
        node.HaveNodeAnnouncement = true
×
3219
        node.AuthSigBytes = dbNode.Signature
×
3220
        node.Alias = dbNode.Alias.String
×
3221
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3222

×
3223
        var err error
×
3224
        if dbNode.Color.Valid {
×
3225
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3226
                if err != nil {
×
3227
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3228
                                err)
×
3229
                }
×
3230
        }
3231

3232
        // Fetch the node's features.
3233
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
3234
        if err != nil {
×
3235
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3236
                        "features: %w", dbNode.ID, err)
×
3237
        }
×
3238

3239
        // Fetch the node's addresses.
3240
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
3241
        if err != nil {
×
3242
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3243
                        "addresses: %w", dbNode.ID, err)
×
3244
        }
×
3245

3246
        // Fetch the node's extra signed fields.
3247
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
3248
        if err != nil {
×
3249
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
3250
                        "extra signed fields: %w", dbNode.ID, err)
×
3251
        }
×
3252

3253
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
3254
        if err != nil {
×
3255
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3256
                        "fields: %w", err)
×
3257
        }
×
3258

3259
        if len(recs) != 0 {
×
3260
                node.ExtraOpaqueData = recs
×
3261
        }
×
3262

3263
        return node, nil
×
3264
}
3265

3266
// getNodeFeatures fetches the feature bits and constructs the feature vector
3267
// for a node with the given DB ID.
3268
func getNodeFeatures(ctx context.Context, db SQLQueries,
3269
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3270

×
3271
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3272
        if err != nil {
×
3273
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3274
                        nodeID, err)
×
3275
        }
×
3276

3277
        features := lnwire.EmptyFeatureVector()
×
3278
        for _, feature := range rows {
×
3279
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3280
        }
×
3281

3282
        return features, nil
×
3283
}
3284

3285
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
3286
// given DB ID.
3287
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3288
        nodeID int64) (map[uint64][]byte, error) {
×
3289

×
3290
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3291
        if err != nil {
×
3292
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
3293
                        "signed fields: %w", nodeID, err)
×
3294
        }
×
3295

3296
        extraFields := make(map[uint64][]byte)
×
3297
        for _, field := range fields {
×
3298
                extraFields[uint64(field.Type)] = field.Value
×
3299
        }
×
3300

3301
        return extraFields, nil
×
3302
}
3303

3304
// upsertNode upserts the node record into the database. If the node already
3305
// exists, then the node's information is updated. If the node doesn't exist,
3306
// then a new node is created. The node's features, addresses and extra TLV
3307
// types are also updated. The node's DB ID is returned.
3308
func upsertNode(ctx context.Context, db SQLQueries,
3309
        node *models.LightningNode) (int64, error) {
×
3310

×
3311
        params := sqlc.UpsertNodeParams{
×
3312
                Version: int16(ProtocolV1),
×
3313
                PubKey:  node.PubKeyBytes[:],
×
3314
        }
×
3315

×
3316
        if node.HaveNodeAnnouncement {
×
3317
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3318
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3319
                params.Alias = sqldb.SQLStr(node.Alias)
×
3320
                params.Signature = node.AuthSigBytes
×
3321
        }
×
3322

3323
        nodeID, err := db.UpsertNode(ctx, params)
×
3324
        if err != nil {
×
3325
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3326
                        err)
×
3327
        }
×
3328

3329
        // We can exit here if we don't have the announcement yet.
3330
        if !node.HaveNodeAnnouncement {
×
3331
                return nodeID, nil
×
3332
        }
×
3333

3334
        // Update the node's features.
3335
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3336
        if err != nil {
×
3337
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3338
        }
×
3339

3340
        // Update the node's addresses.
3341
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3342
        if err != nil {
×
3343
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3344
        }
×
3345

3346
        // Convert the flat extra opaque data into a map of TLV types to
3347
        // values.
3348
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3349
        if err != nil {
×
3350
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3351
                        err)
×
3352
        }
×
3353

3354
        // Update the node's extra signed fields.
3355
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3356
        if err != nil {
×
3357
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3358
        }
×
3359

3360
        return nodeID, nil
×
3361
}
3362

3363
// upsertNodeFeatures updates the node's features node_features table. This
3364
// includes deleting any feature bits no longer present and inserting any new
3365
// feature bits. If the feature bit does not yet exist in the features table,
3366
// then an entry is created in that table first.
3367
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3368
        features *lnwire.FeatureVector) error {
×
3369

×
3370
        // Get any existing features for the node.
×
3371
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3372
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3373
                return err
×
3374
        }
×
3375

3376
        // Copy the nodes latest set of feature bits.
3377
        newFeatures := make(map[int32]struct{})
×
3378
        if features != nil {
×
3379
                for feature := range features.Features() {
×
3380
                        newFeatures[int32(feature)] = struct{}{}
×
3381
                }
×
3382
        }
3383

3384
        // For any current feature that already exists in the DB, remove it from
3385
        // the in-memory map. For any existing feature that does not exist in
3386
        // the in-memory map, delete it from the database.
3387
        for _, feature := range existingFeatures {
×
3388
                // The feature is still present, so there are no updates to be
×
3389
                // made.
×
3390
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3391
                        delete(newFeatures, feature.FeatureBit)
×
3392
                        continue
×
3393
                }
3394

3395
                // The feature is no longer present, so we remove it from the
3396
                // database.
3397
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3398
                        NodeID:     nodeID,
×
3399
                        FeatureBit: feature.FeatureBit,
×
3400
                })
×
3401
                if err != nil {
×
3402
                        return fmt.Errorf("unable to delete node(%d) "+
×
3403
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3404
                                err)
×
3405
                }
×
3406
        }
3407

3408
        // Any remaining entries in newFeatures are new features that need to be
3409
        // added to the database for the first time.
3410
        for feature := range newFeatures {
×
3411
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3412
                        NodeID:     nodeID,
×
3413
                        FeatureBit: feature,
×
3414
                })
×
3415
                if err != nil {
×
3416
                        return fmt.Errorf("unable to insert node(%d) "+
×
3417
                                "feature(%v): %w", nodeID, feature, err)
×
3418
                }
×
3419
        }
3420

3421
        return nil
×
3422
}
3423

3424
// fetchNodeFeatures fetches the features for a node with the given public key.
3425
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3426
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3427

×
3428
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3429
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3430
                        PubKey:  nodePub[:],
×
3431
                        Version: int16(ProtocolV1),
×
3432
                },
×
3433
        )
×
3434
        if err != nil {
×
3435
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3436
                        nodePub, err)
×
3437
        }
×
3438

3439
        features := lnwire.EmptyFeatureVector()
×
3440
        for _, bit := range rows {
×
3441
                features.Set(lnwire.FeatureBit(bit))
×
3442
        }
×
3443

3444
        return features, nil
×
3445
}
3446

3447
// dbAddressType is an enum type that represents the different address types
3448
// that we store in the node_addresses table. The address type determines how
3449
// the address is to be serialised/deserialize.
3450
type dbAddressType uint8
3451

3452
const (
3453
        addressTypeIPv4   dbAddressType = 1
3454
        addressTypeIPv6   dbAddressType = 2
3455
        addressTypeTorV2  dbAddressType = 3
3456
        addressTypeTorV3  dbAddressType = 4
3457
        addressTypeOpaque dbAddressType = math.MaxInt8
3458
)
3459

3460
// upsertNodeAddresses updates the node's addresses in the database. This
3461
// includes deleting any existing addresses and inserting the new set of
3462
// addresses. The deletion is necessary since the ordering of the addresses may
3463
// change, and we need to ensure that the database reflects the latest set of
3464
// addresses so that at the time of reconstructing the node announcement, the
3465
// order is preserved and the signature over the message remains valid.
3466
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3467
        addresses []net.Addr) error {
×
3468

×
3469
        // Delete any existing addresses for the node. This is required since
×
3470
        // even if the new set of addresses is the same, the ordering may have
×
3471
        // changed for a given address type.
×
3472
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3473
        if err != nil {
×
3474
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3475
                        nodeID, err)
×
3476
        }
×
3477

3478
        // Copy the nodes latest set of addresses.
3479
        newAddresses := map[dbAddressType][]string{
×
3480
                addressTypeIPv4:   {},
×
3481
                addressTypeIPv6:   {},
×
3482
                addressTypeTorV2:  {},
×
3483
                addressTypeTorV3:  {},
×
3484
                addressTypeOpaque: {},
×
3485
        }
×
3486
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3487
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3488
        }
×
3489

3490
        for _, address := range addresses {
×
3491
                switch addr := address.(type) {
×
3492
                case *net.TCPAddr:
×
3493
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3494
                                addAddr(addressTypeIPv4, addr)
×
3495
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3496
                                addAddr(addressTypeIPv6, addr)
×
3497
                        } else {
×
3498
                                return fmt.Errorf("unhandled IP address: %v",
×
3499
                                        addr)
×
3500
                        }
×
3501

3502
                case *tor.OnionAddr:
×
3503
                        switch len(addr.OnionService) {
×
3504
                        case tor.V2Len:
×
3505
                                addAddr(addressTypeTorV2, addr)
×
3506
                        case tor.V3Len:
×
3507
                                addAddr(addressTypeTorV3, addr)
×
3508
                        default:
×
3509
                                return fmt.Errorf("invalid length for a tor " +
×
3510
                                        "address")
×
3511
                        }
3512

3513
                case *lnwire.OpaqueAddrs:
×
3514
                        addAddr(addressTypeOpaque, addr)
×
3515

3516
                default:
×
3517
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3518
                }
3519
        }
3520

3521
        // Any remaining entries in newAddresses are new addresses that need to
3522
        // be added to the database for the first time.
3523
        for addrType, addrList := range newAddresses {
×
3524
                for position, addr := range addrList {
×
3525
                        err := db.InsertNodeAddress(
×
3526
                                ctx, sqlc.InsertNodeAddressParams{
×
3527
                                        NodeID:   nodeID,
×
3528
                                        Type:     int16(addrType),
×
3529
                                        Address:  addr,
×
3530
                                        Position: int32(position),
×
3531
                                },
×
3532
                        )
×
3533
                        if err != nil {
×
3534
                                return fmt.Errorf("unable to insert "+
×
3535
                                        "node(%d) address(%v): %w", nodeID,
×
3536
                                        addr, err)
×
3537
                        }
×
3538
                }
3539
        }
3540

3541
        return nil
×
3542
}
3543

3544
// getNodeAddresses fetches the addresses for a node with the given public key.
3545
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3546
        []net.Addr, error) {
×
3547

×
3548
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
3549
        // are returned in the same order as they were inserted.
×
3550
        rows, err := db.GetNodeAddressesByPubKey(
×
3551
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
3552
                        Version: int16(ProtocolV1),
×
3553
                        PubKey:  nodePub,
×
3554
                },
×
3555
        )
×
3556
        if err != nil {
×
3557
                return false, nil, err
×
3558
        }
×
3559

3560
        // GetNodeAddressesByPubKey uses a left join so there should always be
3561
        // at least one row returned if the node exists even if it has no
3562
        // addresses.
3563
        if len(rows) == 0 {
×
3564
                return false, nil, nil
×
3565
        }
×
3566

3567
        addresses := make([]net.Addr, 0, len(rows))
×
3568
        for _, addr := range rows {
×
3569
                if !(addr.Type.Valid && addr.Address.Valid) {
×
3570
                        continue
×
3571
                }
3572

3573
                address := addr.Address.String
×
3574

×
3575
                switch dbAddressType(addr.Type.Int16) {
×
3576
                case addressTypeIPv4:
×
3577
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
3578
                        if err != nil {
×
3579
                                return false, nil, nil
×
3580
                        }
×
3581
                        tcp.IP = tcp.IP.To4()
×
3582

×
3583
                        addresses = append(addresses, tcp)
×
3584

3585
                case addressTypeIPv6:
×
3586
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
3587
                        if err != nil {
×
3588
                                return false, nil, nil
×
3589
                        }
×
3590
                        addresses = append(addresses, tcp)
×
3591

3592
                case addressTypeTorV3, addressTypeTorV2:
×
3593
                        service, portStr, err := net.SplitHostPort(address)
×
3594
                        if err != nil {
×
3595
                                return false, nil, fmt.Errorf("unable to "+
×
3596
                                        "split tor v3 address: %v",
×
3597
                                        addr.Address)
×
3598
                        }
×
3599

3600
                        port, err := strconv.Atoi(portStr)
×
3601
                        if err != nil {
×
3602
                                return false, nil, err
×
3603
                        }
×
3604

3605
                        addresses = append(addresses, &tor.OnionAddr{
×
3606
                                OnionService: service,
×
3607
                                Port:         port,
×
3608
                        })
×
3609

3610
                case addressTypeOpaque:
×
3611
                        opaque, err := hex.DecodeString(address)
×
3612
                        if err != nil {
×
3613
                                return false, nil, fmt.Errorf("unable to "+
×
3614
                                        "decode opaque address: %v", addr)
×
3615
                        }
×
3616

3617
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
3618
                                Payload: opaque,
×
3619
                        })
×
3620

3621
                default:
×
3622
                        return false, nil, fmt.Errorf("unknown address "+
×
3623
                                "type: %v", addr.Type)
×
3624
                }
3625
        }
3626

3627
        return true, addresses, nil
×
3628
}
3629

3630
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3631
// database. This includes updating any existing types, inserting any new types,
3632
// and deleting any types that are no longer present.
3633
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3634
        nodeID int64, extraFields map[uint64][]byte) error {
×
3635

×
3636
        // Get any existing extra signed fields for the node.
×
3637
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3638
        if err != nil {
×
3639
                return err
×
3640
        }
×
3641

3642
        // Make a lookup map of the existing field types so that we can use it
3643
        // to keep track of any fields we should delete.
3644
        m := make(map[uint64]bool)
×
3645
        for _, field := range existingFields {
×
3646
                m[uint64(field.Type)] = true
×
3647
        }
×
3648

3649
        // For all the new fields, we'll upsert them and remove them from the
3650
        // map of existing fields.
3651
        for tlvType, value := range extraFields {
×
3652
                err = db.UpsertNodeExtraType(
×
3653
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3654
                                NodeID: nodeID,
×
3655
                                Type:   int64(tlvType),
×
3656
                                Value:  value,
×
3657
                        },
×
3658
                )
×
3659
                if err != nil {
×
3660
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3661
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3662
                }
×
3663

3664
                // Remove the field from the map of existing fields if it was
3665
                // present.
3666
                delete(m, tlvType)
×
3667
        }
3668

3669
        // For all the fields that are left in the map of existing fields, we'll
3670
        // delete them as they are no longer present in the new set of fields.
3671
        for tlvType := range m {
×
3672
                err = db.DeleteExtraNodeType(
×
3673
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3674
                                NodeID: nodeID,
×
3675
                                Type:   int64(tlvType),
×
3676
                        },
×
3677
                )
×
3678
                if err != nil {
×
3679
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3680
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3681
                }
×
3682
        }
3683

3684
        return nil
×
3685
}
3686

3687
// srcNodeInfo holds the information about the source node of the graph.
3688
type srcNodeInfo struct {
3689
        // id is the DB level ID of the source node entry in the "nodes" table.
3690
        id int64
3691

3692
        // pub is the public key of the source node.
3693
        pub route.Vertex
3694
}
3695

3696
// getSourceNode returns the DB node ID and pub key of the source node for the
3697
// specified protocol version.
3698
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3699
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3700

×
3701
        s.srcNodeMu.Lock()
×
3702
        defer s.srcNodeMu.Unlock()
×
3703

×
3704
        // If we already have the source node ID and pub key cached, then
×
3705
        // return them.
×
3706
        if info, ok := s.srcNodes[version]; ok {
×
3707
                return info.id, info.pub, nil
×
3708
        }
×
3709

3710
        var pubKey route.Vertex
×
3711

×
3712
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3713
        if err != nil {
×
3714
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3715
                        err)
×
3716
        }
×
3717

3718
        if len(nodes) == 0 {
×
3719
                return 0, pubKey, ErrSourceNodeNotSet
×
3720
        } else if len(nodes) > 1 {
×
3721
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3722
                        "protocol %s found", version)
×
3723
        }
×
3724

3725
        copy(pubKey[:], nodes[0].PubKey)
×
3726

×
3727
        s.srcNodes[version] = &srcNodeInfo{
×
3728
                id:  nodes[0].NodeID,
×
3729
                pub: pubKey,
×
3730
        }
×
3731

×
3732
        return nodes[0].NodeID, pubKey, nil
×
3733
}
3734

3735
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3736
// This then produces a map from TLV type to value. If the input is not a
3737
// valid TLV stream, then an error is returned.
3738
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3739
        r := bytes.NewReader(data)
×
3740

×
3741
        tlvStream, err := tlv.NewStream()
×
3742
        if err != nil {
×
3743
                return nil, err
×
3744
        }
×
3745

3746
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3747
        // pass it into the P2P decoding variant.
3748
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3749
        if err != nil {
×
3750
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3751
        }
×
3752
        if len(parsedTypes) == 0 {
×
3753
                return nil, nil
×
3754
        }
×
3755

3756
        records := make(map[uint64][]byte)
×
3757
        for k, v := range parsedTypes {
×
3758
                records[uint64(k)] = v
×
3759
        }
×
3760

3761
        return records, nil
×
3762
}
3763

3764
// insertChannel inserts a new channel record into the database.
3765
func insertChannel(ctx context.Context, db SQLQueries,
3766
        edge *models.ChannelEdgeInfo) error {
×
3767

×
3768
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3769

×
3770
        // Make sure that the channel doesn't already exist. We do this
×
3771
        // explicitly instead of relying on catching a unique constraint error
×
3772
        // because relying on SQL to throw that error would abort the entire
×
3773
        // batch of transactions.
×
3774
        _, err := db.GetChannelBySCID(
×
3775
                ctx, sqlc.GetChannelBySCIDParams{
×
3776
                        Scid:    chanIDB[:],
×
3777
                        Version: int16(ProtocolV1),
×
3778
                },
×
3779
        )
×
3780
        if err == nil {
×
3781
                return ErrEdgeAlreadyExist
×
3782
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3783
                return fmt.Errorf("unable to fetch channel: %w", err)
×
3784
        }
×
3785

3786
        // Make sure that at least a "shell" entry for each node is present in
3787
        // the nodes table.
3788
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3789
        if err != nil {
×
3790
                return fmt.Errorf("unable to create shell node: %w", err)
×
3791
        }
×
3792

3793
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3794
        if err != nil {
×
3795
                return fmt.Errorf("unable to create shell node: %w", err)
×
3796
        }
×
3797

3798
        var capacity sql.NullInt64
×
3799
        if edge.Capacity != 0 {
×
3800
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3801
        }
×
3802

3803
        createParams := sqlc.CreateChannelParams{
×
3804
                Version:     int16(ProtocolV1),
×
3805
                Scid:        chanIDB[:],
×
3806
                NodeID1:     node1DBID,
×
3807
                NodeID2:     node2DBID,
×
3808
                Outpoint:    edge.ChannelPoint.String(),
×
3809
                Capacity:    capacity,
×
3810
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3811
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3812
        }
×
3813

×
3814
        if edge.AuthProof != nil {
×
3815
                proof := edge.AuthProof
×
3816

×
3817
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3818
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3819
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3820
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3821
        }
×
3822

3823
        // Insert the new channel record.
3824
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3825
        if err != nil {
×
3826
                return err
×
3827
        }
×
3828

3829
        // Insert any channel features.
3830
        if len(edge.Features) != 0 {
×
3831
                chanFeatures := lnwire.NewRawFeatureVector()
×
3832
                err := chanFeatures.Decode(bytes.NewReader(edge.Features))
×
3833
                if err != nil {
×
3834
                        return err
×
3835
                }
×
3836

3837
                fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features)
×
3838
                for feature := range fv.Features() {
×
3839
                        err = db.InsertChannelFeature(
×
3840
                                ctx, sqlc.InsertChannelFeatureParams{
×
3841
                                        ChannelID:  dbChanID,
×
3842
                                        FeatureBit: int32(feature),
×
3843
                                },
×
3844
                        )
×
3845
                        if err != nil {
×
3846
                                return fmt.Errorf("unable to insert "+
×
3847
                                        "channel(%d) feature(%v): %w", dbChanID,
×
3848
                                        feature, err)
×
3849
                        }
×
3850
                }
3851
        }
3852

3853
        // Finally, insert any extra TLV fields in the channel announcement.
3854
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3855
        if err != nil {
×
3856
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
3857
                        err)
×
3858
        }
×
3859

3860
        for tlvType, value := range extra {
×
3861
                err := db.CreateChannelExtraType(
×
3862
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3863
                                ChannelID: dbChanID,
×
3864
                                Type:      int64(tlvType),
×
3865
                                Value:     value,
×
3866
                        },
×
3867
                )
×
3868
                if err != nil {
×
3869
                        return fmt.Errorf("unable to upsert channel(%d) extra "+
×
3870
                                "signed field(%v): %w", edge.ChannelID,
×
3871
                                tlvType, err)
×
3872
                }
×
3873
        }
3874

3875
        return nil
×
3876
}
3877

3878
// maybeCreateShellNode checks if a shell node entry exists for the
3879
// given public key. If it does not exist, then a new shell node entry is
3880
// created. The ID of the node is returned. A shell node only has a protocol
3881
// version and public key persisted.
3882
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
3883
        pubKey route.Vertex) (int64, error) {
×
3884

×
3885
        dbNode, err := db.GetNodeByPubKey(
×
3886
                ctx, sqlc.GetNodeByPubKeyParams{
×
3887
                        PubKey:  pubKey[:],
×
3888
                        Version: int16(ProtocolV1),
×
3889
                },
×
3890
        )
×
3891
        // The node exists. Return the ID.
×
3892
        if err == nil {
×
3893
                return dbNode.ID, nil
×
3894
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3895
                return 0, err
×
3896
        }
×
3897

3898
        // Otherwise, the node does not exist, so we create a shell entry for
3899
        // it.
3900
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3901
                Version: int16(ProtocolV1),
×
3902
                PubKey:  pubKey[:],
×
3903
        })
×
3904
        if err != nil {
×
3905
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3906
        }
×
3907

3908
        return id, nil
×
3909
}
3910

3911
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
3912
// the database. This includes deleting any existing types and then inserting
3913
// the new types.
3914
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3915
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
3916

×
3917
        // Delete all existing extra signed fields for the channel policy.
×
3918
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
3919
        if err != nil {
×
3920
                return fmt.Errorf("unable to delete "+
×
3921
                        "existing policy extra signed fields for policy %d: %w",
×
3922
                        chanPolicyID, err)
×
3923
        }
×
3924

3925
        // Insert all new extra signed fields for the channel policy.
3926
        for tlvType, value := range extraFields {
×
3927
                err = db.InsertChanPolicyExtraType(
×
3928
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3929
                                ChannelPolicyID: chanPolicyID,
×
3930
                                Type:            int64(tlvType),
×
3931
                                Value:           value,
×
3932
                        },
×
3933
                )
×
3934
                if err != nil {
×
3935
                        return fmt.Errorf("unable to insert "+
×
3936
                                "channel_policy(%d) extra signed field(%v): %w",
×
3937
                                chanPolicyID, tlvType, err)
×
3938
                }
×
3939
        }
3940

3941
        return nil
×
3942
}
3943

3944
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
3945
// provided dbChanRow and also fetches any other required information
3946
// to construct the edge info.
3947
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
3948
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1,
3949
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3950

×
3951
        if dbChan.Version != int16(ProtocolV1) {
×
3952
                return nil, fmt.Errorf("unsupported channel version: %d",
×
3953
                        dbChan.Version)
×
3954
        }
×
3955

3956
        fv, extras, err := getChanFeaturesAndExtras(
×
3957
                ctx, db, dbChanID,
×
3958
        )
×
3959
        if err != nil {
×
3960
                return nil, err
×
3961
        }
×
3962

3963
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
3964
        if err != nil {
×
3965
                return nil, err
×
3966
        }
×
3967

3968
        var featureBuf bytes.Buffer
×
3969
        if err := fv.Encode(&featureBuf); err != nil {
×
3970
                return nil, fmt.Errorf("unable to encode features: %w", err)
×
3971
        }
×
3972

3973
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3974
        if err != nil {
×
3975
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3976
                        "fields: %w", err)
×
3977
        }
×
3978
        if recs == nil {
×
3979
                recs = make([]byte, 0)
×
3980
        }
×
3981

3982
        var btcKey1, btcKey2 route.Vertex
×
3983
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
3984
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
3985

×
3986
        channel := &models.ChannelEdgeInfo{
×
3987
                ChainHash:        chain,
×
3988
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
3989
                NodeKey1Bytes:    node1,
×
3990
                NodeKey2Bytes:    node2,
×
3991
                BitcoinKey1Bytes: btcKey1,
×
3992
                BitcoinKey2Bytes: btcKey2,
×
3993
                ChannelPoint:     *op,
×
3994
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
3995
                Features:         featureBuf.Bytes(),
×
3996
                ExtraOpaqueData:  recs,
×
3997
        }
×
3998

×
3999
        // We always set all the signatures at the same time, so we can
×
4000
        // safely check if one signature is present to determine if we have the
×
4001
        // rest of the signatures for the auth proof.
×
4002
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4003
                channel.AuthProof = &models.ChannelAuthProof{
×
4004
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4005
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4006
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4007
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4008
                }
×
4009
        }
×
4010

4011
        return channel, nil
×
4012
}
4013

4014
// buildNodeVertices is a helper that converts raw node public keys
4015
// into route.Vertex instances.
4016
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4017
        route.Vertex, error) {
×
4018

×
4019
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4020
        if err != nil {
×
4021
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4022
                        "create vertex from node1 pubkey: %w", err)
×
4023
        }
×
4024

4025
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4026
        if err != nil {
×
4027
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4028
                        "create vertex from node2 pubkey: %w", err)
×
4029
        }
×
4030

4031
        return node1Vertex, node2Vertex, nil
×
4032
}
4033

4034
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
4035
// for a channel with the given ID.
4036
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
4037
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
4038

×
4039
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
4040
        if err != nil {
×
4041
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
4042
                        "features and extras: %w", err)
×
4043
        }
×
4044

4045
        var (
×
4046
                fv     = lnwire.EmptyFeatureVector()
×
4047
                extras = make(map[uint64][]byte)
×
4048
        )
×
4049
        for _, row := range rows {
×
4050
                if row.IsFeature {
×
4051
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
4052

×
4053
                        continue
×
4054
                }
4055

4056
                tlvType, ok := row.ExtraKey.(int64)
×
4057
                if !ok {
×
4058
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4059
                                "TLV type: %T", row.ExtraKey)
×
4060
                }
×
4061

4062
                valueBytes, ok := row.Value.([]byte)
×
4063
                if !ok {
×
4064
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4065
                                "Value: %T", row.Value)
×
4066
                }
×
4067

4068
                extras[uint64(tlvType)] = valueBytes
×
4069
        }
4070

4071
        return fv, extras, nil
×
4072
}
4073

4074
// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves
4075
// all the extra info required to build the complete models.ChannelEdgePolicy
4076
// types. It returns two policies, which may be nil if the provided
4077
// sqlc.ChannelPolicy records are nil.
4078
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4079
        dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1,
4080
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4081
        *models.ChannelEdgePolicy, error) {
×
4082

×
4083
        if dbPol1 == nil && dbPol2 == nil {
×
4084
                return nil, nil, nil
×
4085
        }
×
4086

4087
        var (
×
4088
                policy1ID int64
×
4089
                policy2ID int64
×
4090
        )
×
4091
        if dbPol1 != nil {
×
4092
                policy1ID = dbPol1.ID
×
4093
        }
×
4094
        if dbPol2 != nil {
×
4095
                policy2ID = dbPol2.ID
×
4096
        }
×
4097
        rows, err := db.GetChannelPolicyExtraTypes(
×
4098
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
4099
                        ID:   policy1ID,
×
4100
                        ID_2: policy2ID,
×
4101
                },
×
4102
        )
×
4103
        if err != nil {
×
4104
                return nil, nil, err
×
4105
        }
×
4106

4107
        var (
×
4108
                dbPol1Extras = make(map[uint64][]byte)
×
4109
                dbPol2Extras = make(map[uint64][]byte)
×
4110
        )
×
4111
        for _, row := range rows {
×
4112
                switch row.PolicyID {
×
4113
                case policy1ID:
×
4114
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
4115
                case policy2ID:
×
4116
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
4117
                default:
×
4118
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
4119
                                "in row: %v", row.PolicyID, row)
×
4120
                }
4121
        }
4122

4123
        var pol1, pol2 *models.ChannelEdgePolicy
×
4124
        if dbPol1 != nil {
×
4125
                pol1, err = buildChanPolicy(
×
4126
                        *dbPol1, channelID, dbPol1Extras, node2, true,
×
4127
                )
×
4128
                if err != nil {
×
4129
                        return nil, nil, err
×
4130
                }
×
4131
        }
4132
        if dbPol2 != nil {
×
4133
                pol2, err = buildChanPolicy(
×
4134
                        *dbPol2, channelID, dbPol2Extras, node1, false,
×
4135
                )
×
4136
                if err != nil {
×
4137
                        return nil, nil, err
×
4138
                }
×
4139
        }
4140

4141
        return pol1, pol2, nil
×
4142
}
4143

4144
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4145
// provided sqlc.ChannelPolicy and other required information.
4146
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
4147
        extras map[uint64][]byte, toNode route.Vertex,
4148
        isNode1 bool) (*models.ChannelEdgePolicy, error) {
×
4149

×
4150
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4151
        if err != nil {
×
4152
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4153
                        "fields: %w", err)
×
4154
        }
×
4155

4156
        var msgFlags lnwire.ChanUpdateMsgFlags
×
4157
        if dbPolicy.MaxHtlcMsat.Valid {
×
4158
                msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc
×
4159
        }
×
4160

4161
        var chanFlags lnwire.ChanUpdateChanFlags
×
4162
        if !isNode1 {
×
4163
                chanFlags |= lnwire.ChanUpdateDirection
×
4164
        }
×
4165
        if dbPolicy.Disabled.Bool {
×
4166
                chanFlags |= lnwire.ChanUpdateDisabled
×
4167
        }
×
4168

4169
        var inboundFee fn.Option[lnwire.Fee]
×
4170
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4171
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4172

×
4173
                inboundFee = fn.Some(lnwire.Fee{
×
4174
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4175
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4176
                })
×
4177
        }
×
4178

4179
        return &models.ChannelEdgePolicy{
×
4180
                SigBytes:  dbPolicy.Signature,
×
4181
                ChannelID: channelID,
×
4182
                LastUpdate: time.Unix(
×
4183
                        dbPolicy.LastUpdate.Int64, 0,
×
4184
                ),
×
4185
                MessageFlags:  msgFlags,
×
4186
                ChannelFlags:  chanFlags,
×
4187
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4188
                MinHTLC: lnwire.MilliSatoshi(
×
4189
                        dbPolicy.MinHtlcMsat,
×
4190
                ),
×
4191
                MaxHTLC: lnwire.MilliSatoshi(
×
4192
                        dbPolicy.MaxHtlcMsat.Int64,
×
4193
                ),
×
4194
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4195
                        dbPolicy.BaseFeeMsat,
×
4196
                ),
×
4197
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4198
                ToNode:                    toNode,
×
4199
                InboundFee:                inboundFee,
×
4200
                ExtraOpaqueData:           recs,
×
4201
        }, nil
×
4202
}
4203

4204
// buildNodes builds the models.LightningNode instances for the
4205
// given row which is expected to be a sqlc type that contains node information.
4206
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4207
        dbNode2 sqlc.Node) (*models.LightningNode, *models.LightningNode,
4208
        error) {
×
4209

×
4210
        node1, err := buildNode(ctx, db, &dbNode1)
×
4211
        if err != nil {
×
4212
                return nil, nil, err
×
4213
        }
×
4214

4215
        node2, err := buildNode(ctx, db, &dbNode2)
×
4216
        if err != nil {
×
4217
                return nil, nil, err
×
4218
        }
×
4219

4220
        return node1, node2, nil
×
4221
}
4222

4223
// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give
4224
// row which is expected to be a sqlc type that contains channel policy
4225
// information. It returns two policies, which may be nil if the policy
4226
// information is not present in the row.
4227
//
4228
//nolint:ll,dupl,funlen
4229
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
4230
        error) {
×
4231

×
4232
        var policy1, policy2 *sqlc.ChannelPolicy
×
4233
        switch r := row.(type) {
×
4234
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4235
                if r.Policy1ID.Valid {
×
4236
                        policy1 = &sqlc.ChannelPolicy{
×
4237
                                ID:                      r.Policy1ID.Int64,
×
4238
                                Version:                 r.Policy1Version.Int16,
×
4239
                                ChannelID:               r.Channel.ID,
×
4240
                                NodeID:                  r.Policy1NodeID.Int64,
×
4241
                                Timelock:                r.Policy1Timelock.Int32,
×
4242
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4243
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4244
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4245
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4246
                                LastUpdate:              r.Policy1LastUpdate,
×
4247
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4248
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4249
                                Disabled:                r.Policy1Disabled,
×
4250
                                Signature:               r.Policy1Signature,
×
4251
                        }
×
4252
                }
×
4253
                if r.Policy2ID.Valid {
×
4254
                        policy2 = &sqlc.ChannelPolicy{
×
4255
                                ID:                      r.Policy2ID.Int64,
×
4256
                                Version:                 r.Policy2Version.Int16,
×
4257
                                ChannelID:               r.Channel.ID,
×
4258
                                NodeID:                  r.Policy2NodeID.Int64,
×
4259
                                Timelock:                r.Policy2Timelock.Int32,
×
4260
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4261
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4262
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4263
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4264
                                LastUpdate:              r.Policy2LastUpdate,
×
4265
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4266
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4267
                                Disabled:                r.Policy2Disabled,
×
4268
                                Signature:               r.Policy2Signature,
×
4269
                        }
×
4270
                }
×
4271

4272
                return policy1, policy2, nil
×
4273

4274
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4275
                if r.Policy1ID.Valid {
×
4276
                        policy1 = &sqlc.ChannelPolicy{
×
4277
                                ID:                      r.Policy1ID.Int64,
×
4278
                                Version:                 r.Policy1Version.Int16,
×
4279
                                ChannelID:               r.Channel.ID,
×
4280
                                NodeID:                  r.Policy1NodeID.Int64,
×
4281
                                Timelock:                r.Policy1Timelock.Int32,
×
4282
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4283
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4284
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4285
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4286
                                LastUpdate:              r.Policy1LastUpdate,
×
4287
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4288
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4289
                                Disabled:                r.Policy1Disabled,
×
4290
                                Signature:               r.Policy1Signature,
×
4291
                        }
×
4292
                }
×
4293
                if r.Policy2ID.Valid {
×
4294
                        policy2 = &sqlc.ChannelPolicy{
×
4295
                                ID:                      r.Policy2ID.Int64,
×
4296
                                Version:                 r.Policy2Version.Int16,
×
4297
                                ChannelID:               r.Channel.ID,
×
4298
                                NodeID:                  r.Policy2NodeID.Int64,
×
4299
                                Timelock:                r.Policy2Timelock.Int32,
×
4300
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4301
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4302
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4303
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4304
                                LastUpdate:              r.Policy2LastUpdate,
×
4305
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4306
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4307
                                Disabled:                r.Policy2Disabled,
×
4308
                                Signature:               r.Policy2Signature,
×
4309
                        }
×
4310
                }
×
4311

4312
                return policy1, policy2, nil
×
4313

4314
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4315
                if r.Policy1ID.Valid {
×
4316
                        policy1 = &sqlc.ChannelPolicy{
×
4317
                                ID:                      r.Policy1ID.Int64,
×
4318
                                Version:                 r.Policy1Version.Int16,
×
4319
                                ChannelID:               r.Channel.ID,
×
4320
                                NodeID:                  r.Policy1NodeID.Int64,
×
4321
                                Timelock:                r.Policy1Timelock.Int32,
×
4322
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4323
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4324
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4325
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4326
                                LastUpdate:              r.Policy1LastUpdate,
×
4327
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4328
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4329
                                Disabled:                r.Policy1Disabled,
×
4330
                                Signature:               r.Policy1Signature,
×
4331
                        }
×
4332
                }
×
4333
                if r.Policy2ID.Valid {
×
4334
                        policy2 = &sqlc.ChannelPolicy{
×
4335
                                ID:                      r.Policy2ID.Int64,
×
4336
                                Version:                 r.Policy2Version.Int16,
×
4337
                                ChannelID:               r.Channel.ID,
×
4338
                                NodeID:                  r.Policy2NodeID.Int64,
×
4339
                                Timelock:                r.Policy2Timelock.Int32,
×
4340
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4341
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4342
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4343
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4344
                                LastUpdate:              r.Policy2LastUpdate,
×
4345
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4346
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4347
                                Disabled:                r.Policy2Disabled,
×
4348
                                Signature:               r.Policy2Signature,
×
4349
                        }
×
4350
                }
×
4351

4352
                return policy1, policy2, nil
×
4353

4354
        case sqlc.ListChannelsByNodeIDRow:
×
4355
                if r.Policy1ID.Valid {
×
4356
                        policy1 = &sqlc.ChannelPolicy{
×
4357
                                ID:                      r.Policy1ID.Int64,
×
4358
                                Version:                 r.Policy1Version.Int16,
×
4359
                                ChannelID:               r.Channel.ID,
×
4360
                                NodeID:                  r.Policy1NodeID.Int64,
×
4361
                                Timelock:                r.Policy1Timelock.Int32,
×
4362
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4363
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4364
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4365
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4366
                                LastUpdate:              r.Policy1LastUpdate,
×
4367
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4368
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4369
                                Disabled:                r.Policy1Disabled,
×
4370
                                Signature:               r.Policy1Signature,
×
4371
                        }
×
4372
                }
×
4373
                if r.Policy2ID.Valid {
×
4374
                        policy2 = &sqlc.ChannelPolicy{
×
4375
                                ID:                      r.Policy2ID.Int64,
×
4376
                                Version:                 r.Policy2Version.Int16,
×
4377
                                ChannelID:               r.Channel.ID,
×
4378
                                NodeID:                  r.Policy2NodeID.Int64,
×
4379
                                Timelock:                r.Policy2Timelock.Int32,
×
4380
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4381
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4382
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4383
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4384
                                LastUpdate:              r.Policy2LastUpdate,
×
4385
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4386
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4387
                                Disabled:                r.Policy2Disabled,
×
4388
                                Signature:               r.Policy2Signature,
×
4389
                        }
×
4390
                }
×
4391

4392
                return policy1, policy2, nil
×
4393

4394
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4395
                if r.Policy1ID.Valid {
×
4396
                        policy1 = &sqlc.ChannelPolicy{
×
4397
                                ID:                      r.Policy1ID.Int64,
×
4398
                                Version:                 r.Policy1Version.Int16,
×
4399
                                ChannelID:               r.Channel.ID,
×
4400
                                NodeID:                  r.Policy1NodeID.Int64,
×
4401
                                Timelock:                r.Policy1Timelock.Int32,
×
4402
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4403
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4404
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4405
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4406
                                LastUpdate:              r.Policy1LastUpdate,
×
4407
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4408
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4409
                                Disabled:                r.Policy1Disabled,
×
4410
                                Signature:               r.Policy1Signature,
×
4411
                        }
×
4412
                }
×
4413
                if r.Policy2ID.Valid {
×
4414
                        policy2 = &sqlc.ChannelPolicy{
×
4415
                                ID:                      r.Policy2ID.Int64,
×
4416
                                Version:                 r.Policy2Version.Int16,
×
4417
                                ChannelID:               r.Channel.ID,
×
4418
                                NodeID:                  r.Policy2NodeID.Int64,
×
4419
                                Timelock:                r.Policy2Timelock.Int32,
×
4420
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4421
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4422
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4423
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4424
                                LastUpdate:              r.Policy2LastUpdate,
×
4425
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4426
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4427
                                Disabled:                r.Policy2Disabled,
×
4428
                                Signature:               r.Policy2Signature,
×
4429
                        }
×
4430
                }
×
4431

4432
                return policy1, policy2, nil
×
4433
        default:
×
4434
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4435
                        "extractChannelPolicies: %T", r)
×
4436
        }
4437
}
4438

4439
// channelIDToBytes converts a channel ID (SCID) to a byte array
4440
// representation.
4441
func channelIDToBytes(channelID uint64) [8]byte {
×
4442
        var chanIDB [8]byte
×
4443
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4444

×
4445
        return chanIDB
×
4446
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc