• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 13586005509

28 Feb 2025 10:14AM UTC coverage: 68.629% (+9.9%) from 58.77%
13586005509

Pull #9521

github

web-flow
Merge 37d3a70a5 into 8532955b3
Pull Request #9521: unit: remove GOACC, use Go 1.20 native coverage functionality

129950 of 189351 relevant lines covered (68.63%)

23726.46 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.02
/lnwire/query_short_chan_ids.go
1
package lnwire
2

3
import (
4
        "bytes"
5
        "compress/zlib"
6
        "fmt"
7
        "io"
8
        "sort"
9
        "sync"
10

11
        "github.com/btcsuite/btcd/chaincfg/chainhash"
12
)
13

14
const (
15
        // maxZlibBufSize is the max number of bytes that we'll accept from a
16
        // zlib decoding instance. We do this in order to limit the total
17
        // amount of memory allocated during a decoding instance.
18
        maxZlibBufSize = 67413630
19
)
20

21
// ErrUnsortedSIDs is returned when decoding a QueryShortChannelID request whose
22
// items were not sorted.
23
type ErrUnsortedSIDs struct {
24
        prevSID ShortChannelID
25
        curSID  ShortChannelID
26
}
27

28
// Error returns a human-readable description of the error.
29
func (e ErrUnsortedSIDs) Error() string {
×
30
        return fmt.Sprintf("current sid: %v isn't greater than last sid: %v",
×
31
                e.curSID, e.prevSID)
×
32
}
×
33

34
// zlibDecodeMtx is a package level mutex that we'll use in order to ensure
35
// that we'll only attempt a single zlib decoding instance at a time. This
36
// allows us to also further bound our memory usage.
37
var zlibDecodeMtx sync.Mutex
38

39
// ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we
40
// came across an unknown short channel ID encoding, and therefore were unable
41
// to continue parsing.
42
func ErrUnknownShortChanIDEncoding(encoding QueryEncoding) error {
4✔
43
        return fmt.Errorf("unknown short chan id encoding: %v", encoding)
4✔
44
}
4✔
45

46
// QueryShortChanIDs is a message that allows the sender to query a set of
47
// channel announcement and channel update messages that correspond to the set
48
// of encoded short channel ID's. The encoding of the short channel ID's is
49
// detailed in the query message ensuring that the receiver knows how to
50
// properly decode each encode short channel ID which may be encoded using a
51
// compression format. The receiver should respond with a series of channel
52
// announcement and channel updates, finally sending a ReplyShortChanIDsEnd
53
// message.
54
type QueryShortChanIDs struct {
55
        // ChainHash denotes the target chain that we're querying for the
56
        // channel ID's of.
57
        ChainHash chainhash.Hash
58

59
        // EncodingType is a signal to the receiver of the message that
60
        // indicates exactly how the set of short channel ID's that follow have
61
        // been encoded.
62
        EncodingType QueryEncoding
63

64
        // ShortChanIDs is a slice of decoded short channel ID's.
65
        ShortChanIDs []ShortChannelID
66

67
        // ExtraData is the set of data that was appended to this message to
68
        // fill out the full maximum transport message size. These fields can
69
        // be used to specify optional data such as custom TLV fields.
70
        ExtraData ExtraOpaqueData
71

72
        // noSort indicates whether or not to sort the short channel ids before
73
        // writing them out.
74
        //
75
        // NOTE: This should only be used during testing.
76
        noSort bool
77
}
78

79
// NewQueryShortChanIDs creates a new QueryShortChanIDs message.
80
func NewQueryShortChanIDs(h chainhash.Hash, e QueryEncoding,
81
        s []ShortChannelID) *QueryShortChanIDs {
×
82

×
83
        return &QueryShortChanIDs{
×
84
                ChainHash:    h,
×
85
                EncodingType: e,
×
86
                ShortChanIDs: s,
×
87
        }
×
88
}
×
89

90
// A compile time check to ensure QueryShortChanIDs implements the
91
// lnwire.Message interface.
92
var _ Message = (*QueryShortChanIDs)(nil)
93

94
// Decode deserializes a serialized QueryShortChanIDs message stored in the
95
// passed io.Reader observing the specified protocol version.
96
//
97
// This is part of the lnwire.Message interface.
98
func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error {
1,166✔
99
        err := ReadElements(r, q.ChainHash[:])
1,166✔
100
        if err != nil {
1,168✔
101
                return err
2✔
102
        }
2✔
103

104
        q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r)
1,164✔
105
        if err != nil {
1,507✔
106
                return err
343✔
107
        }
343✔
108

109
        return q.ExtraData.Decode(r)
821✔
110
}
111

112
// decodeShortChanIDs decodes a set of short channel ID's that have been
113
// encoded. The first byte of the body details how the short chan ID's were
114
// encoded. We'll use this type to govern exactly how we go about encoding the
115
// set of short channel ID's.
116
func decodeShortChanIDs(r io.Reader) (QueryEncoding, []ShortChannelID, error) {
2,380✔
117
        // First, we'll attempt to read the number of bytes in the body of the
2,380✔
118
        // set of encoded short channel ID's.
2,380✔
119
        var numBytesResp uint16
2,380✔
120
        err := ReadElements(r, &numBytesResp)
2,380✔
121
        if err != nil {
2,383✔
122
                return 0, nil, err
3✔
123
        }
3✔
124

125
        if numBytesResp == 0 {
2,437✔
126
                return 0, nil, nil
60✔
127
        }
60✔
128

129
        queryBody := make([]byte, numBytesResp)
2,317✔
130
        if _, err := io.ReadFull(r, queryBody); err != nil {
2,319✔
131
                return 0, nil, err
2✔
132
        }
2✔
133

134
        // The first byte is the encoding type, so we'll extract that so we can
135
        // continue our parsing.
136
        encodingType := QueryEncoding(queryBody[0])
2,315✔
137

2,315✔
138
        // Before continuing, we'll snip off the first byte of the query body
2,315✔
139
        // as that was just the encoding type.
2,315✔
140
        queryBody = queryBody[1:]
2,315✔
141

2,315✔
142
        // Otherwise, depending on the encoding type, we'll decode the encode
2,315✔
143
        // short channel ID's in a different manner.
2,315✔
144
        switch encodingType {
2,315✔
145

146
        // In this encoding, we'll simply read a sort array of encoded short
147
        // channel ID's from the buffer.
148
        case EncodingSortedPlain:
152✔
149
                // If after extracting the encoding type, the number of
152✔
150
                // remaining bytes is not a whole multiple of the size of an
152✔
151
                // encoded short channel ID (8 bytes), then we'll return a
152✔
152
                // parsing error.
152✔
153
                if len(queryBody)%8 != 0 {
156✔
154
                        return 0, nil, fmt.Errorf("whole number of short "+
4✔
155
                                "chan ID's cannot be encoded in len=%v",
4✔
156
                                len(queryBody))
4✔
157
                }
4✔
158

159
                // As each short channel ID is encoded as 8 bytes, we can
160
                // compute the number of bytes encoded based on the size of the
161
                // query body.
162
                numShortChanIDs := len(queryBody) / 8
148✔
163
                if numShortChanIDs == 0 {
167✔
164
                        return encodingType, nil, nil
19✔
165
                }
19✔
166

167
                // Finally, we'll read out the exact number of short channel
168
                // ID's to conclude our parsing.
169
                shortChanIDs := make([]ShortChannelID, numShortChanIDs)
132✔
170
                bodyReader := bytes.NewReader(queryBody)
132✔
171
                var lastChanID ShortChannelID
132✔
172
                for i := 0; i < numShortChanIDs; i++ {
203,385✔
173
                        if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil {
203,253✔
174
                                return 0, nil, fmt.Errorf("unable to parse "+
×
175
                                        "short chan ID: %v", err)
×
176
                        }
×
177

178
                        // We'll ensure that this short chan ID is greater than
179
                        // the last one. This is a requirement within the
180
                        // encoding, and if violated can aide us in detecting
181
                        // malicious payloads. This can only be true starting
182
                        // at the second chanID.
183
                        cid := shortChanIDs[i]
203,253✔
184
                        if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
203,261✔
185
                                return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
8✔
186
                        }
8✔
187
                        lastChanID = cid
203,245✔
188
                }
189

190
                return encodingType, shortChanIDs, nil
124✔
191

192
        // In this encoding, we'll use zlib to decode the compressed payload.
193
        // However, we'll pay attention to ensure that we don't open our selves
194
        // up to a memory exhaustion attack.
195
        case EncodingSortedZlib:
2,159✔
196
                // We'll obtain an ultimately release the zlib decode mutex.
2,159✔
197
                // This guards us against allocating too much memory to decode
2,159✔
198
                // each instance from concurrent peers.
2,159✔
199
                zlibDecodeMtx.Lock()
2,159✔
200
                defer zlibDecodeMtx.Unlock()
2,159✔
201

2,159✔
202
                // At this point, if there's no body remaining, then only the encoding
2,159✔
203
                // type was specified, meaning that there're no further bytes to be
2,159✔
204
                // parsed.
2,159✔
205
                if len(queryBody) == 0 {
2,242✔
206
                        return encodingType, nil, nil
83✔
207
                }
83✔
208

209
                // Before we start to decode, we'll create a limit reader over
210
                // the current reader. This will ensure that we can control how
211
                // much memory we're allocating during the decoding process.
212
                limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{
2,076✔
213
                        R: bytes.NewReader(queryBody),
2,076✔
214
                        N: maxZlibBufSize,
2,076✔
215
                })
2,076✔
216
                if err != nil {
2,086✔
217
                        return 0, nil, fmt.Errorf("unable to create zlib "+
10✔
218
                                "reader: %w", err)
10✔
219
                }
10✔
220

221
                var (
2,066✔
222
                        shortChanIDs []ShortChannelID
2,066✔
223
                        lastChanID   ShortChannelID
2,066✔
224
                        i            int
2,066✔
225
                )
2,066✔
226
                for {
260,093✔
227
                        // We'll now attempt to read the next short channel ID
258,027✔
228
                        // encoded in the payload.
258,027✔
229
                        var cid ShortChannelID
258,027✔
230
                        err := ReadElements(limitedDecompressor, &cid)
258,027✔
231

258,027✔
232
                        switch {
258,027✔
233
                        // If we get an EOF error, then that either means we've
234
                        // read all that's contained in the buffer, or have hit
235
                        // our limit on the number of bytes we'll read. In
236
                        // either case, we'll return what we have so far.
237
                        case err == io.ErrUnexpectedEOF || err == io.EOF:
1,372✔
238
                                return encodingType, shortChanIDs, nil
1,372✔
239

240
                        // Otherwise, we hit some other sort of error, possibly
241
                        // an invalid payload, so we'll exit early with the
242
                        // error.
243
                        case err != nil:
80✔
244
                                return 0, nil, fmt.Errorf("unable to "+
80✔
245
                                        "deflate next short chan "+
80✔
246
                                        "ID: %v", err)
80✔
247
                        }
248

249
                        // We successfully read the next ID, so we'll collect
250
                        // that in the set of final ID's to return.
251
                        shortChanIDs = append(shortChanIDs, cid)
256,575✔
252

256,575✔
253
                        // Finally, we'll ensure that this short chan ID is
256,575✔
254
                        // greater than the last one. This is a requirement
256,575✔
255
                        // within the encoding, and if violated can aide us in
256,575✔
256
                        // detecting malicious payloads. This can only be true
256,575✔
257
                        // starting at the second chanID.
256,575✔
258
                        if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
257,189✔
259
                                return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
614✔
260
                        }
614✔
261

262
                        lastChanID = cid
255,961✔
263
                        i++
255,961✔
264
                }
265

266
        default:
4✔
267
                // If we've been sent an encoding type that we don't know of,
4✔
268
                // then we'll return a parsing error as we can't continue if
4✔
269
                // we're unable to encode them.
4✔
270
                return 0, nil, ErrUnknownShortChanIDEncoding(encodingType)
4✔
271
        }
272
}
273

274
// Encode serializes the target QueryShortChanIDs into the passed io.Writer
275
// observing the protocol version specified.
276
//
277
// This is part of the lnwire.Message interface.
278
func (q *QueryShortChanIDs) Encode(w *bytes.Buffer, pver uint32) error {
467✔
279
        // First, we'll write out the chain hash.
467✔
280
        if err := WriteBytes(w, q.ChainHash[:]); err != nil {
467✔
281
                return err
×
282
        }
×
283

284
        // For both of the current encoding types, the channel ID's are to be
285
        // sorted in place, so we'll do that now. The sorting is applied unless
286
        // we were specifically requested not to for testing purposes.
287
        if !q.noSort {
928✔
288
                sort.Slice(q.ShortChanIDs, func(i, j int) bool {
3,112,626✔
289
                        return q.ShortChanIDs[i].ToUint64() <
3,112,165✔
290
                                q.ShortChanIDs[j].ToUint64()
3,112,165✔
291
                })
3,112,165✔
292
        }
293

294
        // Base on our encoding type, we'll write out the set of short channel
295
        // ID's.
296
        err := encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs)
467✔
297
        if err != nil {
467✔
298
                return err
×
299
        }
×
300

301
        return WriteBytes(w, q.ExtraData)
467✔
302
}
303

304
// encodeShortChanIDs encodes the passed short channel ID's into the passed
305
// io.Writer, respecting the specified encoding type.
306
func encodeShortChanIDs(w *bytes.Buffer, encodingType QueryEncoding,
307
        shortChanIDs []ShortChannelID) error {
903✔
308

903✔
309
        switch encodingType {
903✔
310

311
        // In this encoding, we'll simply write a sorted array of encoded short
312
        // channel ID's from the buffer.
313
        case EncodingSortedPlain:
126✔
314
                // First, we'll write out the number of bytes of the query
126✔
315
                // body. We add 1 as the response will have the encoding type
126✔
316
                // prepended to it.
126✔
317
                numBytesBody := uint16(len(shortChanIDs)*8) + 1
126✔
318
                if err := WriteUint16(w, numBytesBody); err != nil {
126✔
319
                        return err
×
320
                }
×
321

322
                // We'll then write out the encoding that that follows the
323
                // actual encoded short channel ID's.
324
                err := WriteQueryEncoding(w, encodingType)
126✔
325
                if err != nil {
126✔
326
                        return err
×
327
                }
×
328

329
                // Now that we know they're sorted, we can write out each short
330
                // channel ID to the buffer.
331
                for _, chanID := range shortChanIDs {
203,305✔
332
                        if err := WriteShortChannelID(w, chanID); err != nil {
203,179✔
333
                                return fmt.Errorf("unable to write short chan "+
×
334
                                        "ID: %v", err)
×
335
                        }
×
336
                }
337

338
                return nil
126✔
339

340
        // For this encoding we'll first write out a serialized version of all
341
        // the channel ID's into a buffer, then zlib encode that. The final
342
        // payload is what we'll write out to the passed io.Writer.
343
        //
344
        // TODO(roasbeef): assumes the caller knows the proper chunk size to
345
        // pass to avoid bin-packing here
346
        case EncodingSortedZlib:
777✔
347
                // If we don't have anything at all to write, then we'll write
777✔
348
                // an empty payload so we don't include things like the zlib
777✔
349
                // header when the remote party is expecting no actual short
777✔
350
                // channel IDs.
777✔
351
                var compressedPayload []byte
777✔
352
                if len(shortChanIDs) > 0 {
1,474✔
353
                        // We'll make a new write buffer to hold the bytes of
697✔
354
                        // shortChanIDs.
697✔
355
                        var wb bytes.Buffer
697✔
356

697✔
357
                        // Next, we'll write out all the channel ID's directly
697✔
358
                        // into the zlib writer, which will do compressing on
697✔
359
                        // the fly.
697✔
360
                        for _, chanID := range shortChanIDs {
252,348✔
361
                                err := WriteShortChannelID(&wb, chanID)
251,651✔
362
                                if err != nil {
251,651✔
363
                                        return fmt.Errorf(
×
364
                                                "unable to write short chan "+
×
365
                                                        "ID: %v", err,
×
366
                                        )
×
367
                                }
×
368
                        }
369

370
                        // With shortChanIDs written into wb, we'll create a
371
                        // zlib writer and write all the compressed bytes.
372
                        var zlibBuffer bytes.Buffer
697✔
373
                        zlibWriter := zlib.NewWriter(&zlibBuffer)
697✔
374

697✔
375
                        if _, err := zlibWriter.Write(wb.Bytes()); err != nil {
697✔
376
                                return fmt.Errorf(
×
377
                                        "unable to write compressed short chan"+
×
378
                                                "ID: %w", err)
×
379
                        }
×
380

381
                        // Now that we've written all the elements, we'll
382
                        // ensure the compressed stream is written to the
383
                        // underlying buffer.
384
                        if err := zlibWriter.Close(); err != nil {
697✔
385
                                return fmt.Errorf("unable to finalize "+
×
386
                                        "compression: %v", err)
×
387
                        }
×
388

389
                        compressedPayload = zlibBuffer.Bytes()
697✔
390
                }
391

392
                // Now that we have all the items compressed, we can compute
393
                // what the total payload size will be. We add one to account
394
                // for the byte to encode the type.
395
                //
396
                // If we don't have any actual bytes to write, then we'll end
397
                // up emitting one byte for the length, followed by the
398
                // encoding type, and nothing more. The spec isn't 100% clear
399
                // in this area, but we do this as this is what most of the
400
                // other implementations do.
401
                numBytesBody := len(compressedPayload) + 1
777✔
402

777✔
403
                // Finally, we can write out the number of bytes, the
777✔
404
                // compression type, and finally the buffer itself.
777✔
405
                if err := WriteUint16(w, uint16(numBytesBody)); err != nil {
777✔
406
                        return err
×
407
                }
×
408
                err := WriteQueryEncoding(w, encodingType)
777✔
409
                if err != nil {
777✔
410
                        return err
×
411
                }
×
412

413
                return WriteBytes(w, compressedPayload)
777✔
414

415
        default:
×
416
                // If we're trying to encode with an encoding type that we
×
417
                // don't know of, then we'll return a parsing error as we can't
×
418
                // continue if we're unable to encode them.
×
419
                return ErrUnknownShortChanIDEncoding(encodingType)
×
420
        }
421
}
422

423
// MsgType returns the integer uniquely identifying this message type on the
424
// wire.
425
//
426
// This is part of the lnwire.Message interface.
427
func (q *QueryShortChanIDs) MsgType() MessageType {
461✔
428
        return MsgQueryShortChanIDs
461✔
429
}
461✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc