• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 13980275562

20 Mar 2025 10:06PM UTC coverage: 58.6% (-10.2%) from 68.789%
13980275562

Pull #9623

github

web-flow
Merge b9b960345 into 09b674508
Pull Request #9623: Size msg test msg

0 of 1518 new or added lines in 42 files covered. (0.0%)

26603 existing lines in 443 files now uncovered.

96807 of 165200 relevant lines covered (58.6%)

1.82 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

27.1
/lnwire/query_short_chan_ids.go
1
package lnwire
2

3
import (
4
        "bytes"
5
        "compress/zlib"
6
        "fmt"
7
        "io"
8
        "sort"
9
        "sync"
10

11
        "github.com/btcsuite/btcd/chaincfg/chainhash"
12
        "pgregory.net/rapid"
13
)
14

15
const (
16
        // maxZlibBufSize is the max number of bytes that we'll accept from a
17
        // zlib decoding instance. We do this in order to limit the total
18
        // amount of memory allocated during a decoding instance.
19
        maxZlibBufSize = 67413630
20
)
21

22
// ErrUnsortedSIDs is returned when decoding a QueryShortChannelID request whose
23
// items were not sorted.
24
type ErrUnsortedSIDs struct {
25
        prevSID ShortChannelID
26
        curSID  ShortChannelID
27
}
28

29
// Error returns a human-readable description of the error.
30
func (e ErrUnsortedSIDs) Error() string {
×
31
        return fmt.Sprintf("current sid: %v isn't greater than last sid: %v",
×
32
                e.curSID, e.prevSID)
×
33
}
×
34

35
// zlibDecodeMtx is a package level mutex that we'll use in order to ensure
36
// that we'll only attempt a single zlib decoding instance at a time. This
37
// allows us to also further bound our memory usage.
38
var zlibDecodeMtx sync.Mutex
39

40
// ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we
41
// came across an unknown short channel ID encoding, and therefore were unable
42
// to continue parsing.
UNCOV
43
func ErrUnknownShortChanIDEncoding(encoding QueryEncoding) error {
×
UNCOV
44
        return fmt.Errorf("unknown short chan id encoding: %v", encoding)
×
UNCOV
45
}
×
46

47
// QueryShortChanIDs is a message that allows the sender to query a set of
48
// channel announcement and channel update messages that correspond to the set
49
// of encoded short channel ID's. The encoding of the short channel ID's is
50
// detailed in the query message ensuring that the receiver knows how to
51
// properly decode each encode short channel ID which may be encoded using a
52
// compression format. The receiver should respond with a series of channel
53
// announcement and channel updates, finally sending a ReplyShortChanIDsEnd
54
// message.
55
type QueryShortChanIDs struct {
56
        // ChainHash denotes the target chain that we're querying for the
57
        // channel ID's of.
58
        ChainHash chainhash.Hash
59

60
        // EncodingType is a signal to the receiver of the message that
61
        // indicates exactly how the set of short channel ID's that follow have
62
        // been encoded.
63
        EncodingType QueryEncoding
64

65
        // ShortChanIDs is a slice of decoded short channel ID's.
66
        ShortChanIDs []ShortChannelID
67

68
        // ExtraData is the set of data that was appended to this message to
69
        // fill out the full maximum transport message size. These fields can
70
        // be used to specify optional data such as custom TLV fields.
71
        ExtraData ExtraOpaqueData
72

73
        // noSort indicates whether or not to sort the short channel ids before
74
        // writing them out.
75
        //
76
        // NOTE: This should only be used during testing.
77
        noSort bool
78
}
79

80
// NewQueryShortChanIDs creates a new QueryShortChanIDs message.
81
func NewQueryShortChanIDs(h chainhash.Hash, e QueryEncoding,
82
        s []ShortChannelID) *QueryShortChanIDs {
×
83

×
84
        return &QueryShortChanIDs{
×
85
                ChainHash:    h,
×
86
                EncodingType: e,
×
87
                ShortChanIDs: s,
×
88
        }
×
89
}
×
90

91
// A compile time check to ensure QueryShortChanIDs implements the
92
// lnwire.Message interface.
93
var _ Message = (*QueryShortChanIDs)(nil)
94

95
// A compile time check to ensure QueryShortChanIDs implements the lnwire.SizeableMessage
96
// interface.
97
var _ SizeableMessage = (*QueryShortChanIDs)(nil)
98

99
// A compile time check to ensure QueryShortChanIDs implements the lnwire.TestMessage
100
// interface.
101
var _ TestMessage = (*QueryShortChanIDs)(nil)
102

103
// Decode deserializes a serialized QueryShortChanIDs message stored in the
104
// passed io.Reader observing the specified protocol version.
105
//
106
// This is part of the lnwire.Message interface.
107
func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error {
3✔
108
        err := ReadElements(r, q.ChainHash[:])
3✔
109
        if err != nil {
3✔
UNCOV
110
                return err
×
UNCOV
111
        }
×
112

113
        q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r)
3✔
114
        if err != nil {
3✔
UNCOV
115
                return err
×
UNCOV
116
        }
×
117

118
        return q.ExtraData.Decode(r)
3✔
119
}
120

121
// decodeShortChanIDs decodes a set of short channel ID's that have been
122
// encoded. The first byte of the body details how the short chan ID's were
123
// encoded. We'll use this type to govern exactly how we go about encoding the
124
// set of short channel ID's.
125
func decodeShortChanIDs(r io.Reader) (QueryEncoding, []ShortChannelID, error) {
3✔
126
        // First, we'll attempt to read the number of bytes in the body of the
3✔
127
        // set of encoded short channel ID's.
3✔
128
        var numBytesResp uint16
3✔
129
        err := ReadElements(r, &numBytesResp)
3✔
130
        if err != nil {
3✔
UNCOV
131
                return 0, nil, err
×
UNCOV
132
        }
×
133

134
        if numBytesResp == 0 {
3✔
UNCOV
135
                return 0, nil, nil
×
UNCOV
136
        }
×
137

138
        queryBody := make([]byte, numBytesResp)
3✔
139
        if _, err := io.ReadFull(r, queryBody); err != nil {
3✔
UNCOV
140
                return 0, nil, err
×
UNCOV
141
        }
×
142

143
        // The first byte is the encoding type, so we'll extract that so we can
144
        // continue our parsing.
145
        encodingType := QueryEncoding(queryBody[0])
3✔
146

3✔
147
        // Before continuing, we'll snip off the first byte of the query body
3✔
148
        // as that was just the encoding type.
3✔
149
        queryBody = queryBody[1:]
3✔
150

3✔
151
        // Otherwise, depending on the encoding type, we'll decode the encode
3✔
152
        // short channel ID's in a different manner.
3✔
153
        switch encodingType {
3✔
154

155
        // In this encoding, we'll simply read a sort array of encoded short
156
        // channel ID's from the buffer.
157
        case EncodingSortedPlain:
3✔
158
                // If after extracting the encoding type, the number of
3✔
159
                // remaining bytes is not a whole multiple of the size of an
3✔
160
                // encoded short channel ID (8 bytes), then we'll return a
3✔
161
                // parsing error.
3✔
162
                if len(queryBody)%8 != 0 {
3✔
UNCOV
163
                        return 0, nil, fmt.Errorf("whole number of short "+
×
UNCOV
164
                                "chan ID's cannot be encoded in len=%v",
×
UNCOV
165
                                len(queryBody))
×
UNCOV
166
                }
×
167

168
                // As each short channel ID is encoded as 8 bytes, we can
169
                // compute the number of bytes encoded based on the size of the
170
                // query body.
171
                numShortChanIDs := len(queryBody) / 8
3✔
172
                if numShortChanIDs == 0 {
6✔
173
                        return encodingType, nil, nil
3✔
174
                }
3✔
175

176
                // Finally, we'll read out the exact number of short channel
177
                // ID's to conclude our parsing.
178
                shortChanIDs := make([]ShortChannelID, numShortChanIDs)
3✔
179
                bodyReader := bytes.NewReader(queryBody)
3✔
180
                var lastChanID ShortChannelID
3✔
181
                for i := 0; i < numShortChanIDs; i++ {
6✔
182
                        if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil {
3✔
183
                                return 0, nil, fmt.Errorf("unable to parse "+
×
184
                                        "short chan ID: %v", err)
×
185
                        }
×
186

187
                        // We'll ensure that this short chan ID is greater than
188
                        // the last one. This is a requirement within the
189
                        // encoding, and if violated can aide us in detecting
190
                        // malicious payloads. This can only be true starting
191
                        // at the second chanID.
192
                        cid := shortChanIDs[i]
3✔
193
                        if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
3✔
UNCOV
194
                                return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
×
UNCOV
195
                        }
×
196
                        lastChanID = cid
3✔
197
                }
198

199
                return encodingType, shortChanIDs, nil
3✔
200

201
        // In this encoding, we'll use zlib to decode the compressed payload.
202
        // However, we'll pay attention to ensure that we don't open our selves
203
        // up to a memory exhaustion attack.
UNCOV
204
        case EncodingSortedZlib:
×
UNCOV
205
                // We'll obtain an ultimately release the zlib decode mutex.
×
UNCOV
206
                // This guards us against allocating too much memory to decode
×
UNCOV
207
                // each instance from concurrent peers.
×
UNCOV
208
                zlibDecodeMtx.Lock()
×
UNCOV
209
                defer zlibDecodeMtx.Unlock()
×
UNCOV
210

×
UNCOV
211
                // At this point, if there's no body remaining, then only the encoding
×
UNCOV
212
                // type was specified, meaning that there're no further bytes to be
×
UNCOV
213
                // parsed.
×
UNCOV
214
                if len(queryBody) == 0 {
×
UNCOV
215
                        return encodingType, nil, nil
×
UNCOV
216
                }
×
217

218
                // Before we start to decode, we'll create a limit reader over
219
                // the current reader. This will ensure that we can control how
220
                // much memory we're allocating during the decoding process.
UNCOV
221
                limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{
×
UNCOV
222
                        R: bytes.NewReader(queryBody),
×
UNCOV
223
                        N: maxZlibBufSize,
×
UNCOV
224
                })
×
UNCOV
225
                if err != nil {
×
UNCOV
226
                        return 0, nil, fmt.Errorf("unable to create zlib "+
×
UNCOV
227
                                "reader: %w", err)
×
UNCOV
228
                }
×
229

UNCOV
230
                var (
×
UNCOV
231
                        shortChanIDs []ShortChannelID
×
UNCOV
232
                        lastChanID   ShortChannelID
×
UNCOV
233
                        i            int
×
UNCOV
234
                )
×
UNCOV
235
                for {
×
UNCOV
236
                        // We'll now attempt to read the next short channel ID
×
UNCOV
237
                        // encoded in the payload.
×
UNCOV
238
                        var cid ShortChannelID
×
UNCOV
239
                        err := ReadElements(limitedDecompressor, &cid)
×
UNCOV
240

×
UNCOV
241
                        switch {
×
242
                        // If we get an EOF error, then that either means we've
243
                        // read all that's contained in the buffer, or have hit
244
                        // our limit on the number of bytes we'll read. In
245
                        // either case, we'll return what we have so far.
UNCOV
246
                        case err == io.ErrUnexpectedEOF || err == io.EOF:
×
UNCOV
247
                                return encodingType, shortChanIDs, nil
×
248

249
                        // Otherwise, we hit some other sort of error, possibly
250
                        // an invalid payload, so we'll exit early with the
251
                        // error.
UNCOV
252
                        case err != nil:
×
UNCOV
253
                                return 0, nil, fmt.Errorf("unable to "+
×
UNCOV
254
                                        "deflate next short chan "+
×
UNCOV
255
                                        "ID: %v", err)
×
256
                        }
257

258
                        // We successfully read the next ID, so we'll collect
259
                        // that in the set of final ID's to return.
UNCOV
260
                        shortChanIDs = append(shortChanIDs, cid)
×
UNCOV
261

×
UNCOV
262
                        // Finally, we'll ensure that this short chan ID is
×
UNCOV
263
                        // greater than the last one. This is a requirement
×
UNCOV
264
                        // within the encoding, and if violated can aide us in
×
UNCOV
265
                        // detecting malicious payloads. This can only be true
×
UNCOV
266
                        // starting at the second chanID.
×
UNCOV
267
                        if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
×
UNCOV
268
                                return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
×
UNCOV
269
                        }
×
270

UNCOV
271
                        lastChanID = cid
×
UNCOV
272
                        i++
×
273
                }
274

UNCOV
275
        default:
×
UNCOV
276
                // If we've been sent an encoding type that we don't know of,
×
UNCOV
277
                // then we'll return a parsing error as we can't continue if
×
UNCOV
278
                // we're unable to encode them.
×
UNCOV
279
                return 0, nil, ErrUnknownShortChanIDEncoding(encodingType)
×
280
        }
281
}
282

283
// Encode serializes the target QueryShortChanIDs into the passed io.Writer
284
// observing the protocol version specified.
285
//
286
// This is part of the lnwire.Message interface.
287
func (q *QueryShortChanIDs) Encode(w *bytes.Buffer, pver uint32) error {
3✔
288
        // First, we'll write out the chain hash.
3✔
289
        if err := WriteBytes(w, q.ChainHash[:]); err != nil {
3✔
290
                return err
×
291
        }
×
292

293
        // For both of the current encoding types, the channel ID's are to be
294
        // sorted in place, so we'll do that now. The sorting is applied unless
295
        // we were specifically requested not to for testing purposes.
296
        if !q.noSort {
6✔
297
                sort.Slice(q.ShortChanIDs, func(i, j int) bool {
6✔
298
                        return q.ShortChanIDs[i].ToUint64() <
3✔
299
                                q.ShortChanIDs[j].ToUint64()
3✔
300
                })
3✔
301
        }
302

303
        // Base on our encoding type, we'll write out the set of short channel
304
        // ID's.
305
        err := encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs)
3✔
306
        if err != nil {
3✔
307
                return err
×
308
        }
×
309

310
        return WriteBytes(w, q.ExtraData)
3✔
311
}
312

313
// encodeShortChanIDs encodes the passed short channel ID's into the passed
314
// io.Writer, respecting the specified encoding type.
315
func encodeShortChanIDs(w *bytes.Buffer, encodingType QueryEncoding,
316
        shortChanIDs []ShortChannelID) error {
3✔
317

3✔
318
        switch encodingType {
3✔
319

320
        // In this encoding, we'll simply write a sorted array of encoded short
321
        // channel ID's from the buffer.
322
        case EncodingSortedPlain:
3✔
323
                // First, we'll write out the number of bytes of the query
3✔
324
                // body. We add 1 as the response will have the encoding type
3✔
325
                // prepended to it.
3✔
326
                numBytesBody := uint16(len(shortChanIDs)*8) + 1
3✔
327
                if err := WriteUint16(w, numBytesBody); err != nil {
3✔
328
                        return err
×
329
                }
×
330

331
                // We'll then write out the encoding that that follows the
332
                // actual encoded short channel ID's.
333
                err := WriteQueryEncoding(w, encodingType)
3✔
334
                if err != nil {
3✔
335
                        return err
×
336
                }
×
337

338
                // Now that we know they're sorted, we can write out each short
339
                // channel ID to the buffer.
340
                for _, chanID := range shortChanIDs {
6✔
341
                        if err := WriteShortChannelID(w, chanID); err != nil {
3✔
342
                                return fmt.Errorf("unable to write short chan "+
×
343
                                        "ID: %v", err)
×
344
                        }
×
345
                }
346

347
                return nil
3✔
348

349
        // For this encoding we'll first write out a serialized version of all
350
        // the channel ID's into a buffer, then zlib encode that. The final
351
        // payload is what we'll write out to the passed io.Writer.
352
        //
353
        // TODO(roasbeef): assumes the caller knows the proper chunk size to
354
        // pass to avoid bin-packing here
UNCOV
355
        case EncodingSortedZlib:
×
UNCOV
356
                // If we don't have anything at all to write, then we'll write
×
UNCOV
357
                // an empty payload so we don't include things like the zlib
×
UNCOV
358
                // header when the remote party is expecting no actual short
×
UNCOV
359
                // channel IDs.
×
UNCOV
360
                var compressedPayload []byte
×
UNCOV
361
                if len(shortChanIDs) > 0 {
×
UNCOV
362
                        // We'll make a new write buffer to hold the bytes of
×
UNCOV
363
                        // shortChanIDs.
×
UNCOV
364
                        var wb bytes.Buffer
×
UNCOV
365

×
UNCOV
366
                        // Next, we'll write out all the channel ID's directly
×
UNCOV
367
                        // into the zlib writer, which will do compressing on
×
UNCOV
368
                        // the fly.
×
UNCOV
369
                        for _, chanID := range shortChanIDs {
×
UNCOV
370
                                err := WriteShortChannelID(&wb, chanID)
×
UNCOV
371
                                if err != nil {
×
372
                                        return fmt.Errorf(
×
373
                                                "unable to write short chan "+
×
374
                                                        "ID: %v", err,
×
375
                                        )
×
376
                                }
×
377
                        }
378

379
                        // With shortChanIDs written into wb, we'll create a
380
                        // zlib writer and write all the compressed bytes.
UNCOV
381
                        var zlibBuffer bytes.Buffer
×
UNCOV
382
                        zlibWriter := zlib.NewWriter(&zlibBuffer)
×
UNCOV
383

×
UNCOV
384
                        if _, err := zlibWriter.Write(wb.Bytes()); err != nil {
×
385
                                return fmt.Errorf(
×
386
                                        "unable to write compressed short chan"+
×
387
                                                "ID: %w", err)
×
388
                        }
×
389

390
                        // Now that we've written all the elements, we'll
391
                        // ensure the compressed stream is written to the
392
                        // underlying buffer.
UNCOV
393
                        if err := zlibWriter.Close(); err != nil {
×
394
                                return fmt.Errorf("unable to finalize "+
×
395
                                        "compression: %v", err)
×
396
                        }
×
397

UNCOV
398
                        compressedPayload = zlibBuffer.Bytes()
×
399
                }
400

401
                // Now that we have all the items compressed, we can compute
402
                // what the total payload size will be. We add one to account
403
                // for the byte to encode the type.
404
                //
405
                // If we don't have any actual bytes to write, then we'll end
406
                // up emitting one byte for the length, followed by the
407
                // encoding type, and nothing more. The spec isn't 100% clear
408
                // in this area, but we do this as this is what most of the
409
                // other implementations do.
UNCOV
410
                numBytesBody := len(compressedPayload) + 1
×
UNCOV
411

×
UNCOV
412
                // Finally, we can write out the number of bytes, the
×
UNCOV
413
                // compression type, and finally the buffer itself.
×
UNCOV
414
                if err := WriteUint16(w, uint16(numBytesBody)); err != nil {
×
415
                        return err
×
416
                }
×
UNCOV
417
                err := WriteQueryEncoding(w, encodingType)
×
UNCOV
418
                if err != nil {
×
419
                        return err
×
420
                }
×
421

UNCOV
422
                return WriteBytes(w, compressedPayload)
×
423

424
        default:
×
425
                // If we're trying to encode with an encoding type that we
×
426
                // don't know of, then we'll return a parsing error as we can't
×
427
                // continue if we're unable to encode them.
×
428
                return ErrUnknownShortChanIDEncoding(encodingType)
×
429
        }
430
}
431

432
// MsgType returns the integer uniquely identifying this message type on the
433
// wire.
434
//
435
// This is part of the lnwire.Message interface.
436
func (q *QueryShortChanIDs) MsgType() MessageType {
3✔
437
        return MsgQueryShortChanIDs
3✔
438
}
3✔
439

440
// SerializedSize returns the serialized size of the message in bytes.
441
//
442
// This is part of the lnwire.SizeableMessage interface.
NEW
443
func (q *QueryShortChanIDs) SerializedSize() (uint32, error) {
×
NEW
444
        return MessageSerializedSize(q)
×
NEW
445
}
×
446

447
// RandTestMessage populates the message with random data suitable for testing.
448
// It uses the rapid testing framework to generate random values.
449
//
450
// This is part of the TestMessage interface.
NEW
451
func (q *QueryShortChanIDs) RandTestMessage(t *rapid.T) Message {
×
NEW
452
        var chainHash chainhash.Hash
×
NEW
453
        hashBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "chainHash")
×
NEW
454
        copy(chainHash[:], hashBytes)
×
NEW
455

×
NEW
456
        encodingType := EncodingSortedPlain
×
NEW
457
        if rapid.Bool().Draw(t, "useZlibEncoding") {
×
NEW
458
                encodingType = EncodingSortedZlib
×
NEW
459
        }
×
460

NEW
461
        msg := &QueryShortChanIDs{
×
NEW
462
                ChainHash:    chainHash,
×
NEW
463
                EncodingType: encodingType,
×
NEW
464
                ExtraData:    RandExtraOpaqueData(t, nil),
×
NEW
465
                noSort:       false,
×
NEW
466
        }
×
NEW
467

×
NEW
468
        numIDs := rapid.IntRange(2, 20).Draw(t, "numShortChanIDs")
×
NEW
469

×
NEW
470
        // Generate sorted short channel IDs.
×
NEW
471
        shortChanIDs := make([]ShortChannelID, numIDs)
×
NEW
472
        for i := 0; i < numIDs; i++ {
×
NEW
473
                shortChanIDs[i] = RandShortChannelID(t)
×
NEW
474

×
NEW
475
                // Ensure they're properly sorted.
×
NEW
476
                if i > 0 && shortChanIDs[i].ToUint64() <=
×
NEW
477
                        shortChanIDs[i-1].ToUint64() {
×
NEW
478

×
NEW
479
                        // Ensure this ID is larger than the previous one.
×
NEW
480
                        shortChanIDs[i] = NewShortChanIDFromInt(
×
NEW
481
                                shortChanIDs[i-1].ToUint64() + 1,
×
NEW
482
                        )
×
NEW
483
                }
×
484
        }
485

NEW
486
        msg.ShortChanIDs = shortChanIDs
×
NEW
487

×
NEW
488
        return msg
×
489
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc