• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 12428593038

20 Dec 2024 09:02AM UTC coverage: 58.33% (-0.2%) from 58.576%
12428593038

Pull #9382

github

guggero
.golangci.yml: speed up linter by updating start commit

With this we allow the linter to only look at recent changes, since
everything between that old commit and this most recent one has been
linted correctly anyway.
Pull Request #9382: lint: deprecate old linters, use new ref commit

133769 of 229330 relevant lines covered (58.33%)

19284.53 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

76.87
/lnrpc/websocket_proxy.go
1
// The code in this file is a heavily modified version of
2
// https://github.com/tmc/grpc-websocket-proxy/
3

4
package lnrpc
5

6
import (
7
        "bufio"
8
        "io"
9
        "net/http"
10
        "net/textproto"
11
        "regexp"
12
        "strings"
13
        "time"
14

15
        "github.com/btcsuite/btclog/v2"
16
        "github.com/gorilla/websocket"
17
        "golang.org/x/net/context"
18
)
19

20
const (
21
        // MethodOverrideParam is the GET query parameter that specifies what
22
        // HTTP request method should be used for the forwarded REST request.
23
        // This is necessary because the WebSocket API specifies that a
24
        // handshake request must always be done through a GET request.
25
        MethodOverrideParam = "method"
26

27
        // HeaderWebSocketProtocol is the name of the WebSocket protocol
28
        // exchange header field that we use to transport additional header
29
        // fields.
30
        HeaderWebSocketProtocol = "Sec-Websocket-Protocol"
31

32
        // WebSocketProtocolDelimiter is the delimiter we use between the
33
        // additional header field and its value. We use the plus symbol because
34
        // the default delimiters aren't allowed in the protocol names.
35
        WebSocketProtocolDelimiter = "+"
36

37
        // PingContent is the content of the ping message we send out. This is
38
        // an arbitrary non-empty message that has no deeper meaning but should
39
        // be sent back by the client in the pong message.
40
        PingContent = "are you there?"
41

42
        // MaxWsMsgSize is the largest websockets message we'll attempt to
43
        // decode in the gRPC <-> WS proxy. gRPC has a similar setting used
44
        // elsewhere.
45
        MaxWsMsgSize = 4 * 1024 * 1024
46
)
47

48
var (
49
        // defaultHeadersToForward is a map of all HTTP header fields that are
50
        // forwarded by default. The keys must be in the canonical MIME header
51
        // format.
52
        defaultHeadersToForward = map[string]bool{
53
                "Origin":                 true,
54
                "Referer":                true,
55
                "Grpc-Metadata-Macaroon": true,
56
        }
57

58
        // defaultProtocolsToAllow are additional header fields that we allow
59
        // to be transported inside of the Sec-Websocket-Protocol field to be
60
        // forwarded to the backend.
61
        defaultProtocolsToAllow = map[string]bool{
62
                "Grpc-Metadata-Macaroon": true,
63
        }
64

65
        // DefaultPingInterval is the default number of seconds to wait between
66
        // sending ping requests.
67
        DefaultPingInterval = time.Second * 30
68

69
        // DefaultPongWait is the maximum duration we wait for a pong response
70
        // to a ping we sent before we assume the connection died.
71
        DefaultPongWait = time.Second * 5
72
)
73

74
// NewWebSocketProxy attempts to expose the underlying handler as a response-
75
// streaming WebSocket stream with newline-delimited JSON as the content
76
// encoding. If pingInterval is a non-zero duration, a ping message will be
77
// sent out periodically and a pong response message is expected from the
78
// client. The clientStreamingURIs parameter can hold a list of all patterns
79
// for URIs that are mapped to client-streaming RPC methods. We need to keep
80
// track of those to make sure we initialize the request body correctly for the
81
// underlying grpc-gateway library.
82
func NewWebSocketProxy(h http.Handler, logger btclog.Logger,
83
        pingInterval, pongWait time.Duration,
84
        clientStreamingURIs []*regexp.Regexp) http.Handler {
1✔
85

1✔
86
        p := &WebsocketProxy{
1✔
87
                backend: h,
1✔
88
                logger:  logger,
1✔
89
                upgrader: &websocket.Upgrader{
1✔
90
                        ReadBufferSize:  1024,
1✔
91
                        WriteBufferSize: 1024,
1✔
92
                        CheckOrigin: func(r *http.Request) bool {
2✔
93
                                return true
1✔
94
                        },
1✔
95
                },
96
                clientStreamingURIs: clientStreamingURIs,
97
        }
98

99
        if pingInterval > 0 && pongWait > 0 {
2✔
100
                p.pingInterval = pingInterval
1✔
101
                p.pongWait = pongWait
1✔
102
        }
1✔
103

104
        return p
1✔
105
}
106

107
// WebsocketProxy provides websocket transport upgrade to compatible endpoints.
108
type WebsocketProxy struct {
109
        backend  http.Handler
110
        logger   btclog.Logger
111
        upgrader *websocket.Upgrader
112

113
        // clientStreamingURIs holds a list of all patterns for URIs that are
114
        // mapped to client-streaming RPC methods. We need to keep track of
115
        // those to make sure we initialize the request body correctly for the
116
        // underlying grpc-gateway library.
117
        clientStreamingURIs []*regexp.Regexp
118

119
        pingInterval time.Duration
120
        pongWait     time.Duration
121
}
122

123
// pingPongEnabled returns true if a ping interval is set to enable sending and
124
// expecting regular ping/pong messages.
125
func (p *WebsocketProxy) pingPongEnabled() bool {
1✔
126
        return p.pingInterval > 0 && p.pongWait > 0
1✔
127
}
1✔
128

129
// ServeHTTP handles the incoming HTTP request. If the request is an
130
// "upgradeable" WebSocket request (identified by header fields), then the
131
// WS proxy handles the request. Otherwise the request is passed directly to the
132
// underlying REST proxy.
133
func (p *WebsocketProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
1✔
134
        if !websocket.IsWebSocketUpgrade(r) {
2✔
135
                p.backend.ServeHTTP(w, r)
1✔
136
                return
1✔
137
        }
1✔
138
        p.upgradeToWebSocketProxy(w, r)
1✔
139
}
140

141
// upgradeToWebSocketProxy upgrades the incoming request to a WebSocket, reads
142
// one incoming message then streams all responses until either the client or
143
// server quit the connection.
144
func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
145
        r *http.Request) {
1✔
146

1✔
147
        conn, err := p.upgrader.Upgrade(w, r, nil)
1✔
148
        if err != nil {
1✔
149
                p.logger.Errorf("error upgrading websocket:", err)
×
150
                return
×
151
        }
×
152
        defer func() {
2✔
153
                err := conn.Close()
1✔
154
                if err != nil && !IsClosedConnError(err) {
1✔
155
                        p.logger.Errorf("WS: error closing upgraded conn: %v",
×
156
                                err)
×
157
                }
×
158
        }()
159

160
        ctx, cancelFn := context.WithCancel(r.Context())
1✔
161
        defer cancelFn()
1✔
162

1✔
163
        requestForwarder := newRequestForwardingReader()
1✔
164
        request, err := http.NewRequestWithContext(
1✔
165
                ctx, r.Method, r.URL.String(), requestForwarder,
1✔
166
        )
1✔
167
        if err != nil {
1✔
168
                p.logger.Errorf("WS: error preparing request:", err)
×
169
                return
×
170
        }
×
171

172
        // Allow certain headers to be forwarded, either from source headers
173
        // or the special Sec-Websocket-Protocol header field.
174
        forwardHeaders(r.Header, request.Header)
1✔
175

1✔
176
        // Also allow the target request method to be overwritten, as all
1✔
177
        // WebSocket establishment calls MUST be GET requests.
1✔
178
        if m := r.URL.Query().Get(MethodOverrideParam); m != "" {
2✔
179
                request.Method = m
1✔
180
        }
1✔
181

182
        // Is this a call to a client-streaming RPC method?
183
        clientStreaming := false
1✔
184
        for _, pattern := range p.clientStreamingURIs {
2✔
185
                if pattern.MatchString(r.URL.Path) {
2✔
186
                        clientStreaming = true
1✔
187
                }
1✔
188
        }
189

190
        responseForwarder := newResponseForwardingWriter()
1✔
191
        go func() {
2✔
192
                <-ctx.Done()
1✔
193
                responseForwarder.Close()
1✔
194
                requestForwarder.CloseWriter()
1✔
195
        }()
1✔
196

197
        go func() {
2✔
198
                defer cancelFn()
1✔
199
                p.backend.ServeHTTP(responseForwarder, request)
1✔
200
        }()
1✔
201

202
        // Read loop: Take messages from websocket and write them to the payload
203
        // channel. This needs to be its own goroutine because for non-client
204
        // streaming RPCs, the requestForwarder.Write() in the second goroutine
205
        // will block until the request has fully completed. But for the ping/
206
        // pong handler to work, we need to have an active call to
207
        // conn.ReadMessage() going on. So we make sure we have such an active
208
        // call by starting a second read as soon as the first one has
209
        // completed.
210
        payloadChannel := make(chan []byte, 1)
1✔
211
        go func() {
2✔
212
                defer cancelFn()
1✔
213
                defer close(payloadChannel)
1✔
214

1✔
215
                for {
2✔
216
                        select {
1✔
217
                        case <-ctx.Done():
×
218
                                return
×
219
                        default:
1✔
220
                        }
221

222
                        _, payload, err := conn.ReadMessage()
1✔
223
                        if err != nil {
2✔
224
                                if IsClosedConnError(err) {
2✔
225
                                        p.logger.Tracef("WS: socket "+
1✔
226
                                                "closed: %v", err)
1✔
227
                                        return
1✔
228
                                }
1✔
229
                                p.logger.Errorf("error reading message: %v",
×
230
                                        err)
×
231
                                return
×
232
                        }
233

234
                        select {
1✔
235
                        case payloadChannel <- payload:
1✔
236
                        case <-ctx.Done():
×
237
                                return
×
238
                        }
239
                }
240
        }()
241

242
        // Forward loop: Take messages from the incoming payload channel and
243
        // write them to the http request.
244
        go func() {
2✔
245
                defer cancelFn()
1✔
246
                for {
2✔
247
                        var payload []byte
1✔
248
                        select {
1✔
249
                        case <-ctx.Done():
×
250
                                return
×
251
                        case newPayload, more := <-payloadChannel:
1✔
252
                                if !more {
2✔
253
                                        p.logger.Infof("WS: incoming payload " +
1✔
254
                                                "chan closed")
1✔
255
                                        return
1✔
256
                                }
1✔
257

258
                                payload = newPayload
1✔
259
                        }
260

261
                        _, err := requestForwarder.Write(payload)
1✔
262
                        if err != nil {
2✔
263
                                p.logger.Errorf("WS: error writing message "+
1✔
264
                                        "to upstream http server: %v", err)
1✔
265
                                return
1✔
266
                        }
1✔
267
                        _, _ = requestForwarder.Write([]byte{'\n'})
1✔
268

1✔
269
                        // The grpc-gateway library uses a different request
1✔
270
                        // reader depending on whether it is a client streaming
1✔
271
                        // RPC or not. For a non-streaming request we need to
1✔
272
                        // close with EOF to signal the request was completed.
1✔
273
                        if !clientStreaming {
2✔
274
                                requestForwarder.CloseWriter()
1✔
275
                        }
1✔
276
                }
277
        }()
278

279
        // Ping write loop: Send a ping message regularly if ping/pong is
280
        // enabled.
281
        if p.pingPongEnabled() {
2✔
282
                // We'll send out our first ping in pingInterval. So the initial
1✔
283
                // deadline is that interval plus the time we allow for a
1✔
284
                // response to be sent.
1✔
285
                initialDeadline := time.Now().Add(p.pingInterval + p.pongWait)
1✔
286
                _ = conn.SetReadDeadline(initialDeadline)
1✔
287

1✔
288
                // Whenever a pong message comes in, we extend the deadline
1✔
289
                // until the next read is expected by the interval plus pong
1✔
290
                // wait time. Since we can never _reach_ any of the deadlines,
1✔
291
                // we also have to advance the deadline for the next expected
1✔
292
                // write to happen, in case the next thing we actually write is
1✔
293
                // the next ping.
1✔
294
                conn.SetPongHandler(func(appData string) error {
1✔
295
                        nextDeadline := time.Now().Add(
×
296
                                p.pingInterval + p.pongWait,
×
297
                        )
×
298
                        _ = conn.SetReadDeadline(nextDeadline)
×
299
                        _ = conn.SetWriteDeadline(nextDeadline)
×
300

×
301
                        return nil
×
302
                })
×
303
                go func() {
2✔
304
                        ticker := time.NewTicker(p.pingInterval)
1✔
305
                        defer ticker.Stop()
1✔
306

1✔
307
                        for {
2✔
308
                                select {
1✔
309
                                case <-ctx.Done():
1✔
310
                                        p.logger.Debug("WS: ping loop done")
1✔
311
                                        return
1✔
312

313
                                case <-ticker.C:
×
314
                                        // Writing the ping shouldn't take any
×
315
                                        // longer than we'll wait for a response
×
316
                                        // in the first place.
×
317
                                        writeDeadline := time.Now().Add(
×
318
                                                p.pongWait,
×
319
                                        )
×
320
                                        err := conn.WriteControl(
×
321
                                                websocket.PingMessage,
×
322
                                                []byte(PingContent),
×
323
                                                writeDeadline,
×
324
                                        )
×
325
                                        if err != nil {
×
326
                                                p.logger.Warnf("WS: could not "+
×
327
                                                        "send ping message: %v",
×
328
                                                        err)
×
329
                                                return
×
330
                                        }
×
331
                                }
332
                        }
333
                }()
334
        }
335

336
        // Write loop: Take messages from the response forwarder and write them
337
        // to the WebSocket.
338
        for responseForwarder.Scan() {
2✔
339
                if len(responseForwarder.Bytes()) == 0 {
1✔
340
                        p.logger.Errorf("WS: empty scan: %v",
×
341
                                responseForwarder.Err())
×
342

×
343
                        continue
×
344
                }
345

346
                err := conn.WriteMessage(
1✔
347
                        websocket.TextMessage, responseForwarder.Bytes(),
1✔
348
                )
1✔
349
                if err != nil {
1✔
350
                        p.logger.Errorf("WS: error writing message: %v", err)
×
351
                        return
×
352
                }
×
353
        }
354
        if err := responseForwarder.Err(); err != nil && !IsClosedConnError(err) {
1✔
355
                p.logger.Errorf("WS: scanner err: %v", err)
×
356
        }
×
357
}
358

359
// forwardHeaders forwards certain allowed header fields from the source request
360
// to the target request. Because browsers are limited in what header fields
361
// they can send on the WebSocket setup call, we also allow additional fields to
362
// be transported in the special Sec-Websocket-Protocol field.
363
func forwardHeaders(source, target http.Header) {
1✔
364
        // Forward allowed header fields directly.
1✔
365
        for header := range source {
2✔
366
                headerName := textproto.CanonicalMIMEHeaderKey(header)
1✔
367
                forward, ok := defaultHeadersToForward[headerName]
1✔
368
                if ok && forward {
2✔
369
                        target.Set(headerName, source.Get(header))
1✔
370
                }
1✔
371
        }
372

373
        // Browser aren't allowed to set custom header fields on WebSocket
374
        // requests. We need to allow them to submit the macaroon as a WS
375
        // protocol, which is the only allowed header. Set any "protocols" we
376
        // declare valid as header fields on the forwarded request.
377
        protocol := source.Get(HeaderWebSocketProtocol)
1✔
378
        for key := range defaultProtocolsToAllow {
2✔
379
                if strings.HasPrefix(protocol, key) {
2✔
380
                        // The format is "<protocol name>+<value>". We know the
1✔
381
                        // protocol string starts with the name so we only need
1✔
382
                        // to set the value.
1✔
383
                        values := strings.Split(
1✔
384
                                protocol, WebSocketProtocolDelimiter,
1✔
385
                        )
1✔
386
                        target.Set(key, values[1])
1✔
387
                }
1✔
388
        }
389
}
390

391
// newRequestForwardingReader creates a new request forwarding pipe.
392
func newRequestForwardingReader() *requestForwardingReader {
1✔
393
        r, w := io.Pipe()
1✔
394
        return &requestForwardingReader{
1✔
395
                Reader: r,
1✔
396
                Writer: w,
1✔
397
                pipeR:  r,
1✔
398
                pipeW:  w,
1✔
399
        }
1✔
400
}
1✔
401

402
// requestForwardingReader is a wrapper around io.Pipe that embeds both the
403
// io.Reader and io.Writer interface and can be closed.
404
type requestForwardingReader struct {
405
        io.Reader
406
        io.Writer
407

408
        pipeR *io.PipeReader
409
        pipeW *io.PipeWriter
410
}
411

412
// CloseWriter closes the underlying pipe writer.
413
func (r *requestForwardingReader) CloseWriter() {
1✔
414
        _ = r.pipeW.CloseWithError(io.EOF)
1✔
415
}
1✔
416

417
// newResponseForwardingWriter creates a new http.ResponseWriter that intercepts
418
// what's written to it and presents it through a bufio.Scanner interface.
419
func newResponseForwardingWriter() *responseForwardingWriter {
1✔
420
        r, w := io.Pipe()
1✔
421

1✔
422
        scanner := bufio.NewScanner(r)
1✔
423

1✔
424
        // We pass in a custom buffer for the bufio scanner to use. We'll keep
1✔
425
        // with a normal 64KB buffer, but allow a larger max message size,
1✔
426
        // which may cause buffer expansion when needed.
1✔
427
        buf := make([]byte, 0, bufio.MaxScanTokenSize)
1✔
428
        scanner.Buffer(buf, MaxWsMsgSize)
1✔
429

1✔
430
        return &responseForwardingWriter{
1✔
431
                Writer:  w,
1✔
432
                Scanner: scanner,
1✔
433
                pipeR:   r,
1✔
434
                pipeW:   w,
1✔
435
                header:  http.Header{},
1✔
436
                closed:  make(chan bool, 1),
1✔
437
        }
1✔
438
}
1✔
439

440
// responseForwardingWriter is a type that implements the http.ResponseWriter
441
// interface but internally forwards what's written to the writer through a pipe
442
// so it can easily be read again through the bufio.Scanner interface.
443
type responseForwardingWriter struct {
444
        io.Writer
445
        *bufio.Scanner
446

447
        pipeR *io.PipeReader
448
        pipeW *io.PipeWriter
449

450
        header http.Header
451
        code   int
452
        closed chan bool
453
}
454

455
// Write writes the given bytes to the internal pipe.
456
//
457
// NOTE: This is part of the http.ResponseWriter interface.
458
func (w *responseForwardingWriter) Write(b []byte) (int, error) {
1✔
459
        return w.Writer.Write(b)
1✔
460
}
1✔
461

462
// Header returns the HTTP header fields intercepted so far.
463
//
464
// NOTE: This is part of the http.ResponseWriter interface.
465
func (w *responseForwardingWriter) Header() http.Header {
1✔
466
        return w.header
1✔
467
}
1✔
468

469
// WriteHeader indicates that the header part of the response is now finished
470
// and sets the response code.
471
//
472
// NOTE: This is part of the http.ResponseWriter interface.
473
func (w *responseForwardingWriter) WriteHeader(code int) {
×
474
        w.code = code
×
475
}
×
476

477
// CloseNotify returns a channel that indicates if a connection was closed.
478
//
479
// NOTE: This is part of the http.CloseNotifier interface.
480
func (w *responseForwardingWriter) CloseNotify() <-chan bool {
×
481
        return w.closed
×
482
}
×
483

484
// Flush empties all buffers. We implement this to indicate to our backend that
485
// we support flushing our content. There is no actual implementation because
486
// all writes happen immediately, there is no internal buffering.
487
//
488
// NOTE: This is part of the http.Flusher interface.
489
func (w *responseForwardingWriter) Flush() {}
1✔
490

491
func (w *responseForwardingWriter) Close() {
1✔
492
        _ = w.pipeR.CloseWithError(io.EOF)
1✔
493
        _ = w.pipeW.CloseWithError(io.EOF)
1✔
494
        w.closed <- true
1✔
495
}
1✔
496

497
// IsClosedConnError is a helper function that returns true if the given error
498
// is an error indicating we are using a closed connection.
499
func IsClosedConnError(err error) bool {
1✔
500
        if err == nil {
1✔
501
                return false
×
502
        }
×
503
        if err == http.ErrServerClosed {
1✔
504
                return true
×
505
        }
×
506

507
        str := err.Error()
1✔
508
        if strings.Contains(str, "use of closed network connection") {
2✔
509
                return true
1✔
510
        }
1✔
511
        if strings.Contains(str, "closed pipe") {
2✔
512
                return true
1✔
513
        }
1✔
514
        if strings.Contains(str, "broken pipe") {
2✔
515
                return true
1✔
516
        }
1✔
517
        if strings.Contains(str, "connection reset by peer") {
1✔
518
                return true
×
519
        }
×
520
        return websocket.IsCloseError(
1✔
521
                err, websocket.CloseNormalClosure, websocket.CloseGoingAway,
1✔
522
        )
1✔
523
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc