• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 14358372723

09 Apr 2025 01:26PM UTC coverage: 56.696% (-12.3%) from 69.037%
14358372723

Pull #9696

github

web-flow
Merge e2837e400 into 867d27d68
Pull Request #9696: Add `development_guidelines.md` for both human and machine

107055 of 188823 relevant lines covered (56.7%)

22721.56 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.97
/lnwire/query_short_chan_ids.go
1
package lnwire
2

3
import (
4
        "bytes"
5
        "compress/zlib"
6
        "fmt"
7
        "io"
8
        "sort"
9
        "sync"
10

11
        "github.com/btcsuite/btcd/chaincfg/chainhash"
12
)
13

14
const (
15
        // maxZlibBufSize is the max number of bytes that we'll accept from a
16
        // zlib decoding instance. We do this in order to limit the total
17
        // amount of memory allocated during a decoding instance.
18
        maxZlibBufSize = 67413630
19
)
20

21
// ErrUnsortedSIDs is returned when decoding a QueryShortChannelID request whose
22
// items were not sorted.
23
type ErrUnsortedSIDs struct {
24
        prevSID ShortChannelID
25
        curSID  ShortChannelID
26
}
27

28
// Error returns a human-readable description of the error.
29
func (e ErrUnsortedSIDs) Error() string {
×
30
        return fmt.Sprintf("current sid: %v isn't greater than last sid: %v",
×
31
                e.curSID, e.prevSID)
×
32
}
×
33

34
// zlibDecodeMtx is a package level mutex that we'll use in order to ensure
35
// that we'll only attempt a single zlib decoding instance at a time. This
36
// allows us to also further bound our memory usage.
37
var zlibDecodeMtx sync.Mutex
38

39
// ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we
40
// came across an unknown short channel ID encoding, and therefore were unable
41
// to continue parsing.
42
func ErrUnknownShortChanIDEncoding(encoding QueryEncoding) error {
4✔
43
        return fmt.Errorf("unknown short chan id encoding: %v", encoding)
4✔
44
}
4✔
45

46
// QueryShortChanIDs is a message that allows the sender to query a set of
47
// channel announcement and channel update messages that correspond to the set
48
// of encoded short channel ID's. The encoding of the short channel ID's is
49
// detailed in the query message ensuring that the receiver knows how to
50
// properly decode each encode short channel ID which may be encoded using a
51
// compression format. The receiver should respond with a series of channel
52
// announcement and channel updates, finally sending a ReplyShortChanIDsEnd
53
// message.
54
type QueryShortChanIDs struct {
55
        // ChainHash denotes the target chain that we're querying for the
56
        // channel ID's of.
57
        ChainHash chainhash.Hash
58

59
        // EncodingType is a signal to the receiver of the message that
60
        // indicates exactly how the set of short channel ID's that follow have
61
        // been encoded.
62
        EncodingType QueryEncoding
63

64
        // ShortChanIDs is a slice of decoded short channel ID's.
65
        ShortChanIDs []ShortChannelID
66

67
        // ExtraData is the set of data that was appended to this message to
68
        // fill out the full maximum transport message size. These fields can
69
        // be used to specify optional data such as custom TLV fields.
70
        ExtraData ExtraOpaqueData
71

72
        // noSort indicates whether or not to sort the short channel ids before
73
        // writing them out.
74
        //
75
        // NOTE: This should only be used during testing.
76
        noSort bool
77
}
78

79
// NewQueryShortChanIDs creates a new QueryShortChanIDs message.
80
func NewQueryShortChanIDs(h chainhash.Hash, e QueryEncoding,
81
        s []ShortChannelID) *QueryShortChanIDs {
×
82

×
83
        return &QueryShortChanIDs{
×
84
                ChainHash:    h,
×
85
                EncodingType: e,
×
86
                ShortChanIDs: s,
×
87
        }
×
88
}
×
89

90
// A compile time check to ensure QueryShortChanIDs implements the
91
// lnwire.Message interface.
92
var _ Message = (*QueryShortChanIDs)(nil)
93

94
// A compile time check to ensure QueryShortChanIDs implements the
95
// lnwire.SizeableMessage interface.
96
var _ SizeableMessage = (*QueryShortChanIDs)(nil)
97

98
// Decode deserializes a serialized QueryShortChanIDs message stored in the
99
// passed io.Reader observing the specified protocol version.
100
//
101
// This is part of the lnwire.Message interface.
102
func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error {
1,163✔
103
        err := ReadElements(r, q.ChainHash[:])
1,163✔
104
        if err != nil {
1,165✔
105
                return err
2✔
106
        }
2✔
107

108
        q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r)
1,161✔
109
        if err != nil {
1,504✔
110
                return err
343✔
111
        }
343✔
112

113
        return q.ExtraData.Decode(r)
818✔
114
}
115

116
// decodeShortChanIDs decodes a set of short channel ID's that have been
117
// encoded. The first byte of the body details how the short chan ID's were
118
// encoded. We'll use this type to govern exactly how we go about encoding the
119
// set of short channel ID's.
120
func decodeShortChanIDs(r io.Reader) (QueryEncoding, []ShortChannelID, error) {
2,377✔
121
        // First, we'll attempt to read the number of bytes in the body of the
2,377✔
122
        // set of encoded short channel ID's.
2,377✔
123
        var numBytesResp uint16
2,377✔
124
        err := ReadElements(r, &numBytesResp)
2,377✔
125
        if err != nil {
2,380✔
126
                return 0, nil, err
3✔
127
        }
3✔
128

129
        if numBytesResp == 0 {
2,434✔
130
                return 0, nil, nil
60✔
131
        }
60✔
132

133
        queryBody := make([]byte, numBytesResp)
2,314✔
134
        if _, err := io.ReadFull(r, queryBody); err != nil {
2,316✔
135
                return 0, nil, err
2✔
136
        }
2✔
137

138
        // The first byte is the encoding type, so we'll extract that so we can
139
        // continue our parsing.
140
        encodingType := QueryEncoding(queryBody[0])
2,312✔
141

2,312✔
142
        // Before continuing, we'll snip off the first byte of the query body
2,312✔
143
        // as that was just the encoding type.
2,312✔
144
        queryBody = queryBody[1:]
2,312✔
145

2,312✔
146
        // Otherwise, depending on the encoding type, we'll decode the encode
2,312✔
147
        // short channel ID's in a different manner.
2,312✔
148
        switch encodingType {
2,312✔
149

150
        // In this encoding, we'll simply read a sort array of encoded short
151
        // channel ID's from the buffer.
152
        case EncodingSortedPlain:
157✔
153
                // If after extracting the encoding type, the number of
157✔
154
                // remaining bytes is not a whole multiple of the size of an
157✔
155
                // encoded short channel ID (8 bytes), then we'll return a
157✔
156
                // parsing error.
157✔
157
                if len(queryBody)%8 != 0 {
161✔
158
                        return 0, nil, fmt.Errorf("whole number of short "+
4✔
159
                                "chan ID's cannot be encoded in len=%v",
4✔
160
                                len(queryBody))
4✔
161
                }
4✔
162

163
                // As each short channel ID is encoded as 8 bytes, we can
164
                // compute the number of bytes encoded based on the size of the
165
                // query body.
166
                numShortChanIDs := len(queryBody) / 8
153✔
167
                if numShortChanIDs == 0 {
177✔
168
                        return encodingType, nil, nil
24✔
169
                }
24✔
170

171
                // Finally, we'll read out the exact number of short channel
172
                // ID's to conclude our parsing.
173
                shortChanIDs := make([]ShortChannelID, numShortChanIDs)
129✔
174
                bodyReader := bytes.NewReader(queryBody)
129✔
175
                var lastChanID ShortChannelID
129✔
176
                for i := 0; i < numShortChanIDs; i++ {
1,027✔
177
                        if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil {
898✔
178
                                return 0, nil, fmt.Errorf("unable to parse "+
×
179
                                        "short chan ID: %v", err)
×
180
                        }
×
181

182
                        // We'll ensure that this short chan ID is greater than
183
                        // the last one. This is a requirement within the
184
                        // encoding, and if violated can aide us in detecting
185
                        // malicious payloads. This can only be true starting
186
                        // at the second chanID.
187
                        cid := shortChanIDs[i]
898✔
188
                        if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
906✔
189
                                return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
8✔
190
                        }
8✔
191
                        lastChanID = cid
890✔
192
                }
193

194
                return encodingType, shortChanIDs, nil
121✔
195

196
        // In this encoding, we'll use zlib to decode the compressed payload.
197
        // However, we'll pay attention to ensure that we don't open our selves
198
        // up to a memory exhaustion attack.
199
        case EncodingSortedZlib:
2,151✔
200
                // We'll obtain an ultimately release the zlib decode mutex.
2,151✔
201
                // This guards us against allocating too much memory to decode
2,151✔
202
                // each instance from concurrent peers.
2,151✔
203
                zlibDecodeMtx.Lock()
2,151✔
204
                defer zlibDecodeMtx.Unlock()
2,151✔
205

2,151✔
206
                // At this point, if there's no body remaining, then only the encoding
2,151✔
207
                // type was specified, meaning that there're no further bytes to be
2,151✔
208
                // parsed.
2,151✔
209
                if len(queryBody) == 0 {
2,241✔
210
                        return encodingType, nil, nil
90✔
211
                }
90✔
212

213
                // Before we start to decode, we'll create a limit reader over
214
                // the current reader. This will ensure that we can control how
215
                // much memory we're allocating during the decoding process.
216
                limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{
2,061✔
217
                        R: bytes.NewReader(queryBody),
2,061✔
218
                        N: maxZlibBufSize,
2,061✔
219
                })
2,061✔
220
                if err != nil {
2,071✔
221
                        return 0, nil, fmt.Errorf("unable to create zlib "+
10✔
222
                                "reader: %w", err)
10✔
223
                }
10✔
224

225
                var (
2,051✔
226
                        shortChanIDs []ShortChannelID
2,051✔
227
                        lastChanID   ShortChannelID
2,051✔
228
                        i            int
2,051✔
229
                )
2,051✔
230
                for {
12,246✔
231
                        // We'll now attempt to read the next short channel ID
10,195✔
232
                        // encoded in the payload.
10,195✔
233
                        var cid ShortChannelID
10,195✔
234
                        err := ReadElements(limitedDecompressor, &cid)
10,195✔
235

10,195✔
236
                        switch {
10,195✔
237
                        // If we get an EOF error, then that either means we've
238
                        // read all that's contained in the buffer, or have hit
239
                        // our limit on the number of bytes we'll read. In
240
                        // either case, we'll return what we have so far.
241
                        case err == io.ErrUnexpectedEOF || err == io.EOF:
1,357✔
242
                                return encodingType, shortChanIDs, nil
1,357✔
243

244
                        // Otherwise, we hit some other sort of error, possibly
245
                        // an invalid payload, so we'll exit early with the
246
                        // error.
247
                        case err != nil:
80✔
248
                                return 0, nil, fmt.Errorf("unable to "+
80✔
249
                                        "deflate next short chan "+
80✔
250
                                        "ID: %v", err)
80✔
251
                        }
252

253
                        // We successfully read the next ID, so we'll collect
254
                        // that in the set of final ID's to return.
255
                        shortChanIDs = append(shortChanIDs, cid)
8,758✔
256

8,758✔
257
                        // Finally, we'll ensure that this short chan ID is
8,758✔
258
                        // greater than the last one. This is a requirement
8,758✔
259
                        // within the encoding, and if violated can aide us in
8,758✔
260
                        // detecting malicious payloads. This can only be true
8,758✔
261
                        // starting at the second chanID.
8,758✔
262
                        if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
9,372✔
263
                                return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
614✔
264
                        }
614✔
265

266
                        lastChanID = cid
8,144✔
267
                        i++
8,144✔
268
                }
269

270
        default:
4✔
271
                // If we've been sent an encoding type that we don't know of,
4✔
272
                // then we'll return a parsing error as we can't continue if
4✔
273
                // we're unable to encode them.
4✔
274
                return 0, nil, ErrUnknownShortChanIDEncoding(encodingType)
4✔
275
        }
276
}
277

278
// Encode serializes the target QueryShortChanIDs into the passed io.Writer
279
// observing the protocol version specified.
280
//
281
// This is part of the lnwire.Message interface.
282
func (q *QueryShortChanIDs) Encode(w *bytes.Buffer, pver uint32) error {
464✔
283
        // First, we'll write out the chain hash.
464✔
284
        if err := WriteBytes(w, q.ChainHash[:]); err != nil {
464✔
285
                return err
×
286
        }
×
287

288
        // For both of the current encoding types, the channel ID's are to be
289
        // sorted in place, so we'll do that now. The sorting is applied unless
290
        // we were specifically requested not to for testing purposes.
291
        if !q.noSort {
922✔
292
                sort.Slice(q.ShortChanIDs, func(i, j int) bool {
2,505✔
293
                        return q.ShortChanIDs[i].ToUint64() <
2,047✔
294
                                q.ShortChanIDs[j].ToUint64()
2,047✔
295
                })
2,047✔
296
        }
297

298
        // Base on our encoding type, we'll write out the set of short channel
299
        // ID's.
300
        err := encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs)
464✔
301
        if err != nil {
464✔
302
                return err
×
303
        }
×
304

305
        return WriteBytes(w, q.ExtraData)
464✔
306
}
307

308
// encodeShortChanIDs encodes the passed short channel ID's into the passed
309
// io.Writer, respecting the specified encoding type.
310
func encodeShortChanIDs(w *bytes.Buffer, encodingType QueryEncoding,
311
        shortChanIDs []ShortChannelID) error {
900✔
312

900✔
313
        switch encodingType {
900✔
314

315
        // In this encoding, we'll simply write a sorted array of encoded short
316
        // channel ID's from the buffer.
317
        case EncodingSortedPlain:
131✔
318
                // First, we'll write out the number of bytes of the query
131✔
319
                // body. We add 1 as the response will have the encoding type
131✔
320
                // prepended to it.
131✔
321
                numBytesBody := uint16(len(shortChanIDs)*8) + 1
131✔
322
                if err := WriteUint16(w, numBytesBody); err != nil {
131✔
323
                        return err
×
324
                }
×
325

326
                // We'll then write out the encoding that that follows the
327
                // actual encoded short channel ID's.
328
                err := WriteQueryEncoding(w, encodingType)
131✔
329
                if err != nil {
131✔
330
                        return err
×
331
                }
×
332

333
                // Now that we know they're sorted, we can write out each short
334
                // channel ID to the buffer.
335
                for _, chanID := range shortChanIDs {
955✔
336
                        if err := WriteShortChannelID(w, chanID); err != nil {
824✔
337
                                return fmt.Errorf("unable to write short chan "+
×
338
                                        "ID: %v", err)
×
339
                        }
×
340
                }
341

342
                return nil
131✔
343

344
        // For this encoding we'll first write out a serialized version of all
345
        // the channel ID's into a buffer, then zlib encode that. The final
346
        // payload is what we'll write out to the passed io.Writer.
347
        //
348
        // TODO(roasbeef): assumes the caller knows the proper chunk size to
349
        // pass to avoid bin-packing here
350
        case EncodingSortedZlib:
769✔
351
                // If we don't have anything at all to write, then we'll write
769✔
352
                // an empty payload so we don't include things like the zlib
769✔
353
                // header when the remote party is expecting no actual short
769✔
354
                // channel IDs.
769✔
355
                var compressedPayload []byte
769✔
356
                if len(shortChanIDs) > 0 {
1,451✔
357
                        // We'll make a new write buffer to hold the bytes of
682✔
358
                        // shortChanIDs.
682✔
359
                        var wb bytes.Buffer
682✔
360

682✔
361
                        // Next, we'll write out all the channel ID's directly
682✔
362
                        // into the zlib writer, which will do compressing on
682✔
363
                        // the fly.
682✔
364
                        for _, chanID := range shortChanIDs {
4,516✔
365
                                err := WriteShortChannelID(&wb, chanID)
3,834✔
366
                                if err != nil {
3,834✔
367
                                        return fmt.Errorf(
×
368
                                                "unable to write short chan "+
×
369
                                                        "ID: %v", err,
×
370
                                        )
×
371
                                }
×
372
                        }
373

374
                        // With shortChanIDs written into wb, we'll create a
375
                        // zlib writer and write all the compressed bytes.
376
                        var zlibBuffer bytes.Buffer
682✔
377
                        zlibWriter := zlib.NewWriter(&zlibBuffer)
682✔
378

682✔
379
                        if _, err := zlibWriter.Write(wb.Bytes()); err != nil {
682✔
380
                                return fmt.Errorf(
×
381
                                        "unable to write compressed short chan"+
×
382
                                                "ID: %w", err)
×
383
                        }
×
384

385
                        // Now that we've written all the elements, we'll
386
                        // ensure the compressed stream is written to the
387
                        // underlying buffer.
388
                        if err := zlibWriter.Close(); err != nil {
682✔
389
                                return fmt.Errorf("unable to finalize "+
×
390
                                        "compression: %v", err)
×
391
                        }
×
392

393
                        compressedPayload = zlibBuffer.Bytes()
682✔
394
                }
395

396
                // Now that we have all the items compressed, we can compute
397
                // what the total payload size will be. We add one to account
398
                // for the byte to encode the type.
399
                //
400
                // If we don't have any actual bytes to write, then we'll end
401
                // up emitting one byte for the length, followed by the
402
                // encoding type, and nothing more. The spec isn't 100% clear
403
                // in this area, but we do this as this is what most of the
404
                // other implementations do.
405
                numBytesBody := len(compressedPayload) + 1
769✔
406

769✔
407
                // Finally, we can write out the number of bytes, the
769✔
408
                // compression type, and finally the buffer itself.
769✔
409
                if err := WriteUint16(w, uint16(numBytesBody)); err != nil {
769✔
410
                        return err
×
411
                }
×
412
                err := WriteQueryEncoding(w, encodingType)
769✔
413
                if err != nil {
769✔
414
                        return err
×
415
                }
×
416

417
                return WriteBytes(w, compressedPayload)
769✔
418

419
        default:
×
420
                // If we're trying to encode with an encoding type that we
×
421
                // don't know of, then we'll return a parsing error as we can't
×
422
                // continue if we're unable to encode them.
×
423
                return ErrUnknownShortChanIDEncoding(encodingType)
×
424
        }
425
}
426

427
// MsgType returns the integer uniquely identifying this message type on the
428
// wire.
429
//
430
// This is part of the lnwire.Message interface.
431
func (q *QueryShortChanIDs) MsgType() MessageType {
458✔
432
        return MsgQueryShortChanIDs
458✔
433
}
458✔
434

435
// SerializedSize returns the serialized size of the message in bytes.
436
//
437
// This is part of the lnwire.SizeableMessage interface.
438
func (q *QueryShortChanIDs) SerializedSize() (uint32, error) {
×
439
        return MessageSerializedSize(q)
×
440
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc