• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15561477203

10 Jun 2025 01:54PM UTC coverage: 58.351% (-10.1%) from 68.487%
15561477203

Pull #9356

github

web-flow
Merge 6440b25db into c6d6d4c0b
Pull Request #9356: lnrpc: add incoming/outgoing channel ids filter to forwarding history request

33 of 36 new or added lines in 2 files covered. (91.67%)

28366 existing lines in 455 files now uncovered.

97715 of 167461 relevant lines covered (58.35%)

1.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

32.6
/lnwire/query_short_chan_ids.go
1
package lnwire
2

3
import (
4
        "bytes"
5
        "compress/zlib"
6
        "fmt"
7
        "io"
8
        "sort"
9
        "sync"
10

11
        "github.com/btcsuite/btcd/chaincfg/chainhash"
12
)
13

14
const (
15
        // maxZlibBufSize is the max number of bytes that we'll accept from a
16
        // zlib decoding instance. We do this in order to limit the total
17
        // amount of memory allocated during a decoding instance.
18
        maxZlibBufSize = 67413630
19
)
20

21
// ErrUnsortedSIDs is returned when decoding a QueryShortChannelID request whose
22
// items were not sorted.
23
type ErrUnsortedSIDs struct {
24
        prevSID ShortChannelID
25
        curSID  ShortChannelID
26
}
27

28
// Error returns a human-readable description of the error.
29
func (e ErrUnsortedSIDs) Error() string {
×
30
        return fmt.Sprintf("current sid: %v isn't greater than last sid: %v",
×
31
                e.curSID, e.prevSID)
×
32
}
×
33

34
// zlibDecodeMtx is a package level mutex that we'll use in order to ensure
35
// that we'll only attempt a single zlib decoding instance at a time. This
36
// allows us to also further bound our memory usage.
37
var zlibDecodeMtx sync.Mutex
38

39
// ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we
40
// came across an unknown short channel ID encoding, and therefore were unable
41
// to continue parsing.
UNCOV
42
func ErrUnknownShortChanIDEncoding(encoding QueryEncoding) error {
×
UNCOV
43
        return fmt.Errorf("unknown short chan id encoding: %v", encoding)
×
UNCOV
44
}
×
45

46
// QueryShortChanIDs is a message that allows the sender to query a set of
47
// channel announcement and channel update messages that correspond to the set
48
// of encoded short channel ID's. The encoding of the short channel ID's is
49
// detailed in the query message ensuring that the receiver knows how to
50
// properly decode each encode short channel ID which may be encoded using a
51
// compression format. The receiver should respond with a series of channel
52
// announcement and channel updates, finally sending a ReplyShortChanIDsEnd
53
// message.
54
type QueryShortChanIDs struct {
55
        // ChainHash denotes the target chain that we're querying for the
56
        // channel ID's of.
57
        ChainHash chainhash.Hash
58

59
        // EncodingType is a signal to the receiver of the message that
60
        // indicates exactly how the set of short channel ID's that follow have
61
        // been encoded.
62
        EncodingType QueryEncoding
63

64
        // ShortChanIDs is a slice of decoded short channel ID's.
65
        ShortChanIDs []ShortChannelID
66

67
        // ExtraData is the set of data that was appended to this message to
68
        // fill out the full maximum transport message size. These fields can
69
        // be used to specify optional data such as custom TLV fields.
70
        ExtraData ExtraOpaqueData
71

72
        // noSort indicates whether or not to sort the short channel ids before
73
        // writing them out.
74
        //
75
        // NOTE: This should only be used during testing.
76
        noSort bool
77
}
78

79
// NewQueryShortChanIDs creates a new QueryShortChanIDs message.
80
func NewQueryShortChanIDs(h chainhash.Hash, e QueryEncoding,
81
        s []ShortChannelID) *QueryShortChanIDs {
×
82

×
83
        return &QueryShortChanIDs{
×
84
                ChainHash:    h,
×
85
                EncodingType: e,
×
86
                ShortChanIDs: s,
×
87
        }
×
88
}
×
89

90
// A compile time check to ensure QueryShortChanIDs implements the
91
// lnwire.Message interface.
92
var _ Message = (*QueryShortChanIDs)(nil)
93

94
// A compile time check to ensure QueryShortChanIDs implements the
95
// lnwire.SizeableMessage interface.
96
var _ SizeableMessage = (*QueryShortChanIDs)(nil)
97

98
// Decode deserializes a serialized QueryShortChanIDs message stored in the
99
// passed io.Reader observing the specified protocol version.
100
//
101
// This is part of the lnwire.Message interface.
102
func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error {
3✔
103
        err := ReadElements(r, q.ChainHash[:])
3✔
104
        if err != nil {
3✔
UNCOV
105
                return err
×
UNCOV
106
        }
×
107

108
        q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r)
3✔
109
        if err != nil {
3✔
UNCOV
110
                return err
×
UNCOV
111
        }
×
112

113
        return q.ExtraData.Decode(r)
3✔
114
}
115

116
// decodeShortChanIDs decodes a set of short channel ID's that have been
117
// encoded. The first byte of the body details how the short chan ID's were
118
// encoded. We'll use this type to govern exactly how we go about encoding the
119
// set of short channel ID's.
120
func decodeShortChanIDs(r io.Reader) (QueryEncoding, []ShortChannelID, error) {
3✔
121
        // First, we'll attempt to read the number of bytes in the body of the
3✔
122
        // set of encoded short channel ID's.
3✔
123
        var numBytesResp uint16
3✔
124
        err := ReadElements(r, &numBytesResp)
3✔
125
        if err != nil {
3✔
UNCOV
126
                return 0, nil, err
×
UNCOV
127
        }
×
128

129
        if numBytesResp == 0 {
3✔
UNCOV
130
                return 0, nil, nil
×
UNCOV
131
        }
×
132

133
        queryBody := make([]byte, numBytesResp)
3✔
134
        if _, err := io.ReadFull(r, queryBody); err != nil {
3✔
UNCOV
135
                return 0, nil, err
×
UNCOV
136
        }
×
137

138
        // The first byte is the encoding type, so we'll extract that so we can
139
        // continue our parsing.
140
        encodingType := QueryEncoding(queryBody[0])
3✔
141

3✔
142
        // Before continuing, we'll snip off the first byte of the query body
3✔
143
        // as that was just the encoding type.
3✔
144
        queryBody = queryBody[1:]
3✔
145

3✔
146
        // Otherwise, depending on the encoding type, we'll decode the encode
3✔
147
        // short channel ID's in a different manner.
3✔
148
        switch encodingType {
3✔
149

150
        // In this encoding, we'll simply read a sort array of encoded short
151
        // channel ID's from the buffer.
152
        case EncodingSortedPlain:
3✔
153
                // If after extracting the encoding type, the number of
3✔
154
                // remaining bytes is not a whole multiple of the size of an
3✔
155
                // encoded short channel ID (8 bytes), then we'll return a
3✔
156
                // parsing error.
3✔
157
                if len(queryBody)%8 != 0 {
3✔
UNCOV
158
                        return 0, nil, fmt.Errorf("whole number of short "+
×
UNCOV
159
                                "chan ID's cannot be encoded in len=%v",
×
UNCOV
160
                                len(queryBody))
×
UNCOV
161
                }
×
162

163
                // As each short channel ID is encoded as 8 bytes, we can
164
                // compute the number of bytes encoded based on the size of the
165
                // query body.
166
                numShortChanIDs := len(queryBody) / 8
3✔
167
                if numShortChanIDs == 0 {
6✔
168
                        return encodingType, nil, nil
3✔
169
                }
3✔
170

171
                // Finally, we'll read out the exact number of short channel
172
                // ID's to conclude our parsing.
173
                shortChanIDs := make([]ShortChannelID, numShortChanIDs)
3✔
174
                bodyReader := bytes.NewReader(queryBody)
3✔
175
                var lastChanID ShortChannelID
3✔
176
                for i := 0; i < numShortChanIDs; i++ {
6✔
177
                        if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil {
3✔
178
                                return 0, nil, fmt.Errorf("unable to parse "+
×
179
                                        "short chan ID: %v", err)
×
180
                        }
×
181

182
                        // We'll ensure that this short chan ID is greater than
183
                        // the last one. This is a requirement within the
184
                        // encoding, and if violated can aide us in detecting
185
                        // malicious payloads. This can only be true starting
186
                        // at the second chanID.
187
                        cid := shortChanIDs[i]
3✔
188
                        if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
3✔
UNCOV
189
                                return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
×
UNCOV
190
                        }
×
191
                        lastChanID = cid
3✔
192
                }
193

194
                return encodingType, shortChanIDs, nil
3✔
195

196
        // In this encoding, we'll use zlib to decode the compressed payload.
197
        // However, we'll pay attention to ensure that we don't open our selves
198
        // up to a memory exhaustion attack.
UNCOV
199
        case EncodingSortedZlib:
×
UNCOV
200
                // We'll obtain an ultimately release the zlib decode mutex.
×
UNCOV
201
                // This guards us against allocating too much memory to decode
×
UNCOV
202
                // each instance from concurrent peers.
×
UNCOV
203
                zlibDecodeMtx.Lock()
×
UNCOV
204
                defer zlibDecodeMtx.Unlock()
×
UNCOV
205

×
UNCOV
206
                // At this point, if there's no body remaining, then only the encoding
×
UNCOV
207
                // type was specified, meaning that there're no further bytes to be
×
UNCOV
208
                // parsed.
×
UNCOV
209
                if len(queryBody) == 0 {
×
UNCOV
210
                        return encodingType, nil, nil
×
UNCOV
211
                }
×
212

213
                // Before we start to decode, we'll create a limit reader over
214
                // the current reader. This will ensure that we can control how
215
                // much memory we're allocating during the decoding process.
UNCOV
216
                limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{
×
UNCOV
217
                        R: bytes.NewReader(queryBody),
×
UNCOV
218
                        N: maxZlibBufSize,
×
UNCOV
219
                })
×
UNCOV
220
                if err != nil {
×
UNCOV
221
                        return 0, nil, fmt.Errorf("unable to create zlib "+
×
UNCOV
222
                                "reader: %w", err)
×
UNCOV
223
                }
×
224

UNCOV
225
                var (
×
UNCOV
226
                        shortChanIDs []ShortChannelID
×
UNCOV
227
                        lastChanID   ShortChannelID
×
UNCOV
228
                        i            int
×
UNCOV
229
                )
×
UNCOV
230
                for {
×
UNCOV
231
                        // We'll now attempt to read the next short channel ID
×
UNCOV
232
                        // encoded in the payload.
×
UNCOV
233
                        var cid ShortChannelID
×
UNCOV
234
                        err := ReadElements(limitedDecompressor, &cid)
×
UNCOV
235

×
UNCOV
236
                        switch {
×
237
                        // If we get an EOF error, then that either means we've
238
                        // read all that's contained in the buffer, or have hit
239
                        // our limit on the number of bytes we'll read. In
240
                        // either case, we'll return what we have so far.
UNCOV
241
                        case err == io.ErrUnexpectedEOF || err == io.EOF:
×
UNCOV
242
                                return encodingType, shortChanIDs, nil
×
243

244
                        // Otherwise, we hit some other sort of error, possibly
245
                        // an invalid payload, so we'll exit early with the
246
                        // error.
UNCOV
247
                        case err != nil:
×
UNCOV
248
                                return 0, nil, fmt.Errorf("unable to "+
×
UNCOV
249
                                        "deflate next short chan "+
×
UNCOV
250
                                        "ID: %v", err)
×
251
                        }
252

253
                        // We successfully read the next ID, so we'll collect
254
                        // that in the set of final ID's to return.
UNCOV
255
                        shortChanIDs = append(shortChanIDs, cid)
×
UNCOV
256

×
UNCOV
257
                        // Finally, we'll ensure that this short chan ID is
×
UNCOV
258
                        // greater than the last one. This is a requirement
×
UNCOV
259
                        // within the encoding, and if violated can aide us in
×
UNCOV
260
                        // detecting malicious payloads. This can only be true
×
UNCOV
261
                        // starting at the second chanID.
×
UNCOV
262
                        if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
×
UNCOV
263
                                return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
×
UNCOV
264
                        }
×
265

UNCOV
266
                        lastChanID = cid
×
UNCOV
267
                        i++
×
268
                }
269

UNCOV
270
        default:
×
UNCOV
271
                // If we've been sent an encoding type that we don't know of,
×
UNCOV
272
                // then we'll return a parsing error as we can't continue if
×
UNCOV
273
                // we're unable to encode them.
×
UNCOV
274
                return 0, nil, ErrUnknownShortChanIDEncoding(encodingType)
×
275
        }
276
}
277

278
// Encode serializes the target QueryShortChanIDs into the passed io.Writer
279
// observing the protocol version specified.
280
//
281
// This is part of the lnwire.Message interface.
282
func (q *QueryShortChanIDs) Encode(w *bytes.Buffer, pver uint32) error {
3✔
283
        // First, we'll write out the chain hash.
3✔
284
        if err := WriteBytes(w, q.ChainHash[:]); err != nil {
3✔
285
                return err
×
286
        }
×
287

288
        // For both of the current encoding types, the channel ID's are to be
289
        // sorted in place, so we'll do that now. The sorting is applied unless
290
        // we were specifically requested not to for testing purposes.
291
        if !q.noSort {
6✔
292
                sort.Slice(q.ShortChanIDs, func(i, j int) bool {
6✔
293
                        return q.ShortChanIDs[i].ToUint64() <
3✔
294
                                q.ShortChanIDs[j].ToUint64()
3✔
295
                })
3✔
296
        }
297

298
        // Base on our encoding type, we'll write out the set of short channel
299
        // ID's.
300
        err := encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs)
3✔
301
        if err != nil {
3✔
302
                return err
×
303
        }
×
304

305
        return WriteBytes(w, q.ExtraData)
3✔
306
}
307

308
// encodeShortChanIDs encodes the passed short channel ID's into the passed
309
// io.Writer, respecting the specified encoding type.
310
func encodeShortChanIDs(w *bytes.Buffer, encodingType QueryEncoding,
311
        shortChanIDs []ShortChannelID) error {
3✔
312

3✔
313
        switch encodingType {
3✔
314

315
        // In this encoding, we'll simply write a sorted array of encoded short
316
        // channel ID's from the buffer.
317
        case EncodingSortedPlain:
3✔
318
                // First, we'll write out the number of bytes of the query
3✔
319
                // body. We add 1 as the response will have the encoding type
3✔
320
                // prepended to it.
3✔
321
                numBytesBody := uint16(len(shortChanIDs)*8) + 1
3✔
322
                if err := WriteUint16(w, numBytesBody); err != nil {
3✔
323
                        return err
×
324
                }
×
325

326
                // We'll then write out the encoding that that follows the
327
                // actual encoded short channel ID's.
328
                err := WriteQueryEncoding(w, encodingType)
3✔
329
                if err != nil {
3✔
330
                        return err
×
331
                }
×
332

333
                // Now that we know they're sorted, we can write out each short
334
                // channel ID to the buffer.
335
                for _, chanID := range shortChanIDs {
6✔
336
                        if err := WriteShortChannelID(w, chanID); err != nil {
3✔
337
                                return fmt.Errorf("unable to write short chan "+
×
338
                                        "ID: %v", err)
×
339
                        }
×
340
                }
341

342
                return nil
3✔
343

344
        // For this encoding we'll first write out a serialized version of all
345
        // the channel ID's into a buffer, then zlib encode that. The final
346
        // payload is what we'll write out to the passed io.Writer.
347
        //
348
        // TODO(roasbeef): assumes the caller knows the proper chunk size to
349
        // pass to avoid bin-packing here
UNCOV
350
        case EncodingSortedZlib:
×
UNCOV
351
                // If we don't have anything at all to write, then we'll write
×
UNCOV
352
                // an empty payload so we don't include things like the zlib
×
UNCOV
353
                // header when the remote party is expecting no actual short
×
UNCOV
354
                // channel IDs.
×
UNCOV
355
                var compressedPayload []byte
×
UNCOV
356
                if len(shortChanIDs) > 0 {
×
UNCOV
357
                        // We'll make a new write buffer to hold the bytes of
×
UNCOV
358
                        // shortChanIDs.
×
UNCOV
359
                        var wb bytes.Buffer
×
UNCOV
360

×
UNCOV
361
                        // Next, we'll write out all the channel ID's directly
×
UNCOV
362
                        // into the zlib writer, which will do compressing on
×
UNCOV
363
                        // the fly.
×
UNCOV
364
                        for _, chanID := range shortChanIDs {
×
UNCOV
365
                                err := WriteShortChannelID(&wb, chanID)
×
UNCOV
366
                                if err != nil {
×
367
                                        return fmt.Errorf(
×
368
                                                "unable to write short chan "+
×
369
                                                        "ID: %v", err,
×
370
                                        )
×
371
                                }
×
372
                        }
373

374
                        // With shortChanIDs written into wb, we'll create a
375
                        // zlib writer and write all the compressed bytes.
UNCOV
376
                        var zlibBuffer bytes.Buffer
×
UNCOV
377
                        zlibWriter := zlib.NewWriter(&zlibBuffer)
×
UNCOV
378

×
UNCOV
379
                        if _, err := zlibWriter.Write(wb.Bytes()); err != nil {
×
380
                                return fmt.Errorf(
×
381
                                        "unable to write compressed short chan"+
×
382
                                                "ID: %w", err)
×
383
                        }
×
384

385
                        // Now that we've written all the elements, we'll
386
                        // ensure the compressed stream is written to the
387
                        // underlying buffer.
UNCOV
388
                        if err := zlibWriter.Close(); err != nil {
×
389
                                return fmt.Errorf("unable to finalize "+
×
390
                                        "compression: %v", err)
×
391
                        }
×
392

UNCOV
393
                        compressedPayload = zlibBuffer.Bytes()
×
394
                }
395

396
                // Now that we have all the items compressed, we can compute
397
                // what the total payload size will be. We add one to account
398
                // for the byte to encode the type.
399
                //
400
                // If we don't have any actual bytes to write, then we'll end
401
                // up emitting one byte for the length, followed by the
402
                // encoding type, and nothing more. The spec isn't 100% clear
403
                // in this area, but we do this as this is what most of the
404
                // other implementations do.
UNCOV
405
                numBytesBody := len(compressedPayload) + 1
×
UNCOV
406

×
UNCOV
407
                // Finally, we can write out the number of bytes, the
×
UNCOV
408
                // compression type, and finally the buffer itself.
×
UNCOV
409
                if err := WriteUint16(w, uint16(numBytesBody)); err != nil {
×
410
                        return err
×
411
                }
×
UNCOV
412
                err := WriteQueryEncoding(w, encodingType)
×
UNCOV
413
                if err != nil {
×
414
                        return err
×
415
                }
×
416

UNCOV
417
                return WriteBytes(w, compressedPayload)
×
418

419
        default:
×
420
                // If we're trying to encode with an encoding type that we
×
421
                // don't know of, then we'll return a parsing error as we can't
×
422
                // continue if we're unable to encode them.
×
423
                return ErrUnknownShortChanIDEncoding(encodingType)
×
424
        }
425
}
426

427
// MsgType returns the integer uniquely identifying this message type on the
428
// wire.
429
//
430
// This is part of the lnwire.Message interface.
431
func (q *QueryShortChanIDs) MsgType() MessageType {
3✔
432
        return MsgQueryShortChanIDs
3✔
433
}
3✔
434

435
// SerializedSize returns the serialized size of the message in bytes.
436
//
437
// This is part of the lnwire.SizeableMessage interface.
438
func (q *QueryShortChanIDs) SerializedSize() (uint32, error) {
3✔
439
        return MessageSerializedSize(q)
3✔
440
}
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc