• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 13211764208

08 Feb 2025 03:08AM UTC coverage: 49.288% (-9.5%) from 58.815%
13211764208

Pull #9489

github

calvinrzachman
itest: verify switchrpc server enforces send then track

We prevent the rpc server from allowing onion dispatches for
attempt IDs which have already been tracked by rpc clients.

This helps protect the client from leaking a duplicate onion
attempt. NOTE: This is not the only method for solving this
issue! The issue could be addressed via careful client side
programming which accounts for the uncertainty and async
nature of dispatching onions to a remote process via RPC.
This would require some lnd ChannelRouter changes for how
we intend to use these RPCs though.
Pull Request #9489: multi: add BuildOnion, SendOnion, and TrackOnion RPCs

474 of 990 new or added lines in 11 files covered. (47.88%)

27321 existing lines in 435 files now uncovered.

101192 of 205306 relevant lines covered (49.29%)

1.54 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

31.7
/lnwire/query_short_chan_ids.go
1
package lnwire
2

3
import (
4
        "bytes"
5
        "compress/zlib"
6
        "fmt"
7
        "io"
8
        "sort"
9
        "sync"
10

11
        "github.com/btcsuite/btcd/chaincfg/chainhash"
12
)
13

14
const (
15
        // maxZlibBufSize is the max number of bytes that we'll accept from a
16
        // zlib decoding instance. We do this in order to limit the total
17
        // amount of memory allocated during a decoding instance.
18
        maxZlibBufSize = 67413630
19
)
20

21
// ErrUnsortedSIDs is returned when decoding a QueryShortChannelID request whose
22
// items were not sorted.
23
type ErrUnsortedSIDs struct {
24
        prevSID ShortChannelID
25
        curSID  ShortChannelID
26
}
27

28
// Error returns a human-readable description of the error.
29
func (e ErrUnsortedSIDs) Error() string {
×
30
        return fmt.Sprintf("current sid: %v isn't greater than last sid: %v",
×
31
                e.curSID, e.prevSID)
×
32
}
×
33

34
// zlibDecodeMtx is a package level mutex that we'll use in order to ensure
35
// that we'll only attempt a single zlib decoding instance at a time. This
36
// allows us to also further bound our memory usage.
37
var zlibDecodeMtx sync.Mutex
38

39
// ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we
40
// came across an unknown short channel ID encoding, and therefore were unable
41
// to continue parsing.
UNCOV
42
func ErrUnknownShortChanIDEncoding(encoding QueryEncoding) error {
×
UNCOV
43
        return fmt.Errorf("unknown short chan id encoding: %v", encoding)
×
UNCOV
44
}
×
45

46
// QueryShortChanIDs is a message that allows the sender to query a set of
47
// channel announcement and channel update messages that correspond to the set
48
// of encoded short channel ID's. The encoding of the short channel ID's is
49
// detailed in the query message ensuring that the receiver knows how to
50
// properly decode each encode short channel ID which may be encoded using a
51
// compression format. The receiver should respond with a series of channel
52
// announcement and channel updates, finally sending a ReplyShortChanIDsEnd
53
// message.
54
type QueryShortChanIDs struct {
55
        // ChainHash denotes the target chain that we're querying for the
56
        // channel ID's of.
57
        ChainHash chainhash.Hash
58

59
        // EncodingType is a signal to the receiver of the message that
60
        // indicates exactly how the set of short channel ID's that follow have
61
        // been encoded.
62
        EncodingType QueryEncoding
63

64
        // ShortChanIDs is a slice of decoded short channel ID's.
65
        ShortChanIDs []ShortChannelID
66

67
        // ExtraData is the set of data that was appended to this message to
68
        // fill out the full maximum transport message size. These fields can
69
        // be used to specify optional data such as custom TLV fields.
70
        ExtraData ExtraOpaqueData
71

72
        // noSort indicates whether or not to sort the short channel ids before
73
        // writing them out.
74
        //
75
        // NOTE: This should only be used during testing.
76
        noSort bool
77
}
78

79
// NewQueryShortChanIDs creates a new QueryShortChanIDs message.
80
func NewQueryShortChanIDs(h chainhash.Hash, e QueryEncoding,
81
        s []ShortChannelID) *QueryShortChanIDs {
×
82

×
83
        return &QueryShortChanIDs{
×
84
                ChainHash:    h,
×
85
                EncodingType: e,
×
86
                ShortChanIDs: s,
×
87
        }
×
88
}
×
89

90
// A compile time check to ensure QueryShortChanIDs implements the
91
// lnwire.Message interface.
92
var _ Message = (*QueryShortChanIDs)(nil)
93

94
// Decode deserializes a serialized QueryShortChanIDs message stored in the
95
// passed io.Reader observing the specified protocol version.
96
//
97
// This is part of the lnwire.Message interface.
98
func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error {
3✔
99
        err := ReadElements(r, q.ChainHash[:])
3✔
100
        if err != nil {
3✔
UNCOV
101
                return err
×
UNCOV
102
        }
×
103

104
        q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r)
3✔
105
        if err != nil {
3✔
UNCOV
106
                return err
×
UNCOV
107
        }
×
108

109
        return q.ExtraData.Decode(r)
3✔
110
}
111

112
// decodeShortChanIDs decodes a set of short channel ID's that have been
113
// encoded. The first byte of the body details how the short chan ID's were
114
// encoded. We'll use this type to govern exactly how we go about encoding the
115
// set of short channel ID's.
116
func decodeShortChanIDs(r io.Reader) (QueryEncoding, []ShortChannelID, error) {
3✔
117
        // First, we'll attempt to read the number of bytes in the body of the
3✔
118
        // set of encoded short channel ID's.
3✔
119
        var numBytesResp uint16
3✔
120
        err := ReadElements(r, &numBytesResp)
3✔
121
        if err != nil {
3✔
UNCOV
122
                return 0, nil, err
×
UNCOV
123
        }
×
124

125
        if numBytesResp == 0 {
3✔
UNCOV
126
                return 0, nil, nil
×
UNCOV
127
        }
×
128

129
        queryBody := make([]byte, numBytesResp)
3✔
130
        if _, err := io.ReadFull(r, queryBody); err != nil {
3✔
UNCOV
131
                return 0, nil, err
×
UNCOV
132
        }
×
133

134
        // The first byte is the encoding type, so we'll extract that so we can
135
        // continue our parsing.
136
        encodingType := QueryEncoding(queryBody[0])
3✔
137

3✔
138
        // Before continuing, we'll snip off the first byte of the query body
3✔
139
        // as that was just the encoding type.
3✔
140
        queryBody = queryBody[1:]
3✔
141

3✔
142
        // Otherwise, depending on the encoding type, we'll decode the encode
3✔
143
        // short channel ID's in a different manner.
3✔
144
        switch encodingType {
3✔
145

146
        // In this encoding, we'll simply read a sort array of encoded short
147
        // channel ID's from the buffer.
148
        case EncodingSortedPlain:
3✔
149
                // If after extracting the encoding type, the number of
3✔
150
                // remaining bytes is not a whole multiple of the size of an
3✔
151
                // encoded short channel ID (8 bytes), then we'll return a
3✔
152
                // parsing error.
3✔
153
                if len(queryBody)%8 != 0 {
3✔
UNCOV
154
                        return 0, nil, fmt.Errorf("whole number of short "+
×
UNCOV
155
                                "chan ID's cannot be encoded in len=%v",
×
UNCOV
156
                                len(queryBody))
×
UNCOV
157
                }
×
158

159
                // As each short channel ID is encoded as 8 bytes, we can
160
                // compute the number of bytes encoded based on the size of the
161
                // query body.
162
                numShortChanIDs := len(queryBody) / 8
3✔
163
                if numShortChanIDs == 0 {
6✔
164
                        return encodingType, nil, nil
3✔
165
                }
3✔
166

167
                // Finally, we'll read out the exact number of short channel
168
                // ID's to conclude our parsing.
169
                shortChanIDs := make([]ShortChannelID, numShortChanIDs)
3✔
170
                bodyReader := bytes.NewReader(queryBody)
3✔
171
                var lastChanID ShortChannelID
3✔
172
                for i := 0; i < numShortChanIDs; i++ {
6✔
173
                        if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil {
3✔
174
                                return 0, nil, fmt.Errorf("unable to parse "+
×
175
                                        "short chan ID: %v", err)
×
176
                        }
×
177

178
                        // We'll ensure that this short chan ID is greater than
179
                        // the last one. This is a requirement within the
180
                        // encoding, and if violated can aide us in detecting
181
                        // malicious payloads. This can only be true starting
182
                        // at the second chanID.
183
                        cid := shortChanIDs[i]
3✔
184
                        if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
3✔
UNCOV
185
                                return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
×
UNCOV
186
                        }
×
187
                        lastChanID = cid
3✔
188
                }
189

190
                return encodingType, shortChanIDs, nil
3✔
191

192
        // In this encoding, we'll use zlib to decode the compressed payload.
193
        // However, we'll pay attention to ensure that we don't open our selves
194
        // up to a memory exhaustion attack.
UNCOV
195
        case EncodingSortedZlib:
×
UNCOV
196
                // We'll obtain an ultimately release the zlib decode mutex.
×
UNCOV
197
                // This guards us against allocating too much memory to decode
×
UNCOV
198
                // each instance from concurrent peers.
×
UNCOV
199
                zlibDecodeMtx.Lock()
×
UNCOV
200
                defer zlibDecodeMtx.Unlock()
×
UNCOV
201

×
UNCOV
202
                // At this point, if there's no body remaining, then only the encoding
×
UNCOV
203
                // type was specified, meaning that there're no further bytes to be
×
UNCOV
204
                // parsed.
×
UNCOV
205
                if len(queryBody) == 0 {
×
UNCOV
206
                        return encodingType, nil, nil
×
UNCOV
207
                }
×
208

209
                // Before we start to decode, we'll create a limit reader over
210
                // the current reader. This will ensure that we can control how
211
                // much memory we're allocating during the decoding process.
UNCOV
212
                limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{
×
UNCOV
213
                        R: bytes.NewReader(queryBody),
×
UNCOV
214
                        N: maxZlibBufSize,
×
UNCOV
215
                })
×
UNCOV
216
                if err != nil {
×
UNCOV
217
                        return 0, nil, fmt.Errorf("unable to create zlib "+
×
UNCOV
218
                                "reader: %w", err)
×
UNCOV
219
                }
×
220

UNCOV
221
                var (
×
UNCOV
222
                        shortChanIDs []ShortChannelID
×
UNCOV
223
                        lastChanID   ShortChannelID
×
UNCOV
224
                        i            int
×
UNCOV
225
                )
×
UNCOV
226
                for {
×
UNCOV
227
                        // We'll now attempt to read the next short channel ID
×
UNCOV
228
                        // encoded in the payload.
×
UNCOV
229
                        var cid ShortChannelID
×
UNCOV
230
                        err := ReadElements(limitedDecompressor, &cid)
×
UNCOV
231

×
UNCOV
232
                        switch {
×
233
                        // If we get an EOF error, then that either means we've
234
                        // read all that's contained in the buffer, or have hit
235
                        // our limit on the number of bytes we'll read. In
236
                        // either case, we'll return what we have so far.
UNCOV
237
                        case err == io.ErrUnexpectedEOF || err == io.EOF:
×
UNCOV
238
                                return encodingType, shortChanIDs, nil
×
239

240
                        // Otherwise, we hit some other sort of error, possibly
241
                        // an invalid payload, so we'll exit early with the
242
                        // error.
UNCOV
243
                        case err != nil:
×
UNCOV
244
                                return 0, nil, fmt.Errorf("unable to "+
×
UNCOV
245
                                        "deflate next short chan "+
×
UNCOV
246
                                        "ID: %v", err)
×
247
                        }
248

249
                        // We successfully read the next ID, so we'll collect
250
                        // that in the set of final ID's to return.
UNCOV
251
                        shortChanIDs = append(shortChanIDs, cid)
×
UNCOV
252

×
UNCOV
253
                        // Finally, we'll ensure that this short chan ID is
×
UNCOV
254
                        // greater than the last one. This is a requirement
×
UNCOV
255
                        // within the encoding, and if violated can aide us in
×
UNCOV
256
                        // detecting malicious payloads. This can only be true
×
UNCOV
257
                        // starting at the second chanID.
×
UNCOV
258
                        if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
×
UNCOV
259
                                return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
×
UNCOV
260
                        }
×
261

UNCOV
262
                        lastChanID = cid
×
UNCOV
263
                        i++
×
264
                }
265

UNCOV
266
        default:
×
UNCOV
267
                // If we've been sent an encoding type that we don't know of,
×
UNCOV
268
                // then we'll return a parsing error as we can't continue if
×
UNCOV
269
                // we're unable to encode them.
×
UNCOV
270
                return 0, nil, ErrUnknownShortChanIDEncoding(encodingType)
×
271
        }
272
}
273

274
// Encode serializes the target QueryShortChanIDs into the passed io.Writer
275
// observing the protocol version specified.
276
//
277
// This is part of the lnwire.Message interface.
278
func (q *QueryShortChanIDs) Encode(w *bytes.Buffer, pver uint32) error {
3✔
279
        // First, we'll write out the chain hash.
3✔
280
        if err := WriteBytes(w, q.ChainHash[:]); err != nil {
3✔
281
                return err
×
282
        }
×
283

284
        // For both of the current encoding types, the channel ID's are to be
285
        // sorted in place, so we'll do that now. The sorting is applied unless
286
        // we were specifically requested not to for testing purposes.
287
        if !q.noSort {
6✔
288
                sort.Slice(q.ShortChanIDs, func(i, j int) bool {
6✔
289
                        return q.ShortChanIDs[i].ToUint64() <
3✔
290
                                q.ShortChanIDs[j].ToUint64()
3✔
291
                })
3✔
292
        }
293

294
        // Base on our encoding type, we'll write out the set of short channel
295
        // ID's.
296
        err := encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs)
3✔
297
        if err != nil {
3✔
298
                return err
×
299
        }
×
300

301
        return WriteBytes(w, q.ExtraData)
3✔
302
}
303

304
// encodeShortChanIDs encodes the passed short channel ID's into the passed
305
// io.Writer, respecting the specified encoding type.
306
func encodeShortChanIDs(w *bytes.Buffer, encodingType QueryEncoding,
307
        shortChanIDs []ShortChannelID) error {
3✔
308

3✔
309
        switch encodingType {
3✔
310

311
        // In this encoding, we'll simply write a sorted array of encoded short
312
        // channel ID's from the buffer.
313
        case EncodingSortedPlain:
3✔
314
                // First, we'll write out the number of bytes of the query
3✔
315
                // body. We add 1 as the response will have the encoding type
3✔
316
                // prepended to it.
3✔
317
                numBytesBody := uint16(len(shortChanIDs)*8) + 1
3✔
318
                if err := WriteUint16(w, numBytesBody); err != nil {
3✔
319
                        return err
×
320
                }
×
321

322
                // We'll then write out the encoding that that follows the
323
                // actual encoded short channel ID's.
324
                err := WriteQueryEncoding(w, encodingType)
3✔
325
                if err != nil {
3✔
326
                        return err
×
327
                }
×
328

329
                // Now that we know they're sorted, we can write out each short
330
                // channel ID to the buffer.
331
                for _, chanID := range shortChanIDs {
6✔
332
                        if err := WriteShortChannelID(w, chanID); err != nil {
3✔
333
                                return fmt.Errorf("unable to write short chan "+
×
334
                                        "ID: %v", err)
×
335
                        }
×
336
                }
337

338
                return nil
3✔
339

340
        // For this encoding we'll first write out a serialized version of all
341
        // the channel ID's into a buffer, then zlib encode that. The final
342
        // payload is what we'll write out to the passed io.Writer.
343
        //
344
        // TODO(roasbeef): assumes the caller knows the proper chunk size to
345
        // pass to avoid bin-packing here
UNCOV
346
        case EncodingSortedZlib:
×
UNCOV
347
                // If we don't have anything at all to write, then we'll write
×
UNCOV
348
                // an empty payload so we don't include things like the zlib
×
UNCOV
349
                // header when the remote party is expecting no actual short
×
UNCOV
350
                // channel IDs.
×
UNCOV
351
                var compressedPayload []byte
×
UNCOV
352
                if len(shortChanIDs) > 0 {
×
UNCOV
353
                        // We'll make a new write buffer to hold the bytes of
×
UNCOV
354
                        // shortChanIDs.
×
UNCOV
355
                        var wb bytes.Buffer
×
UNCOV
356

×
UNCOV
357
                        // Next, we'll write out all the channel ID's directly
×
UNCOV
358
                        // into the zlib writer, which will do compressing on
×
UNCOV
359
                        // the fly.
×
UNCOV
360
                        for _, chanID := range shortChanIDs {
×
UNCOV
361
                                err := WriteShortChannelID(&wb, chanID)
×
UNCOV
362
                                if err != nil {
×
363
                                        return fmt.Errorf(
×
364
                                                "unable to write short chan "+
×
365
                                                        "ID: %v", err,
×
366
                                        )
×
367
                                }
×
368
                        }
369

370
                        // With shortChanIDs written into wb, we'll create a
371
                        // zlib writer and write all the compressed bytes.
UNCOV
372
                        var zlibBuffer bytes.Buffer
×
UNCOV
373
                        zlibWriter := zlib.NewWriter(&zlibBuffer)
×
UNCOV
374

×
UNCOV
375
                        if _, err := zlibWriter.Write(wb.Bytes()); err != nil {
×
376
                                return fmt.Errorf(
×
377
                                        "unable to write compressed short chan"+
×
378
                                                "ID: %w", err)
×
379
                        }
×
380

381
                        // Now that we've written all the elements, we'll
382
                        // ensure the compressed stream is written to the
383
                        // underlying buffer.
UNCOV
384
                        if err := zlibWriter.Close(); err != nil {
×
385
                                return fmt.Errorf("unable to finalize "+
×
386
                                        "compression: %v", err)
×
387
                        }
×
388

UNCOV
389
                        compressedPayload = zlibBuffer.Bytes()
×
390
                }
391

392
                // Now that we have all the items compressed, we can compute
393
                // what the total payload size will be. We add one to account
394
                // for the byte to encode the type.
395
                //
396
                // If we don't have any actual bytes to write, then we'll end
397
                // up emitting one byte for the length, followed by the
398
                // encoding type, and nothing more. The spec isn't 100% clear
399
                // in this area, but we do this as this is what most of the
400
                // other implementations do.
UNCOV
401
                numBytesBody := len(compressedPayload) + 1
×
UNCOV
402

×
UNCOV
403
                // Finally, we can write out the number of bytes, the
×
UNCOV
404
                // compression type, and finally the buffer itself.
×
UNCOV
405
                if err := WriteUint16(w, uint16(numBytesBody)); err != nil {
×
406
                        return err
×
407
                }
×
UNCOV
408
                err := WriteQueryEncoding(w, encodingType)
×
UNCOV
409
                if err != nil {
×
410
                        return err
×
411
                }
×
412

UNCOV
413
                return WriteBytes(w, compressedPayload)
×
414

415
        default:
×
416
                // If we're trying to encode with an encoding type that we
×
417
                // don't know of, then we'll return a parsing error as we can't
×
418
                // continue if we're unable to encode them.
×
419
                return ErrUnknownShortChanIDEncoding(encodingType)
×
420
        }
421
}
422

423
// MsgType returns the integer uniquely identifying this message type on the
424
// wire.
425
//
426
// This is part of the lnwire.Message interface.
427
func (q *QueryShortChanIDs) MsgType() MessageType {
3✔
428
        return MsgQueryShortChanIDs
3✔
429
}
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc