• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 14358372723

09 Apr 2025 01:26PM UTC coverage: 56.696% (-12.3%) from 69.037%
14358372723

Pull #9696

github

web-flow
Merge e2837e400 into 867d27d68
Pull Request #9696: Add `development_guidelines.md` for both human and machine

107055 of 188823 relevant lines covered (56.7%)

22721.56 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/lnrpc/websocket_proxy.go
1
// The code in this file is a heavily modified version of
2
// https://github.com/tmc/grpc-websocket-proxy/
3

4
package lnrpc
5

6
import (
7
        "bufio"
8
        "io"
9
        "net/http"
10
        "net/textproto"
11
        "regexp"
12
        "strings"
13
        "time"
14

15
        "github.com/btcsuite/btclog/v2"
16
        "github.com/gorilla/websocket"
17
        "golang.org/x/net/context"
18
)
19

20
const (
21
        // MethodOverrideParam is the GET query parameter that specifies what
22
        // HTTP request method should be used for the forwarded REST request.
23
        // This is necessary because the WebSocket API specifies that a
24
        // handshake request must always be done through a GET request.
25
        MethodOverrideParam = "method"
26

27
        // HeaderWebSocketProtocol is the name of the WebSocket protocol
28
        // exchange header field that we use to transport additional header
29
        // fields.
30
        HeaderWebSocketProtocol = "Sec-Websocket-Protocol"
31

32
        // WebSocketProtocolDelimiter is the delimiter we use between the
33
        // additional header field and its value. We use the plus symbol because
34
        // the default delimiters aren't allowed in the protocol names.
35
        WebSocketProtocolDelimiter = "+"
36

37
        // PingContent is the content of the ping message we send out. This is
38
        // an arbitrary non-empty message that has no deeper meaning but should
39
        // be sent back by the client in the pong message.
40
        PingContent = "are you there?"
41

42
        // MaxWsMsgSize is the largest websockets message we'll attempt to
43
        // decode in the gRPC <-> WS proxy. gRPC has a similar setting used
44
        // elsewhere.
45
        MaxWsMsgSize = 4 * 1024 * 1024
46
)
47

48
var (
49
        // defaultHeadersToForward is a map of all HTTP header fields that are
50
        // forwarded by default. The keys must be in the canonical MIME header
51
        // format.
52
        defaultHeadersToForward = map[string]bool{
53
                "Origin":                 true,
54
                "Referer":                true,
55
                "Grpc-Metadata-Macaroon": true,
56
        }
57

58
        // defaultProtocolsToAllow are additional header fields that we allow
59
        // to be transported inside of the Sec-Websocket-Protocol field to be
60
        // forwarded to the backend.
61
        defaultProtocolsToAllow = map[string]bool{
62
                "Grpc-Metadata-Macaroon": true,
63
        }
64

65
        // DefaultPingInterval is the default number of seconds to wait between
66
        // sending ping requests.
67
        DefaultPingInterval = time.Second * 30
68

69
        // DefaultPongWait is the maximum duration we wait for a pong response
70
        // to a ping we sent before we assume the connection died.
71
        DefaultPongWait = time.Second * 5
72
)
73

74
// NewWebSocketProxy attempts to expose the underlying handler as a response-
75
// streaming WebSocket stream with newline-delimited JSON as the content
76
// encoding. If pingInterval is a non-zero duration, a ping message will be
77
// sent out periodically and a pong response message is expected from the
78
// client. The clientStreamingURIs parameter can hold a list of all patterns
79
// for URIs that are mapped to client-streaming RPC methods. We need to keep
80
// track of those to make sure we initialize the request body correctly for the
81
// underlying grpc-gateway library.
82
func NewWebSocketProxy(h http.Handler, logger btclog.Logger,
83
        pingInterval, pongWait time.Duration,
84
        clientStreamingURIs []*regexp.Regexp) http.Handler {
×
85

×
86
        p := &WebsocketProxy{
×
87
                backend: h,
×
88
                logger:  logger,
×
89
                upgrader: &websocket.Upgrader{
×
90
                        ReadBufferSize:  1024,
×
91
                        WriteBufferSize: 1024,
×
92
                        CheckOrigin: func(r *http.Request) bool {
×
93
                                return true
×
94
                        },
×
95
                },
96
                clientStreamingURIs: clientStreamingURIs,
97
        }
98

99
        if pingInterval > 0 && pongWait > 0 {
×
100
                p.pingInterval = pingInterval
×
101
                p.pongWait = pongWait
×
102
        }
×
103

104
        return p
×
105
}
106

107
// WebsocketProxy provides websocket transport upgrade to compatible endpoints.
108
type WebsocketProxy struct {
109
        backend  http.Handler
110
        logger   btclog.Logger
111
        upgrader *websocket.Upgrader
112

113
        // clientStreamingURIs holds a list of all patterns for URIs that are
114
        // mapped to client-streaming RPC methods. We need to keep track of
115
        // those to make sure we initialize the request body correctly for the
116
        // underlying grpc-gateway library.
117
        clientStreamingURIs []*regexp.Regexp
118

119
        pingInterval time.Duration
120
        pongWait     time.Duration
121
}
122

123
// pingPongEnabled returns true if a ping interval is set to enable sending and
124
// expecting regular ping/pong messages.
125
func (p *WebsocketProxy) pingPongEnabled() bool {
×
126
        return p.pingInterval > 0 && p.pongWait > 0
×
127
}
×
128

129
// ServeHTTP handles the incoming HTTP request. If the request is an
130
// "upgradeable" WebSocket request (identified by header fields), then the
131
// WS proxy handles the request. Otherwise the request is passed directly to the
132
// underlying REST proxy.
133
func (p *WebsocketProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
×
134
        if !websocket.IsWebSocketUpgrade(r) {
×
135
                p.backend.ServeHTTP(w, r)
×
136
                return
×
137
        }
×
138
        p.upgradeToWebSocketProxy(w, r)
×
139
}
140

141
// upgradeToWebSocketProxy upgrades the incoming request to a WebSocket, reads
142
// one incoming message then streams all responses until either the client or
143
// server quit the connection.
144
func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
145
        r *http.Request) {
×
146

×
147
        conn, err := p.upgrader.Upgrade(w, r, nil)
×
148
        if err != nil {
×
149
                p.logger.Errorf("error upgrading websocket:", err)
×
150
                return
×
151
        }
×
152
        defer func() {
×
153
                err := conn.Close()
×
154
                if err != nil && !IsClosedConnError(err) {
×
155
                        p.logger.Errorf("WS: error closing upgraded conn: %v",
×
156
                                err)
×
157
                }
×
158
        }()
159

160
        ctx, cancelFn := context.WithCancel(r.Context())
×
161
        defer cancelFn()
×
162

×
163
        requestForwarder := newRequestForwardingReader()
×
164
        request, err := http.NewRequestWithContext(
×
165
                ctx, r.Method, r.URL.String(), requestForwarder,
×
166
        )
×
167
        if err != nil {
×
168
                p.logger.Errorf("WS: error preparing request:", err)
×
169
                return
×
170
        }
×
171

172
        // Allow certain headers to be forwarded, either from source headers
173
        // or the special Sec-Websocket-Protocol header field.
174
        forwardHeaders(r.Header, request.Header)
×
175

×
176
        // Also allow the target request method to be overwritten, as all
×
177
        // WebSocket establishment calls MUST be GET requests.
×
178
        if m := r.URL.Query().Get(MethodOverrideParam); m != "" {
×
179
                request.Method = m
×
180
        }
×
181

182
        // Is this a call to a client-streaming RPC method?
183
        clientStreaming := false
×
184
        for _, pattern := range p.clientStreamingURIs {
×
185
                if pattern.MatchString(r.URL.Path) {
×
186
                        clientStreaming = true
×
187
                }
×
188
        }
189

190
        responseForwarder := newResponseForwardingWriter()
×
191
        go func() {
×
192
                <-ctx.Done()
×
193
                responseForwarder.Close()
×
194
                requestForwarder.CloseWriter()
×
195
        }()
×
196

197
        go func() {
×
198
                defer cancelFn()
×
199
                p.backend.ServeHTTP(responseForwarder, request)
×
200
        }()
×
201

202
        // Read loop: Take messages from websocket and write them to the payload
203
        // channel. This needs to be its own goroutine because for non-client
204
        // streaming RPCs, the requestForwarder.Write() in the second goroutine
205
        // will block until the request has fully completed. But for the ping/
206
        // pong handler to work, we need to have an active call to
207
        // conn.ReadMessage() going on. So we make sure we have such an active
208
        // call by starting a second read as soon as the first one has
209
        // completed.
210
        payloadChannel := make(chan []byte, 1)
×
211
        go func() {
×
212
                defer cancelFn()
×
213
                defer close(payloadChannel)
×
214

×
215
                for {
×
216
                        select {
×
217
                        case <-ctx.Done():
×
218
                                return
×
219
                        default:
×
220
                        }
221

222
                        _, payload, err := conn.ReadMessage()
×
223
                        if err != nil {
×
224
                                if IsClosedConnError(err) {
×
225
                                        p.logger.Tracef("WS: socket "+
×
226
                                                "closed: %v", err)
×
227
                                        return
×
228
                                }
×
229
                                p.logger.Errorf("error reading message: %v",
×
230
                                        err)
×
231
                                return
×
232
                        }
233

234
                        select {
×
235
                        case payloadChannel <- payload:
×
236
                        case <-ctx.Done():
×
237
                                return
×
238
                        }
239
                }
240
        }()
241

242
        // Forward loop: Take messages from the incoming payload channel and
243
        // write them to the http request.
244
        go func() {
×
245
                defer cancelFn()
×
246
                for {
×
247
                        var payload []byte
×
248
                        select {
×
249
                        case <-ctx.Done():
×
250
                                return
×
251
                        case newPayload, more := <-payloadChannel:
×
252
                                if !more {
×
253
                                        p.logger.Infof("WS: incoming payload " +
×
254
                                                "chan closed")
×
255
                                        return
×
256
                                }
×
257

258
                                payload = newPayload
×
259
                        }
260

261
                        _, err := requestForwarder.Write(payload)
×
262
                        if err != nil {
×
263
                                p.logger.Errorf("WS: error writing message "+
×
264
                                        "to upstream http server: %v", err)
×
265
                                return
×
266
                        }
×
267
                        _, _ = requestForwarder.Write([]byte{'\n'})
×
268

×
269
                        // The grpc-gateway library uses a different request
×
270
                        // reader depending on whether it is a client streaming
×
271
                        // RPC or not. For a non-streaming request we need to
×
272
                        // close with EOF to signal the request was completed.
×
273
                        if !clientStreaming {
×
274
                                requestForwarder.CloseWriter()
×
275
                        }
×
276
                }
277
        }()
278

279
        // Ping write loop: Send a ping message regularly if ping/pong is
280
        // enabled.
281
        if p.pingPongEnabled() {
×
282
                // We'll send out our first ping in pingInterval. So the initial
×
283
                // deadline is that interval plus the time we allow for a
×
284
                // response to be sent.
×
285
                initialDeadline := time.Now().Add(p.pingInterval + p.pongWait)
×
286
                _ = conn.SetReadDeadline(initialDeadline)
×
287

×
288
                // Whenever a pong message comes in, we extend the deadline
×
289
                // until the next read is expected by the interval plus pong
×
290
                // wait time. Since we can never _reach_ any of the deadlines,
×
291
                // we also have to advance the deadline for the next expected
×
292
                // write to happen, in case the next thing we actually write is
×
293
                // the next ping.
×
294
                conn.SetPongHandler(func(appData string) error {
×
295
                        nextDeadline := time.Now().Add(
×
296
                                p.pingInterval + p.pongWait,
×
297
                        )
×
298
                        _ = conn.SetReadDeadline(nextDeadline)
×
299
                        _ = conn.SetWriteDeadline(nextDeadline)
×
300

×
301
                        return nil
×
302
                })
×
303
                go func() {
×
304
                        ticker := time.NewTicker(p.pingInterval)
×
305
                        defer ticker.Stop()
×
306

×
307
                        for {
×
308
                                select {
×
309
                                case <-ctx.Done():
×
310
                                        p.logger.Debug("WS: ping loop done")
×
311
                                        return
×
312

313
                                case <-ticker.C:
×
314
                                        // Writing the ping shouldn't take any
×
315
                                        // longer than we'll wait for a response
×
316
                                        // in the first place.
×
317
                                        writeDeadline := time.Now().Add(
×
318
                                                p.pongWait,
×
319
                                        )
×
320
                                        err := conn.WriteControl(
×
321
                                                websocket.PingMessage,
×
322
                                                []byte(PingContent),
×
323
                                                writeDeadline,
×
324
                                        )
×
325
                                        if err != nil {
×
326
                                                p.logger.Warnf("WS: could not "+
×
327
                                                        "send ping message: %v",
×
328
                                                        err)
×
329
                                                return
×
330
                                        }
×
331
                                }
332
                        }
333
                }()
334
        }
335

336
        // Write loop: Take messages from the response forwarder and write them
337
        // to the WebSocket.
338
        for responseForwarder.Scan() {
×
339
                if len(responseForwarder.Bytes()) == 0 {
×
340
                        p.logger.Errorf("WS: empty scan: %v",
×
341
                                responseForwarder.Err())
×
342

×
343
                        continue
×
344
                }
345

346
                err := conn.WriteMessage(
×
347
                        websocket.TextMessage, responseForwarder.Bytes(),
×
348
                )
×
349
                if err != nil {
×
350
                        p.logger.Errorf("WS: error writing message: %v", err)
×
351
                        return
×
352
                }
×
353
        }
354
        if err := responseForwarder.Err(); err != nil && !IsClosedConnError(err) {
×
355
                p.logger.Errorf("WS: scanner err: %v", err)
×
356
        }
×
357
}
358

359
// forwardHeaders forwards certain allowed header fields from the source request
360
// to the target request. Because browsers are limited in what header fields
361
// they can send on the WebSocket setup call, we also allow additional fields to
362
// be transported in the special Sec-Websocket-Protocol field.
363
func forwardHeaders(source, target http.Header) {
×
364
        // Forward allowed header fields directly.
×
365
        for header := range source {
×
366
                headerName := textproto.CanonicalMIMEHeaderKey(header)
×
367
                forward, ok := defaultHeadersToForward[headerName]
×
368
                if ok && forward {
×
369
                        target.Set(headerName, source.Get(header))
×
370
                }
×
371
        }
372

373
        // Browser aren't allowed to set custom header fields on WebSocket
374
        // requests. We need to allow them to submit the macaroon as a WS
375
        // protocol, which is the only allowed header. Set any "protocols" we
376
        // declare valid as header fields on the forwarded request.
377
        protocol := source.Get(HeaderWebSocketProtocol)
×
378
        for key := range defaultProtocolsToAllow {
×
379
                if strings.HasPrefix(protocol, key) {
×
380
                        // The format is "<protocol name>+<value>". We know the
×
381
                        // protocol string starts with the name so we only need
×
382
                        // to set the value.
×
383
                        values := strings.Split(
×
384
                                protocol, WebSocketProtocolDelimiter,
×
385
                        )
×
386
                        target.Set(key, values[1])
×
387
                }
×
388
        }
389
}
390

391
// newRequestForwardingReader creates a new request forwarding pipe.
392
func newRequestForwardingReader() *requestForwardingReader {
×
393
        r, w := io.Pipe()
×
394
        return &requestForwardingReader{
×
395
                Reader: r,
×
396
                Writer: w,
×
397
                pipeR:  r,
×
398
                pipeW:  w,
×
399
        }
×
400
}
×
401

402
// requestForwardingReader is a wrapper around io.Pipe that embeds both the
403
// io.Reader and io.Writer interface and can be closed.
404
type requestForwardingReader struct {
405
        io.Reader
406
        io.Writer
407

408
        pipeR *io.PipeReader
409
        pipeW *io.PipeWriter
410
}
411

412
// CloseWriter closes the underlying pipe writer.
413
func (r *requestForwardingReader) CloseWriter() {
×
414
        _ = r.pipeW.CloseWithError(io.EOF)
×
415
}
×
416

417
// newResponseForwardingWriter creates a new http.ResponseWriter that intercepts
418
// what's written to it and presents it through a bufio.Scanner interface.
419
func newResponseForwardingWriter() *responseForwardingWriter {
×
420
        r, w := io.Pipe()
×
421

×
422
        scanner := bufio.NewScanner(r)
×
423

×
424
        // We pass in a custom buffer for the bufio scanner to use. We'll keep
×
425
        // with a normal 64KB buffer, but allow a larger max message size,
×
426
        // which may cause buffer expansion when needed.
×
427
        buf := make([]byte, 0, bufio.MaxScanTokenSize)
×
428
        scanner.Buffer(buf, MaxWsMsgSize)
×
429

×
430
        return &responseForwardingWriter{
×
431
                Writer:  w,
×
432
                Scanner: scanner,
×
433
                pipeR:   r,
×
434
                pipeW:   w,
×
435
                header:  http.Header{},
×
436
                closed:  make(chan bool, 1),
×
437
        }
×
438
}
×
439

440
// responseForwardingWriter is a type that implements the http.ResponseWriter
441
// interface but internally forwards what's written to the writer through a pipe
442
// so it can easily be read again through the bufio.Scanner interface.
443
type responseForwardingWriter struct {
444
        io.Writer
445
        *bufio.Scanner
446

447
        pipeR *io.PipeReader
448
        pipeW *io.PipeWriter
449

450
        header http.Header
451
        code   int
452
        closed chan bool
453
}
454

455
// Write writes the given bytes to the internal pipe.
456
//
457
// NOTE: This is part of the http.ResponseWriter interface.
458
func (w *responseForwardingWriter) Write(b []byte) (int, error) {
×
459
        return w.Writer.Write(b)
×
460
}
×
461

462
// Header returns the HTTP header fields intercepted so far.
463
//
464
// NOTE: This is part of the http.ResponseWriter interface.
465
func (w *responseForwardingWriter) Header() http.Header {
×
466
        return w.header
×
467
}
×
468

469
// WriteHeader indicates that the header part of the response is now finished
470
// and sets the response code.
471
//
472
// NOTE: This is part of the http.ResponseWriter interface.
473
func (w *responseForwardingWriter) WriteHeader(code int) {
×
474
        w.code = code
×
475
}
×
476

477
// CloseNotify returns a channel that indicates if a connection was closed.
478
//
479
// NOTE: This is part of the http.CloseNotifier interface.
480
func (w *responseForwardingWriter) CloseNotify() <-chan bool {
×
481
        return w.closed
×
482
}
×
483

484
// Flush empties all buffers. We implement this to indicate to our backend that
485
// we support flushing our content. There is no actual implementation because
486
// all writes happen immediately, there is no internal buffering.
487
//
488
// NOTE: This is part of the http.Flusher interface.
489
func (w *responseForwardingWriter) Flush() {}
×
490

491
func (w *responseForwardingWriter) Close() {
×
492
        _ = w.pipeR.CloseWithError(io.EOF)
×
493
        _ = w.pipeW.CloseWithError(io.EOF)
×
494
        w.closed <- true
×
495
}
×
496

497
// IsClosedConnError is a helper function that returns true if the given error
498
// is an error indicating we are using a closed connection.
499
func IsClosedConnError(err error) bool {
×
500
        if err == nil {
×
501
                return false
×
502
        }
×
503
        if err == http.ErrServerClosed {
×
504
                return true
×
505
        }
×
506

507
        str := err.Error()
×
508
        if strings.Contains(str, "use of closed network connection") {
×
509
                return true
×
510
        }
×
511
        if strings.Contains(str, "closed pipe") {
×
512
                return true
×
513
        }
×
514
        if strings.Contains(str, "broken pipe") {
×
515
                return true
×
516
        }
×
517
        if strings.Contains(str, "connection reset by peer") {
×
518
                return true
×
519
        }
×
520
        return websocket.IsCloseError(
×
521
                err, websocket.CloseNormalClosure, websocket.CloseGoingAway,
×
522
        )
×
523
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc