• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 10203737448

01 Aug 2024 06:26PM UTC coverage: 58.674% (+0.05%) from 58.627%
10203737448

push

github

web-flow
Merge pull request #8938 from bhandras/etcd-leader-election-fixups

multi: check leader status with our health checker to correctly shut down LND if network partitions

28 of 73 new or added lines in 6 files covered. (38.36%)

117 existing lines in 18 files now uncovered.

125392 of 213710 relevant lines covered (58.67%)

28078.2 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.58
/lnrpc/websocket_proxy.go
1
// The code in this file is a heavily modified version of
2
// https://github.com/tmc/grpc-websocket-proxy/
3

4
package lnrpc
5

6
import (
7
        "bufio"
8
        "io"
9
        "net/http"
10
        "net/textproto"
11
        "regexp"
12
        "strings"
13
        "time"
14

15
        "github.com/btcsuite/btclog"
16
        "github.com/gorilla/websocket"
17
        "golang.org/x/net/context"
18
)
19

20
const (
21
        // MethodOverrideParam is the GET query parameter that specifies what
22
        // HTTP request method should be used for the forwarded REST request.
23
        // This is necessary because the WebSocket API specifies that a
24
        // handshake request must always be done through a GET request.
25
        MethodOverrideParam = "method"
26

27
        // HeaderWebSocketProtocol is the name of the WebSocket protocol
28
        // exchange header field that we use to transport additional header
29
        // fields.
30
        HeaderWebSocketProtocol = "Sec-Websocket-Protocol"
31

32
        // WebSocketProtocolDelimiter is the delimiter we use between the
33
        // additional header field and its value. We use the plus symbol because
34
        // the default delimiters aren't allowed in the protocol names.
35
        WebSocketProtocolDelimiter = "+"
36

37
        // PingContent is the content of the ping message we send out. This is
38
        // an arbitrary non-empty message that has no deeper meaning but should
39
        // be sent back by the client in the pong message.
40
        PingContent = "are you there?"
41

42
        // MaxWsMsgSize is the largest websockets message we'll attempt to
43
        // decode in the gRPC <-> WS proxy. gRPC has a similar setting used
44
        // elsewhere.
45
        MaxWsMsgSize = 4 * 1024 * 1024
46
)
47

48
var (
49
        // defaultHeadersToForward is a map of all HTTP header fields that are
50
        // forwarded by default. The keys must be in the canonical MIME header
51
        // format.
52
        defaultHeadersToForward = map[string]bool{
53
                "Origin":                 true,
54
                "Referer":                true,
55
                "Grpc-Metadata-Macaroon": true,
56
        }
57

58
        // defaultProtocolsToAllow are additional header fields that we allow
59
        // to be transported inside of the Sec-Websocket-Protocol field to be
60
        // forwarded to the backend.
61
        defaultProtocolsToAllow = map[string]bool{
62
                "Grpc-Metadata-Macaroon": true,
63
        }
64

65
        // DefaultPingInterval is the default number of seconds to wait between
66
        // sending ping requests.
67
        DefaultPingInterval = time.Second * 30
68

69
        // DefaultPongWait is the maximum duration we wait for a pong response
70
        // to a ping we sent before we assume the connection died.
71
        DefaultPongWait = time.Second * 5
72
)
73

74
// NewWebSocketProxy attempts to expose the underlying handler as a response-
75
// streaming WebSocket stream with newline-delimited JSON as the content
76
// encoding. If pingInterval is a non-zero duration, a ping message will be
77
// sent out periodically and a pong response message is expected from the
78
// client. The clientStreamingURIs parameter can hold a list of all patterns
79
// for URIs that are mapped to client-streaming RPC methods. We need to keep
80
// track of those to make sure we initialize the request body correctly for the
81
// underlying grpc-gateway library.
82
func NewWebSocketProxy(h http.Handler, logger btclog.Logger,
83
        pingInterval, pongWait time.Duration,
84
        clientStreamingURIs []*regexp.Regexp) http.Handler {
4✔
85

4✔
86
        p := &WebsocketProxy{
4✔
87
                backend: h,
4✔
88
                logger:  logger,
4✔
89
                upgrader: &websocket.Upgrader{
4✔
90
                        ReadBufferSize:  1024,
4✔
91
                        WriteBufferSize: 1024,
4✔
92
                        CheckOrigin: func(r *http.Request) bool {
8✔
93
                                return true
4✔
94
                        },
4✔
95
                },
96
                clientStreamingURIs: clientStreamingURIs,
97
        }
98

99
        if pingInterval > 0 && pongWait > 0 {
8✔
100
                p.pingInterval = pingInterval
4✔
101
                p.pongWait = pongWait
4✔
102
        }
4✔
103

104
        return p
4✔
105
}
106

107
// WebsocketProxy provides websocket transport upgrade to compatible endpoints.
108
type WebsocketProxy struct {
109
        backend  http.Handler
110
        logger   btclog.Logger
111
        upgrader *websocket.Upgrader
112

113
        // clientStreamingURIs holds a list of all patterns for URIs that are
114
        // mapped to client-streaming RPC methods. We need to keep track of
115
        // those to make sure we initialize the request body correctly for the
116
        // underlying grpc-gateway library.
117
        clientStreamingURIs []*regexp.Regexp
118

119
        pingInterval time.Duration
120
        pongWait     time.Duration
121
}
122

123
// pingPongEnabled returns true if a ping interval is set to enable sending and
124
// expecting regular ping/pong messages.
125
func (p *WebsocketProxy) pingPongEnabled() bool {
4✔
126
        return p.pingInterval > 0 && p.pongWait > 0
4✔
127
}
4✔
128

129
// ServeHTTP handles the incoming HTTP request. If the request is an
130
// "upgradeable" WebSocket request (identified by header fields), then the
131
// WS proxy handles the request. Otherwise the request is passed directly to the
132
// underlying REST proxy.
133
func (p *WebsocketProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
4✔
134
        if !websocket.IsWebSocketUpgrade(r) {
8✔
135
                p.backend.ServeHTTP(w, r)
4✔
136
                return
4✔
137
        }
4✔
138
        p.upgradeToWebSocketProxy(w, r)
4✔
139
}
140

141
// upgradeToWebSocketProxy upgrades the incoming request to a WebSocket, reads
142
// one incoming message then streams all responses until either the client or
143
// server quit the connection.
144
func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
145
        r *http.Request) {
4✔
146

4✔
147
        conn, err := p.upgrader.Upgrade(w, r, nil)
4✔
148
        if err != nil {
4✔
149
                p.logger.Errorf("error upgrading websocket:", err)
×
150
                return
×
151
        }
×
152
        defer func() {
8✔
153
                err := conn.Close()
4✔
154
                if err != nil && !IsClosedConnError(err) {
4✔
155
                        p.logger.Errorf("WS: error closing upgraded conn: %v",
×
156
                                err)
×
157
                }
×
158
        }()
159

160
        ctx, cancelFn := context.WithCancel(r.Context())
4✔
161
        defer cancelFn()
4✔
162

4✔
163
        requestForwarder := newRequestForwardingReader()
4✔
164
        request, err := http.NewRequestWithContext(
4✔
165
                ctx, r.Method, r.URL.String(), requestForwarder,
4✔
166
        )
4✔
167
        if err != nil {
4✔
168
                p.logger.Errorf("WS: error preparing request:", err)
×
169
                return
×
170
        }
×
171

172
        // Allow certain headers to be forwarded, either from source headers
173
        // or the special Sec-Websocket-Protocol header field.
174
        forwardHeaders(r.Header, request.Header)
4✔
175

4✔
176
        // Also allow the target request method to be overwritten, as all
4✔
177
        // WebSocket establishment calls MUST be GET requests.
4✔
178
        if m := r.URL.Query().Get(MethodOverrideParam); m != "" {
8✔
179
                request.Method = m
4✔
180
        }
4✔
181

182
        // Is this a call to a client-streaming RPC method?
183
        clientStreaming := false
4✔
184
        for _, pattern := range p.clientStreamingURIs {
8✔
185
                if pattern.MatchString(r.URL.Path) {
8✔
186
                        clientStreaming = true
4✔
187
                }
4✔
188
        }
189

190
        responseForwarder := newResponseForwardingWriter()
4✔
191
        go func() {
8✔
192
                <-ctx.Done()
4✔
193
                responseForwarder.Close()
4✔
194
                requestForwarder.CloseWriter()
4✔
195
        }()
4✔
196

197
        go func() {
8✔
198
                defer cancelFn()
4✔
199
                p.backend.ServeHTTP(responseForwarder, request)
4✔
200
        }()
4✔
201

202
        // Read loop: Take messages from websocket and write them to the payload
203
        // channel. This needs to be its own goroutine because for non-client
204
        // streaming RPCs, the requestForwarder.Write() in the second goroutine
205
        // will block until the request has fully completed. But for the ping/
206
        // pong handler to work, we need to have an active call to
207
        // conn.ReadMessage() going on. So we make sure we have such an active
208
        // call by starting a second read as soon as the first one has
209
        // completed.
210
        payloadChannel := make(chan []byte, 1)
4✔
211
        go func() {
8✔
212
                defer cancelFn()
4✔
213
                defer close(payloadChannel)
4✔
214

4✔
215
                for {
8✔
216
                        select {
4✔
217
                        case <-ctx.Done():
×
218
                                return
×
219
                        default:
4✔
220
                        }
221

222
                        _, payload, err := conn.ReadMessage()
4✔
223
                        if err != nil {
8✔
224
                                if IsClosedConnError(err) {
8✔
225
                                        p.logger.Tracef("WS: socket "+
4✔
226
                                                "closed: %v", err)
4✔
227
                                        return
4✔
228
                                }
4✔
229
                                p.logger.Errorf("error reading message: %v",
×
230
                                        err)
×
231
                                return
×
232
                        }
233

234
                        select {
4✔
235
                        case payloadChannel <- payload:
4✔
236
                        case <-ctx.Done():
×
237
                                return
×
238
                        }
239
                }
240
        }()
241

242
        // Forward loop: Take messages from the incoming payload channel and
243
        // write them to the http request.
244
        go func() {
8✔
245
                defer cancelFn()
4✔
246
                for {
8✔
247
                        var payload []byte
4✔
248
                        select {
4✔
249
                        case <-ctx.Done():
×
250
                                return
×
251
                        case newPayload, more := <-payloadChannel:
4✔
252
                                if !more {
8✔
253
                                        p.logger.Infof("WS: incoming payload " +
4✔
254
                                                "chan closed")
4✔
255
                                        return
4✔
256
                                }
4✔
257

258
                                payload = newPayload
4✔
259
                        }
260

261
                        _, err := requestForwarder.Write(payload)
4✔
262
                        if err != nil {
8✔
263
                                p.logger.Errorf("WS: error writing message "+
4✔
264
                                        "to upstream http server: %v", err)
4✔
265
                                return
4✔
266
                        }
4✔
267
                        _, _ = requestForwarder.Write([]byte{'\n'})
4✔
268

4✔
269
                        // The grpc-gateway library uses a different request
4✔
270
                        // reader depending on whether it is a client streaming
4✔
271
                        // RPC or not. For a non-streaming request we need to
4✔
272
                        // close with EOF to signal the request was completed.
4✔
273
                        if !clientStreaming {
8✔
274
                                requestForwarder.CloseWriter()
4✔
275
                        }
4✔
276
                }
277
        }()
278

279
        // Ping write loop: Send a ping message regularly if ping/pong is
280
        // enabled.
281
        if p.pingPongEnabled() {
8✔
282
                // We'll send out our first ping in pingInterval. So the initial
4✔
283
                // deadline is that interval plus the time we allow for a
4✔
284
                // response to be sent.
4✔
285
                initialDeadline := time.Now().Add(p.pingInterval + p.pongWait)
4✔
286
                _ = conn.SetReadDeadline(initialDeadline)
4✔
287

4✔
288
                // Whenever a pong message comes in, we extend the deadline
4✔
289
                // until the next read is expected by the interval plus pong
4✔
290
                // wait time. Since we can never _reach_ any of the deadlines,
4✔
291
                // we also have to advance the deadline for the next expected
4✔
292
                // write to happen, in case the next thing we actually write is
4✔
293
                // the next ping.
4✔
294
                conn.SetPongHandler(func(appData string) error {
4✔
UNCOV
295
                        nextDeadline := time.Now().Add(
×
UNCOV
296
                                p.pingInterval + p.pongWait,
×
UNCOV
297
                        )
×
UNCOV
298
                        _ = conn.SetReadDeadline(nextDeadline)
×
UNCOV
299
                        _ = conn.SetWriteDeadline(nextDeadline)
×
UNCOV
300

×
UNCOV
301
                        return nil
×
UNCOV
302
                })
×
303
                go func() {
8✔
304
                        ticker := time.NewTicker(p.pingInterval)
4✔
305
                        defer ticker.Stop()
4✔
306

4✔
307
                        for {
8✔
308
                                select {
4✔
309
                                case <-ctx.Done():
4✔
310
                                        p.logger.Debug("WS: ping loop done")
4✔
311
                                        return
4✔
312

UNCOV
313
                                case <-ticker.C:
×
UNCOV
314
                                        // Writing the ping shouldn't take any
×
UNCOV
315
                                        // longer than we'll wait for a response
×
UNCOV
316
                                        // in the first place.
×
UNCOV
317
                                        writeDeadline := time.Now().Add(
×
UNCOV
318
                                                p.pongWait,
×
UNCOV
319
                                        )
×
UNCOV
320
                                        err := conn.WriteControl(
×
UNCOV
321
                                                websocket.PingMessage,
×
UNCOV
322
                                                []byte(PingContent),
×
UNCOV
323
                                                writeDeadline,
×
UNCOV
324
                                        )
×
UNCOV
325
                                        if err != nil {
×
326
                                                p.logger.Warnf("WS: could not "+
×
327
                                                        "send ping message: %v",
×
328
                                                        err)
×
329
                                                return
×
330
                                        }
×
331
                                }
332
                        }
333
                }()
334
        }
335

336
        // Write loop: Take messages from the response forwarder and write them
337
        // to the WebSocket.
338
        for responseForwarder.Scan() {
8✔
339
                if len(responseForwarder.Bytes()) == 0 {
4✔
340
                        p.logger.Errorf("WS: empty scan: %v",
×
341
                                responseForwarder.Err())
×
342

×
343
                        continue
×
344
                }
345

346
                err := conn.WriteMessage(
4✔
347
                        websocket.TextMessage, responseForwarder.Bytes(),
4✔
348
                )
4✔
349
                if err != nil {
4✔
350
                        p.logger.Errorf("WS: error writing message: %v", err)
×
351
                        return
×
352
                }
×
353
        }
354
        if err := responseForwarder.Err(); err != nil && !IsClosedConnError(err) {
4✔
355
                p.logger.Errorf("WS: scanner err: %v", err)
×
356
        }
×
357
}
358

359
// forwardHeaders forwards certain allowed header fields from the source request
360
// to the target request. Because browsers are limited in what header fields
361
// they can send on the WebSocket setup call, we also allow additional fields to
362
// be transported in the special Sec-Websocket-Protocol field.
363
func forwardHeaders(source, target http.Header) {
4✔
364
        // Forward allowed header fields directly.
4✔
365
        for header := range source {
8✔
366
                headerName := textproto.CanonicalMIMEHeaderKey(header)
4✔
367
                forward, ok := defaultHeadersToForward[headerName]
4✔
368
                if ok && forward {
8✔
369
                        target.Set(headerName, source.Get(header))
4✔
370
                }
4✔
371
        }
372

373
        // Browser aren't allowed to set custom header fields on WebSocket
374
        // requests. We need to allow them to submit the macaroon as a WS
375
        // protocol, which is the only allowed header. Set any "protocols" we
376
        // declare valid as header fields on the forwarded request.
377
        protocol := source.Get(HeaderWebSocketProtocol)
4✔
378
        for key := range defaultProtocolsToAllow {
8✔
379
                if strings.HasPrefix(protocol, key) {
8✔
380
                        // The format is "<protocol name>+<value>". We know the
4✔
381
                        // protocol string starts with the name so we only need
4✔
382
                        // to set the value.
4✔
383
                        values := strings.Split(
4✔
384
                                protocol, WebSocketProtocolDelimiter,
4✔
385
                        )
4✔
386
                        target.Set(key, values[1])
4✔
387
                }
4✔
388
        }
389
}
390

391
// newRequestForwardingReader creates a new request forwarding pipe.
392
func newRequestForwardingReader() *requestForwardingReader {
4✔
393
        r, w := io.Pipe()
4✔
394
        return &requestForwardingReader{
4✔
395
                Reader: r,
4✔
396
                Writer: w,
4✔
397
                pipeR:  r,
4✔
398
                pipeW:  w,
4✔
399
        }
4✔
400
}
4✔
401

402
// requestForwardingReader is a wrapper around io.Pipe that embeds both the
403
// io.Reader and io.Writer interface and can be closed.
404
type requestForwardingReader struct {
405
        io.Reader
406
        io.Writer
407

408
        pipeR *io.PipeReader
409
        pipeW *io.PipeWriter
410
}
411

412
// CloseWriter closes the underlying pipe writer.
413
func (r *requestForwardingReader) CloseWriter() {
4✔
414
        _ = r.pipeW.CloseWithError(io.EOF)
4✔
415
}
4✔
416

417
// newResponseForwardingWriter creates a new http.ResponseWriter that intercepts
418
// what's written to it and presents it through a bufio.Scanner interface.
419
func newResponseForwardingWriter() *responseForwardingWriter {
4✔
420
        r, w := io.Pipe()
4✔
421

4✔
422
        scanner := bufio.NewScanner(r)
4✔
423

4✔
424
        // We pass in a custom buffer for the bufio scanner to use. We'll keep
4✔
425
        // with a normal 64KB buffer, but allow a larger max message size,
4✔
426
        // which may cause buffer expansion when needed.
4✔
427
        buf := make([]byte, 0, bufio.MaxScanTokenSize)
4✔
428
        scanner.Buffer(buf, MaxWsMsgSize)
4✔
429

4✔
430
        return &responseForwardingWriter{
4✔
431
                Writer:  w,
4✔
432
                Scanner: scanner,
4✔
433
                pipeR:   r,
4✔
434
                pipeW:   w,
4✔
435
                header:  http.Header{},
4✔
436
                closed:  make(chan bool, 1),
4✔
437
        }
4✔
438
}
4✔
439

440
// responseForwardingWriter is a type that implements the http.ResponseWriter
441
// interface but internally forwards what's written to the writer through a pipe
442
// so it can easily be read again through the bufio.Scanner interface.
443
type responseForwardingWriter struct {
444
        io.Writer
445
        *bufio.Scanner
446

447
        pipeR *io.PipeReader
448
        pipeW *io.PipeWriter
449

450
        header http.Header
451
        code   int
452
        closed chan bool
453
}
454

455
// Write writes the given bytes to the internal pipe.
456
//
457
// NOTE: This is part of the http.ResponseWriter interface.
458
func (w *responseForwardingWriter) Write(b []byte) (int, error) {
4✔
459
        return w.Writer.Write(b)
4✔
460
}
4✔
461

462
// Header returns the HTTP header fields intercepted so far.
463
//
464
// NOTE: This is part of the http.ResponseWriter interface.
465
func (w *responseForwardingWriter) Header() http.Header {
4✔
466
        return w.header
4✔
467
}
4✔
468

469
// WriteHeader indicates that the header part of the response is now finished
470
// and sets the response code.
471
//
472
// NOTE: This is part of the http.ResponseWriter interface.
473
func (w *responseForwardingWriter) WriteHeader(code int) {
×
474
        w.code = code
×
475
}
×
476

477
// CloseNotify returns a channel that indicates if a connection was closed.
478
//
479
// NOTE: This is part of the http.CloseNotifier interface.
480
func (w *responseForwardingWriter) CloseNotify() <-chan bool {
×
481
        return w.closed
×
482
}
×
483

484
// Flush empties all buffers. We implement this to indicate to our backend that
485
// we support flushing our content. There is no actual implementation because
486
// all writes happen immediately, there is no internal buffering.
487
//
488
// NOTE: This is part of the http.Flusher interface.
489
func (w *responseForwardingWriter) Flush() {}
4✔
490

491
func (w *responseForwardingWriter) Close() {
4✔
492
        _ = w.pipeR.CloseWithError(io.EOF)
4✔
493
        _ = w.pipeW.CloseWithError(io.EOF)
4✔
494
        w.closed <- true
4✔
495
}
4✔
496

497
// IsClosedConnError is a helper function that returns true if the given error
498
// is an error indicating we are using a closed connection.
499
func IsClosedConnError(err error) bool {
4✔
500
        if err == nil {
4✔
501
                return false
×
502
        }
×
503
        if err == http.ErrServerClosed {
4✔
504
                return true
×
505
        }
×
506

507
        str := err.Error()
4✔
508
        if strings.Contains(str, "use of closed network connection") {
8✔
509
                return true
4✔
510
        }
4✔
511
        if strings.Contains(str, "closed pipe") {
8✔
512
                return true
4✔
513
        }
4✔
514
        if strings.Contains(str, "broken pipe") {
8✔
515
                return true
4✔
516
        }
4✔
517
        if strings.Contains(str, "connection reset by peer") {
5✔
518
                return true
1✔
519
        }
1✔
520
        return websocket.IsCloseError(
4✔
521
                err, websocket.CloseNormalClosure, websocket.CloseGoingAway,
4✔
522
        )
4✔
523
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc