• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15571403561

10 Jun 2025 10:14PM UTC coverage: 58.306% (-0.03%) from 58.331%
15571403561

push

github

web-flow
Merge pull request #9911 from ziggie1984/exit-early-for-fwd-adds

htlcswitch: exit early if no adds are in the fwd pkg

9 of 9 new or added lines in 1 file covered. (100.0%)

96 existing lines in 10 files now uncovered.

97659 of 167493 relevant lines covered (58.31%)

1.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.58
/lnrpc/websocket_proxy.go
1
// The code in this file is a heavily modified version of
2
// https://github.com/tmc/grpc-websocket-proxy/
3

4
package lnrpc
5

6
import (
7
        "bufio"
8
        "io"
9
        "net/http"
10
        "net/textproto"
11
        "regexp"
12
        "strings"
13
        "time"
14

15
        "github.com/btcsuite/btclog/v2"
16
        "github.com/gorilla/websocket"
17
        "golang.org/x/net/context"
18
)
19

20
const (
21
        // MethodOverrideParam is the GET query parameter that specifies what
22
        // HTTP request method should be used for the forwarded REST request.
23
        // This is necessary because the WebSocket API specifies that a
24
        // handshake request must always be done through a GET request.
25
        MethodOverrideParam = "method"
26

27
        // HeaderWebSocketProtocol is the name of the WebSocket protocol
28
        // exchange header field that we use to transport additional header
29
        // fields.
30
        HeaderWebSocketProtocol = "Sec-Websocket-Protocol"
31

32
        // WebSocketProtocolDelimiter is the delimiter we use between the
33
        // additional header field and its value. We use the plus symbol because
34
        // the default delimiters aren't allowed in the protocol names.
35
        WebSocketProtocolDelimiter = "+"
36

37
        // PingContent is the content of the ping message we send out. This is
38
        // an arbitrary non-empty message that has no deeper meaning but should
39
        // be sent back by the client in the pong message.
40
        PingContent = "are you there?"
41

42
        // MaxWsMsgSize is the largest websockets message we'll attempt to
43
        // decode in the gRPC <-> WS proxy. gRPC has a similar setting used
44
        // elsewhere.
45
        MaxWsMsgSize = 4 * 1024 * 1024
46
)
47

48
var (
49
        // defaultHeadersToForward is a map of all HTTP header fields that are
50
        // forwarded by default. The keys must be in the canonical MIME header
51
        // format.
52
        defaultHeadersToForward = map[string]bool{
53
                "Origin":                 true,
54
                "Referer":                true,
55
                "Grpc-Metadata-Macaroon": true,
56
        }
57

58
        // defaultProtocolsToAllow are additional header fields that we allow
59
        // to be transported inside of the Sec-Websocket-Protocol field to be
60
        // forwarded to the backend.
61
        defaultProtocolsToAllow = map[string]bool{
62
                "Grpc-Metadata-Macaroon": true,
63
        }
64

65
        // DefaultPingInterval is the default number of seconds to wait between
66
        // sending ping requests.
67
        DefaultPingInterval = time.Second * 30
68

69
        // DefaultPongWait is the maximum duration we wait for a pong response
70
        // to a ping we sent before we assume the connection died.
71
        DefaultPongWait = time.Second * 5
72
)
73

74
// NewWebSocketProxy attempts to expose the underlying handler as a response-
75
// streaming WebSocket stream with newline-delimited JSON as the content
76
// encoding. If pingInterval is a non-zero duration, a ping message will be
77
// sent out periodically and a pong response message is expected from the
78
// client. The clientStreamingURIs parameter can hold a list of all patterns
79
// for URIs that are mapped to client-streaming RPC methods. We need to keep
80
// track of those to make sure we initialize the request body correctly for the
81
// underlying grpc-gateway library.
82
func NewWebSocketProxy(h http.Handler, logger btclog.Logger,
83
        pingInterval, pongWait time.Duration,
84
        clientStreamingURIs []*regexp.Regexp) http.Handler {
3✔
85

3✔
86
        p := &WebsocketProxy{
3✔
87
                backend: h,
3✔
88
                logger:  logger,
3✔
89
                upgrader: &websocket.Upgrader{
3✔
90
                        ReadBufferSize:  1024,
3✔
91
                        WriteBufferSize: 1024,
3✔
92
                        CheckOrigin: func(r *http.Request) bool {
6✔
93
                                return true
3✔
94
                        },
3✔
95
                },
96
                clientStreamingURIs: clientStreamingURIs,
97
        }
98

99
        if pingInterval > 0 && pongWait > 0 {
6✔
100
                p.pingInterval = pingInterval
3✔
101
                p.pongWait = pongWait
3✔
102
        }
3✔
103

104
        return p
3✔
105
}
106

107
// WebsocketProxy provides websocket transport upgrade to compatible endpoints.
108
type WebsocketProxy struct {
109
        backend  http.Handler
110
        logger   btclog.Logger
111
        upgrader *websocket.Upgrader
112

113
        // clientStreamingURIs holds a list of all patterns for URIs that are
114
        // mapped to client-streaming RPC methods. We need to keep track of
115
        // those to make sure we initialize the request body correctly for the
116
        // underlying grpc-gateway library.
117
        clientStreamingURIs []*regexp.Regexp
118

119
        pingInterval time.Duration
120
        pongWait     time.Duration
121
}
122

123
// pingPongEnabled returns true if a ping interval is set to enable sending and
124
// expecting regular ping/pong messages.
125
func (p *WebsocketProxy) pingPongEnabled() bool {
3✔
126
        return p.pingInterval > 0 && p.pongWait > 0
3✔
127
}
3✔
128

129
// ServeHTTP handles the incoming HTTP request. If the request is an
130
// "upgradeable" WebSocket request (identified by header fields), then the
131
// WS proxy handles the request. Otherwise the request is passed directly to the
132
// underlying REST proxy.
133
func (p *WebsocketProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
3✔
134
        if !websocket.IsWebSocketUpgrade(r) {
6✔
135
                p.backend.ServeHTTP(w, r)
3✔
136
                return
3✔
137
        }
3✔
138
        p.upgradeToWebSocketProxy(w, r)
3✔
139
}
140

141
// upgradeToWebSocketProxy upgrades the incoming request to a WebSocket, reads
142
// one incoming message then streams all responses until either the client or
143
// server quit the connection.
144
func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
145
        r *http.Request) {
3✔
146

3✔
147
        conn, err := p.upgrader.Upgrade(w, r, nil)
3✔
148
        if err != nil {
3✔
149
                p.logger.Errorf("error upgrading websocket:", err)
×
150
                return
×
151
        }
×
152
        defer func() {
6✔
153
                err := conn.Close()
3✔
154
                if err != nil && !IsClosedConnError(err) {
3✔
155
                        p.logger.Errorf("WS: error closing upgraded conn: %v",
×
156
                                err)
×
157
                }
×
158
        }()
159

160
        ctx, cancelFn := context.WithCancel(r.Context())
3✔
161
        defer cancelFn()
3✔
162

3✔
163
        requestForwarder := newRequestForwardingReader()
3✔
164
        request, err := http.NewRequestWithContext(
3✔
165
                ctx, r.Method, r.URL.String(), requestForwarder,
3✔
166
        )
3✔
167
        if err != nil {
3✔
168
                p.logger.Errorf("WS: error preparing request:", err)
×
169
                return
×
170
        }
×
171

172
        // Allow certain headers to be forwarded, either from source headers
173
        // or the special Sec-Websocket-Protocol header field.
174
        forwardHeaders(r.Header, request.Header)
3✔
175

3✔
176
        // Also allow the target request method to be overwritten, as all
3✔
177
        // WebSocket establishment calls MUST be GET requests.
3✔
178
        if m := r.URL.Query().Get(MethodOverrideParam); m != "" {
6✔
179
                request.Method = m
3✔
180
        }
3✔
181

182
        // Is this a call to a client-streaming RPC method?
183
        clientStreaming := false
3✔
184
        for _, pattern := range p.clientStreamingURIs {
6✔
185
                if pattern.MatchString(r.URL.Path) {
6✔
186
                        clientStreaming = true
3✔
187
                }
3✔
188
        }
189

190
        responseForwarder := newResponseForwardingWriter()
3✔
191
        go func() {
6✔
192
                <-ctx.Done()
3✔
193
                responseForwarder.Close()
3✔
194
                requestForwarder.CloseWriter()
3✔
195
        }()
3✔
196

197
        go func() {
6✔
198
                defer cancelFn()
3✔
199
                p.backend.ServeHTTP(responseForwarder, request)
3✔
200
        }()
3✔
201

202
        // Read loop: Take messages from websocket and write them to the payload
203
        // channel. This needs to be its own goroutine because for non-client
204
        // streaming RPCs, the requestForwarder.Write() in the second goroutine
205
        // will block until the request has fully completed. But for the ping/
206
        // pong handler to work, we need to have an active call to
207
        // conn.ReadMessage() going on. So we make sure we have such an active
208
        // call by starting a second read as soon as the first one has
209
        // completed.
210
        payloadChannel := make(chan []byte, 1)
3✔
211
        go func() {
6✔
212
                defer cancelFn()
3✔
213
                defer close(payloadChannel)
3✔
214

3✔
215
                for {
6✔
216
                        select {
3✔
217
                        case <-ctx.Done():
×
218
                                return
×
219
                        default:
3✔
220
                        }
221

222
                        _, payload, err := conn.ReadMessage()
3✔
223
                        if err != nil {
6✔
224
                                if IsClosedConnError(err) {
6✔
225
                                        p.logger.Tracef("WS: socket "+
3✔
226
                                                "closed: %v", err)
3✔
227
                                        return
3✔
228
                                }
3✔
229
                                p.logger.Errorf("error reading message: %v",
×
230
                                        err)
×
231
                                return
×
232
                        }
233

234
                        select {
3✔
235
                        case payloadChannel <- payload:
3✔
236
                        case <-ctx.Done():
×
237
                                return
×
238
                        }
239
                }
240
        }()
241

242
        // Forward loop: Take messages from the incoming payload channel and
243
        // write them to the http request.
244
        go func() {
6✔
245
                defer cancelFn()
3✔
246
                for {
6✔
247
                        var payload []byte
3✔
248
                        select {
3✔
249
                        case <-ctx.Done():
×
250
                                return
×
251
                        case newPayload, more := <-payloadChannel:
3✔
252
                                if !more {
6✔
253
                                        p.logger.Infof("WS: incoming payload " +
3✔
254
                                                "chan closed")
3✔
255
                                        return
3✔
256
                                }
3✔
257

258
                                payload = newPayload
3✔
259
                        }
260

261
                        _, err := requestForwarder.Write(payload)
3✔
262
                        if err != nil {
6✔
263
                                p.logger.Errorf("WS: error writing message "+
3✔
264
                                        "to upstream http server: %v", err)
3✔
265
                                return
3✔
266
                        }
3✔
267
                        _, _ = requestForwarder.Write([]byte{'\n'})
3✔
268

3✔
269
                        // The grpc-gateway library uses a different request
3✔
270
                        // reader depending on whether it is a client streaming
3✔
271
                        // RPC or not. For a non-streaming request we need to
3✔
272
                        // close with EOF to signal the request was completed.
3✔
273
                        if !clientStreaming {
6✔
274
                                requestForwarder.CloseWriter()
3✔
275
                        }
3✔
276
                }
277
        }()
278

279
        // Ping write loop: Send a ping message regularly if ping/pong is
280
        // enabled.
281
        if p.pingPongEnabled() {
6✔
282
                // We'll send out our first ping in pingInterval. So the initial
3✔
283
                // deadline is that interval plus the time we allow for a
3✔
284
                // response to be sent.
3✔
285
                initialDeadline := time.Now().Add(p.pingInterval + p.pongWait)
3✔
286
                _ = conn.SetReadDeadline(initialDeadline)
3✔
287

3✔
288
                // Whenever a pong message comes in, we extend the deadline
3✔
289
                // until the next read is expected by the interval plus pong
3✔
290
                // wait time. Since we can never _reach_ any of the deadlines,
3✔
291
                // we also have to advance the deadline for the next expected
3✔
292
                // write to happen, in case the next thing we actually write is
3✔
293
                // the next ping.
3✔
294
                conn.SetPongHandler(func(appData string) error {
3✔
295
                        nextDeadline := time.Now().Add(
×
296
                                p.pingInterval + p.pongWait,
×
297
                        )
×
298
                        _ = conn.SetReadDeadline(nextDeadline)
×
299
                        _ = conn.SetWriteDeadline(nextDeadline)
×
300

×
301
                        return nil
×
302
                })
×
303
                go func() {
6✔
304
                        ticker := time.NewTicker(p.pingInterval)
3✔
305
                        defer ticker.Stop()
3✔
306

3✔
307
                        for {
6✔
308
                                select {
3✔
309
                                case <-ctx.Done():
3✔
310
                                        p.logger.Debug("WS: ping loop done")
3✔
311
                                        return
3✔
312

313
                                case <-ticker.C:
×
314
                                        // Writing the ping shouldn't take any
×
315
                                        // longer than we'll wait for a response
×
316
                                        // in the first place.
×
317
                                        writeDeadline := time.Now().Add(
×
318
                                                p.pongWait,
×
319
                                        )
×
320
                                        err := conn.WriteControl(
×
321
                                                websocket.PingMessage,
×
322
                                                []byte(PingContent),
×
323
                                                writeDeadline,
×
324
                                        )
×
325
                                        if err != nil {
×
326
                                                p.logger.Warnf("WS: could not "+
×
327
                                                        "send ping message: %v",
×
328
                                                        err)
×
329
                                                return
×
330
                                        }
×
331
                                }
332
                        }
333
                }()
334
        }
335

336
        // Write loop: Take messages from the response forwarder and write them
337
        // to the WebSocket.
338
        for responseForwarder.Scan() {
6✔
339
                if len(responseForwarder.Bytes()) == 0 {
3✔
340
                        p.logger.Errorf("WS: empty scan: %v",
×
341
                                responseForwarder.Err())
×
342

×
343
                        continue
×
344
                }
345

346
                err := conn.WriteMessage(
3✔
347
                        websocket.TextMessage, responseForwarder.Bytes(),
3✔
348
                )
3✔
349
                if err != nil {
3✔
UNCOV
350
                        p.logger.Errorf("WS: error writing message: %v", err)
×
UNCOV
351
                        return
×
UNCOV
352
                }
×
353
        }
354
        if err := responseForwarder.Err(); err != nil && !IsClosedConnError(err) {
3✔
355
                p.logger.Errorf("WS: scanner err: %v", err)
×
356
        }
×
357
}
358

359
// forwardHeaders forwards certain allowed header fields from the source request
360
// to the target request. Because browsers are limited in what header fields
361
// they can send on the WebSocket setup call, we also allow additional fields to
362
// be transported in the special Sec-Websocket-Protocol field.
363
func forwardHeaders(source, target http.Header) {
3✔
364
        // Forward allowed header fields directly.
3✔
365
        for header := range source {
6✔
366
                headerName := textproto.CanonicalMIMEHeaderKey(header)
3✔
367
                forward, ok := defaultHeadersToForward[headerName]
3✔
368
                if ok && forward {
6✔
369
                        target.Set(headerName, source.Get(header))
3✔
370
                }
3✔
371
        }
372

373
        // Browser aren't allowed to set custom header fields on WebSocket
374
        // requests. We need to allow them to submit the macaroon as a WS
375
        // protocol, which is the only allowed header. Set any "protocols" we
376
        // declare valid as header fields on the forwarded request.
377
        protocol := source.Get(HeaderWebSocketProtocol)
3✔
378
        for key := range defaultProtocolsToAllow {
6✔
379
                if strings.HasPrefix(protocol, key) {
6✔
380
                        // The format is "<protocol name>+<value>". We know the
3✔
381
                        // protocol string starts with the name so we only need
3✔
382
                        // to set the value.
3✔
383
                        values := strings.Split(
3✔
384
                                protocol, WebSocketProtocolDelimiter,
3✔
385
                        )
3✔
386
                        target.Set(key, values[1])
3✔
387
                }
3✔
388
        }
389
}
390

391
// newRequestForwardingReader creates a new request forwarding pipe.
392
func newRequestForwardingReader() *requestForwardingReader {
3✔
393
        r, w := io.Pipe()
3✔
394
        return &requestForwardingReader{
3✔
395
                Reader: r,
3✔
396
                Writer: w,
3✔
397
                pipeR:  r,
3✔
398
                pipeW:  w,
3✔
399
        }
3✔
400
}
3✔
401

402
// requestForwardingReader is a wrapper around io.Pipe that embeds both the
403
// io.Reader and io.Writer interface and can be closed.
404
type requestForwardingReader struct {
405
        io.Reader
406
        io.Writer
407

408
        pipeR *io.PipeReader
409
        pipeW *io.PipeWriter
410
}
411

412
// CloseWriter closes the underlying pipe writer.
413
func (r *requestForwardingReader) CloseWriter() {
3✔
414
        _ = r.pipeW.CloseWithError(io.EOF)
3✔
415
}
3✔
416

417
// newResponseForwardingWriter creates a new http.ResponseWriter that intercepts
418
// what's written to it and presents it through a bufio.Scanner interface.
419
func newResponseForwardingWriter() *responseForwardingWriter {
3✔
420
        r, w := io.Pipe()
3✔
421

3✔
422
        scanner := bufio.NewScanner(r)
3✔
423

3✔
424
        // We pass in a custom buffer for the bufio scanner to use. We'll keep
3✔
425
        // with a normal 64KB buffer, but allow a larger max message size,
3✔
426
        // which may cause buffer expansion when needed.
3✔
427
        buf := make([]byte, 0, bufio.MaxScanTokenSize)
3✔
428
        scanner.Buffer(buf, MaxWsMsgSize)
3✔
429

3✔
430
        return &responseForwardingWriter{
3✔
431
                Writer:  w,
3✔
432
                Scanner: scanner,
3✔
433
                pipeR:   r,
3✔
434
                pipeW:   w,
3✔
435
                header:  http.Header{},
3✔
436
                closed:  make(chan bool, 1),
3✔
437
        }
3✔
438
}
3✔
439

440
// responseForwardingWriter is a type that implements the http.ResponseWriter
441
// interface but internally forwards what's written to the writer through a pipe
442
// so it can easily be read again through the bufio.Scanner interface.
443
type responseForwardingWriter struct {
444
        io.Writer
445
        *bufio.Scanner
446

447
        pipeR *io.PipeReader
448
        pipeW *io.PipeWriter
449

450
        header http.Header
451
        code   int
452
        closed chan bool
453
}
454

455
// Write writes the given bytes to the internal pipe.
456
//
457
// NOTE: This is part of the http.ResponseWriter interface.
458
func (w *responseForwardingWriter) Write(b []byte) (int, error) {
3✔
459
        return w.Writer.Write(b)
3✔
460
}
3✔
461

462
// Header returns the HTTP header fields intercepted so far.
463
//
464
// NOTE: This is part of the http.ResponseWriter interface.
465
func (w *responseForwardingWriter) Header() http.Header {
3✔
466
        return w.header
3✔
467
}
3✔
468

469
// WriteHeader indicates that the header part of the response is now finished
470
// and sets the response code.
471
//
472
// NOTE: This is part of the http.ResponseWriter interface.
473
func (w *responseForwardingWriter) WriteHeader(code int) {
×
474
        w.code = code
×
475
}
×
476

477
// CloseNotify returns a channel that indicates if a connection was closed.
478
//
479
// NOTE: This is part of the http.CloseNotifier interface.
480
func (w *responseForwardingWriter) CloseNotify() <-chan bool {
×
481
        return w.closed
×
482
}
×
483

484
// Flush empties all buffers. We implement this to indicate to our backend that
485
// we support flushing our content. There is no actual implementation because
486
// all writes happen immediately, there is no internal buffering.
487
//
488
// NOTE: This is part of the http.Flusher interface.
489
func (w *responseForwardingWriter) Flush() {}
3✔
490

491
func (w *responseForwardingWriter) Close() {
3✔
492
        _ = w.pipeR.CloseWithError(io.EOF)
3✔
493
        _ = w.pipeW.CloseWithError(io.EOF)
3✔
494
        w.closed <- true
3✔
495
}
3✔
496

497
// IsClosedConnError is a helper function that returns true if the given error
498
// is an error indicating we are using a closed connection.
499
func IsClosedConnError(err error) bool {
3✔
500
        if err == nil {
3✔
501
                return false
×
502
        }
×
503
        if err == http.ErrServerClosed {
3✔
504
                return true
×
505
        }
×
506

507
        str := err.Error()
3✔
508
        if strings.Contains(str, "use of closed network connection") {
6✔
509
                return true
3✔
510
        }
3✔
511
        if strings.Contains(str, "closed pipe") {
6✔
512
                return true
3✔
513
        }
3✔
514
        if strings.Contains(str, "broken pipe") {
6✔
515
                return true
3✔
516
        }
3✔
517
        if strings.Contains(str, "connection reset by peer") {
6✔
518
                return true
3✔
519
        }
3✔
520
        return websocket.IsCloseError(
3✔
521
                err, websocket.CloseNormalClosure, websocket.CloseGoingAway,
3✔
522
        )
3✔
523
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc