• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 12312390362

13 Dec 2024 08:44AM UTC coverage: 57.458% (+8.5%) from 48.92%
12312390362

Pull #9343

github

ellemouton
fn: rework the ContextGuard and add tests

In this commit, the ContextGuard struct is re-worked such that the
context that its new main WithCtx method provides is cancelled in sync
with a parent context being cancelled or with it's quit channel being
cancelled. Tests are added to assert the behaviour. In order for the
close of the quit channel to be consistent with the cancelling of the
derived context, the quit channel _must_ be contained internal to the
ContextGuard so that callers are only able to close the channel via the
exposed Quit method which will then take care to first cancel any
derived context that depend on the quit channel before returning.
Pull Request #9343: fn: expand the ContextGuard and add tests

101853 of 177264 relevant lines covered (57.46%)

24972.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

5.59
/rpcperms/middleware_handler.go
1
package rpcperms
2

3
import (
4
        "context"
5
        "encoding/hex"
6
        "errors"
7
        "fmt"
8
        "sync"
9
        "sync/atomic"
10
        "time"
11

12
        "github.com/btcsuite/btcd/chaincfg"
13
        "github.com/lightningnetwork/lnd/lnrpc"
14
        "github.com/lightningnetwork/lnd/macaroons"
15
        "google.golang.org/protobuf/proto"
16
        "google.golang.org/protobuf/reflect/protoreflect"
17
        "google.golang.org/protobuf/reflect/protoregistry"
18
        "gopkg.in/macaroon.v2"
19
)
20

21
var (
22
        // ErrShuttingDown is the error that's returned when the server is
23
        // shutting down and a request cannot be served anymore.
24
        ErrShuttingDown = errors.New("server shutting down")
25

26
        // ErrTimeoutReached is the error that's returned if any of the
27
        // middleware's tasks is not completed in the given time.
28
        ErrTimeoutReached = errors.New("intercept timeout reached")
29

30
        // errClientQuit is the error that's returned if the client closes the
31
        // middleware communication stream before a request was fully handled.
32
        errClientQuit = errors.New("interceptor RPC client quit")
33
)
34

35
// MiddlewareHandler is a type that communicates with a middleware over the
36
// established bi-directional RPC stream. It sends messages to the middleware
37
// whenever the custom business logic implemented there should give feedback to
38
// a request or response that's happening on the main gRPC server.
39
type MiddlewareHandler struct {
40
        // lastMsgID is the ID of the last intercept message that was forwarded
41
        // to the middleware.
42
        //
43
        // NOTE: Must be used atomically!
44
        lastMsgID uint64
45

46
        middlewareName string
47

48
        readOnly bool
49

50
        customCaveatName string
51

52
        receive func() (*lnrpc.RPCMiddlewareResponse, error)
53

54
        send func(request *lnrpc.RPCMiddlewareRequest) error
55

56
        interceptRequests chan *interceptRequest
57

58
        timeout time.Duration
59

60
        // params are our current chain params.
61
        params *chaincfg.Params
62

63
        // done is closed when the rpc client terminates.
64
        done chan struct{}
65

66
        // quit is closed when lnd is shutting down.
67
        quit chan struct{}
68

69
        wg sync.WaitGroup
70
}
71

72
// NewMiddlewareHandler creates a new handler for the middleware with the given
73
// name and custom caveat name.
74
func NewMiddlewareHandler(name, customCaveatName string, readOnly bool,
75
        receive func() (*lnrpc.RPCMiddlewareResponse, error),
76
        send func(request *lnrpc.RPCMiddlewareRequest) error,
77
        timeout time.Duration, params *chaincfg.Params,
78
        quit chan struct{}) *MiddlewareHandler {
×
79

×
80
        // We explicitly want to log this as a warning since intercepting any
×
81
        // gRPC messages can also be used for malicious purposes and the user
×
82
        // should be made aware of the risks.
×
83
        log.Warnf("A new gRPC middleware with the name '%s' was registered "+
×
84
                " with custom_macaroon_caveat='%s', read_only=%v. Make sure "+
×
85
                "you trust the middleware author since that code will be able "+
×
86
                "to intercept and possibly modify any gRPC messages sent/"+
×
87
                "received to/from a client that has a macaroon with that "+
×
88
                "custom caveat.", name, customCaveatName, readOnly)
×
89

×
90
        return &MiddlewareHandler{
×
91
                middlewareName:    name,
×
92
                customCaveatName:  customCaveatName,
×
93
                readOnly:          readOnly,
×
94
                receive:           receive,
×
95
                send:              send,
×
96
                interceptRequests: make(chan *interceptRequest),
×
97
                timeout:           timeout,
×
98
                params:            params,
×
99
                done:              make(chan struct{}),
×
100
                quit:              quit,
×
101
        }
×
102
}
×
103

104
// intercept handles the full interception lifecycle of a single middleware
105
// event (stream authentication, request interception or response interception).
106
// The lifecycle consists of sending a message to the middleware, receiving a
107
// feedback on it and sending the feedback to the appropriate channel. All steps
108
// are guarded by the configured timeout to make sure a middleware cannot slow
109
// down requests too much.
110
func (h *MiddlewareHandler) intercept(requestID uint64,
111
        req *InterceptionRequest) (*interceptResponse, error) {
×
112

×
113
        respChan := make(chan *interceptResponse, 1)
×
114

×
115
        newRequest := &interceptRequest{
×
116
                requestID: requestID,
×
117
                request:   req,
×
118
                response:  respChan,
×
119
        }
×
120

×
121
        // timeout is the time after which intercept requests expire.
×
122
        timeout := time.After(h.timeout)
×
123

×
124
        // Send the request to the interceptRequests channel for the main
×
125
        // goroutine to be picked up.
×
126
        select {
×
127
        case h.interceptRequests <- newRequest:
×
128

129
        case <-timeout:
×
130
                log.Errorf("MiddlewareHandler returned error - reached "+
×
131
                        "timeout of %v for request interception", h.timeout)
×
132

×
133
                return nil, ErrTimeoutReached
×
134

135
        case <-h.done:
×
136
                return nil, errClientQuit
×
137

138
        case <-h.quit:
×
139
                return nil, ErrShuttingDown
×
140
        }
141

142
        // Receive the response and return it. If no response has been received
143
        // in AcceptorTimeout, then return false.
144
        select {
×
145
        case resp := <-respChan:
×
146
                return resp, nil
×
147

148
        case <-timeout:
×
149
                log.Errorf("MiddlewareHandler returned error - reached "+
×
150
                        "timeout of %v for response interception", h.timeout)
×
151
                return nil, ErrTimeoutReached
×
152

153
        case <-h.done:
×
154
                return nil, errClientQuit
×
155

156
        case <-h.quit:
×
157
                return nil, ErrShuttingDown
×
158
        }
159
}
160

161
// Run is the main loop for the middleware handler. This function will block
162
// until it receives the signal that lnd is shutting down, or the rpc stream is
163
// cancelled by the client.
164
func (h *MiddlewareHandler) Run() error {
×
165
        // Wait for our goroutines to exit before we return.
×
166
        defer h.wg.Wait()
×
167
        defer log.Debugf("Exiting middleware run loop for %s", h.middlewareName)
×
168

×
169
        // Create a channel that responses from middlewares are sent into.
×
170
        responses := make(chan *lnrpc.RPCMiddlewareResponse)
×
171

×
172
        // errChan is used by the receive loop to signal any errors that occur
×
173
        // during reading from the stream. This is primarily used to shutdown
×
174
        // the send loop in the case of an RPC client disconnecting.
×
175
        errChan := make(chan error, 1)
×
176

×
177
        // Start a goroutine to receive responses from the interceptor. We
×
178
        // expect the receive function to block, so it must be run in a
×
179
        // goroutine (otherwise we could not send more than one intercept
×
180
        // request to the client).
×
181
        h.wg.Add(1)
×
182
        go func() {
×
183
                defer h.wg.Done()
×
184

×
185
                h.receiveResponses(errChan, responses)
×
186
        }()
×
187

188
        return h.sendInterceptRequests(errChan, responses)
×
189
}
190

191
// receiveResponses receives responses for our intercept requests and dispatches
192
// them into the responses channel provided, sending any errors that occur into
193
// the error channel provided.
194
func (h *MiddlewareHandler) receiveResponses(errChan chan error,
195
        responses chan *lnrpc.RPCMiddlewareResponse) {
×
196

×
197
        for {
×
198
                resp, err := h.receive()
×
199
                if err != nil {
×
200
                        errChan <- err
×
201
                        return
×
202
                }
×
203

204
                select {
×
205
                case responses <- resp:
×
206

207
                case <-h.done:
×
208
                        return
×
209

210
                case <-h.quit:
×
211
                        return
×
212
                }
213
        }
214
}
215

216
// sendInterceptRequests handles intercept requests sent to us by our Accept()
217
// function, dispatching them to our acceptor stream and coordinating return of
218
// responses to their callers.
219
func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error,
220
        responses chan *lnrpc.RPCMiddlewareResponse) error {
×
221

×
222
        // Close the done channel to indicate that the interceptor is no longer
×
223
        // listening and any in-progress requests should be terminated.
×
224
        defer close(h.done)
×
225

×
226
        interceptRequests := make(map[uint64]*interceptRequest)
×
227

×
228
        for {
×
229
                select {
×
230
                // Consume requests passed to us from our Accept() function and
231
                // send them into our stream.
232
                case newRequest := <-h.interceptRequests:
×
233
                        msgID := atomic.AddUint64(&h.lastMsgID, 1)
×
234

×
235
                        req := newRequest.request
×
236
                        interceptRequests[msgID] = newRequest
×
237

×
238
                        interceptReq, err := req.ToRPC(
×
239
                                newRequest.requestID, msgID,
×
240
                        )
×
241
                        if err != nil {
×
242
                                return err
×
243
                        }
×
244

245
                        if err := h.send(interceptReq); err != nil {
×
246
                                return err
×
247
                        }
×
248

249
                // Process newly received responses from our interceptor,
250
                // looking the original request up in our map of requests and
251
                // dispatching the response.
252
                case resp := <-responses:
×
253
                        requestInfo, ok := interceptRequests[resp.RefMsgId]
×
254
                        if !ok {
×
255
                                continue
×
256
                        }
257

258
                        response := &interceptResponse{}
×
259
                        switch msg := resp.GetMiddlewareMessage().(type) {
×
260
                        case *lnrpc.RPCMiddlewareResponse_Feedback:
×
261
                                t := msg.Feedback
×
262
                                if t.Error != "" {
×
263
                                        response.err = fmt.Errorf("%s", t.Error)
×
264
                                        break
×
265
                                }
266

267
                                // If there's nothing to replace, we're done,
268
                                // this request was just accepted.
269
                                if !t.ReplaceResponse {
×
270
                                        break
×
271
                                }
272

273
                                // We are replacing the response, the question
274
                                // now just is: was it an error or a proper
275
                                // proto message?
276
                                response.replace = true
×
277
                                if requestInfo.request.IsError {
×
278
                                        response.replacement = errors.New(
×
279
                                                string(t.ReplacementSerialized),
×
280
                                        )
×
281

×
282
                                        break
×
283
                                }
284

285
                                // Not an error but a proper proto message that
286
                                // needs to be replaced. For that we need to
287
                                // parse it from the raw bytes into the full RPC
288
                                // message.
289
                                protoMsg, err := parseProto(
×
290
                                        requestInfo.request.ProtoTypeName,
×
291
                                        t.ReplacementSerialized,
×
292
                                )
×
293

×
294
                                if err != nil {
×
295
                                        response.err = err
×
296

×
297
                                        break
×
298
                                }
299

300
                                response.replacement = protoMsg
×
301

302
                        default:
×
303
                                return fmt.Errorf("unknown middleware "+
×
304
                                        "message: %v", msg)
×
305
                        }
306

307
                        select {
×
308
                        case requestInfo.response <- response:
×
309
                        case <-h.quit:
×
310
                        }
311

312
                        delete(interceptRequests, resp.RefMsgId)
×
313

314
                // If we failed to receive from our middleware, we exit.
315
                case err := <-errChan:
×
316
                        log.Errorf("Received an error: %v, shutting down", err)
×
317
                        return err
×
318

319
                // Exit if we are shutting down.
320
                case <-h.quit:
×
321
                        return ErrShuttingDown
×
322
                }
323
        }
324
}
325

326
// InterceptType defines the different types of intercept messages a middleware
327
// can receive.
328
type InterceptType uint8
329

330
const (
331
        // TypeStreamAuth is the type of intercept message that is sent when a
332
        // client or streaming RPC is initialized. A message with this type will
333
        // be sent out during stream initialization so a middleware can
334
        // accept/deny the whole stream instead of only single messages on the
335
        // stream.
336
        TypeStreamAuth InterceptType = 1
337

338
        // TypeRequest is the type of intercept message that is sent when an RPC
339
        // request message is sent to lnd. For client-streaming RPCs a new
340
        // message of this type is sent for each individual RPC request sent to
341
        // the stream. Middleware has the option to modify a request message
342
        // before it is delivered to lnd.
343
        TypeRequest InterceptType = 2
344

345
        // TypeResponse is the type of intercept message that is sent when an
346
        // RPC response message is sent from lnd to a client. For
347
        // server-streaming RPCs a new message of this type is sent for each
348
        // individual RPC response sent to the stream. Middleware has the option
349
        // to modify a response message before it is sent out to the client.
350
        TypeResponse InterceptType = 3
351
)
352

353
// InterceptionRequest is a struct holding all information that is sent to a
354
// middleware whenever there is something to intercept (auth, request,
355
// response).
356
type InterceptionRequest struct {
357
        // Type is the type of the interception message.
358
        Type InterceptType
359

360
        // StreamRPC is set to true if the invoked RPC method is client or
361
        // server streaming.
362
        StreamRPC bool
363

364
        // Macaroon holds the macaroon that the client sent to lnd.
365
        Macaroon *macaroon.Macaroon
366

367
        // RawMacaroon holds the raw binary serialized macaroon that the client
368
        // sent to lnd.
369
        RawMacaroon []byte
370

371
        // CustomCaveatName is the name of the custom caveat that the middleware
372
        // was intercepting for.
373
        CustomCaveatName string
374

375
        // CustomCaveatCondition is the condition of the custom caveat that the
376
        // middleware was intercepting for. This can be empty for custom caveats
377
        // that only have a name (marker caveats).
378
        CustomCaveatCondition string
379

380
        // FullURI is the full RPC method URI that was invoked.
381
        FullURI string
382

383
        // ProtoSerialized is the full request or response object in the
384
        // protobuf binary serialization format.
385
        ProtoSerialized []byte
386

387
        // ProtoTypeName is the fully qualified name of the protobuf type of the
388
        // request or response message that is serialized in the field above.
389
        ProtoTypeName string
390

391
        // IsError indicates that the message contained within this request is
392
        // an error. Will only ever be true for response messages.
393
        IsError bool
394
}
395

396
// NewMessageInterceptionRequest creates a new interception request for either
397
// a request or response message.
398
func NewMessageInterceptionRequest(ctx context.Context,
399
        authType InterceptType, isStream bool, fullMethod string,
400
        m interface{}) (*InterceptionRequest, error) {
×
401

×
402
        mac, rawMacaroon, err := macaroonFromContext(ctx)
×
403
        if err != nil {
×
404
                return nil, err
×
405
        }
×
406

407
        req := &InterceptionRequest{
×
408
                Type:        authType,
×
409
                StreamRPC:   isStream,
×
410
                Macaroon:    mac,
×
411
                RawMacaroon: rawMacaroon,
×
412
                FullURI:     fullMethod,
×
413
        }
×
414

×
415
        // The message is either a proto message or an error, we don't support
×
416
        // any other types being intercepted.
×
417
        switch t := m.(type) {
×
418
        case proto.Message:
×
419
                req.ProtoSerialized, err = proto.Marshal(t)
×
420
                if err != nil {
×
421
                        return nil, fmt.Errorf("cannot marshal proto msg: %w",
×
422
                                err)
×
423
                }
×
424
                req.ProtoTypeName = string(proto.MessageName(t))
×
425

426
        case error:
×
427
                req.ProtoSerialized = []byte(t.Error())
×
428
                req.ProtoTypeName = "error"
×
429
                req.IsError = true
×
430

431
        default:
×
432
                return nil, fmt.Errorf("unsupported type for interception "+
×
433
                        "request: %v", m)
×
434
        }
435

436
        return req, nil
×
437
}
438

439
// NewStreamAuthInterceptionRequest creates a new interception request for a
440
// stream authentication message.
441
func NewStreamAuthInterceptionRequest(ctx context.Context,
442
        fullMethod string) (*InterceptionRequest, error) {
×
443

×
444
        mac, rawMacaroon, err := macaroonFromContext(ctx)
×
445
        if err != nil {
×
446
                return nil, err
×
447
        }
×
448

449
        return &InterceptionRequest{
×
450
                Type:        TypeStreamAuth,
×
451
                StreamRPC:   true,
×
452
                Macaroon:    mac,
×
453
                RawMacaroon: rawMacaroon,
×
454
                FullURI:     fullMethod,
×
455
        }, nil
×
456
}
457

458
// macaroonFromContext tries to extract the macaroon from the incoming context.
459
// If there is no macaroon, a nil error is returned since some RPCs might not
460
// require a macaroon. But in case there is something in the macaroon header
461
// field that cannot be parsed, a non-nil error is returned.
462
func macaroonFromContext(ctx context.Context) (*macaroon.Macaroon, []byte,
463
        error) {
×
464

×
465
        macHex, err := macaroons.RawMacaroonFromContext(ctx)
×
466
        if err != nil {
×
467
                // If there is no macaroon, we continue anyway as it might be an
×
468
                // RPC that doesn't require a macaroon.
×
469
                return nil, nil, nil
×
470
        }
×
471

472
        macBytes, err := hex.DecodeString(macHex)
×
473
        if err != nil {
×
474
                return nil, nil, err
×
475
        }
×
476

477
        mac := &macaroon.Macaroon{}
×
478
        if err := mac.UnmarshalBinary(macBytes); err != nil {
×
479
                return nil, nil, err
×
480
        }
×
481

482
        return mac, macBytes, nil
×
483
}
484

485
// ToRPC converts the interception request to its RPC counterpart.
486
func (r *InterceptionRequest) ToRPC(requestID,
487
        msgID uint64) (*lnrpc.RPCMiddlewareRequest, error) {
×
488

×
489
        rpcRequest := &lnrpc.RPCMiddlewareRequest{
×
490
                RequestId:             requestID,
×
491
                MsgId:                 msgID,
×
492
                RawMacaroon:           r.RawMacaroon,
×
493
                CustomCaveatCondition: r.CustomCaveatCondition,
×
494
        }
×
495

×
496
        switch r.Type {
×
497
        case TypeStreamAuth:
×
498
                rpcRequest.InterceptType = &lnrpc.RPCMiddlewareRequest_StreamAuth{
×
499
                        StreamAuth: &lnrpc.StreamAuth{
×
500
                                MethodFullUri: r.FullURI,
×
501
                        },
×
502
                }
×
503

504
        case TypeRequest:
×
505
                rpcRequest.InterceptType = &lnrpc.RPCMiddlewareRequest_Request{
×
506
                        Request: &lnrpc.RPCMessage{
×
507
                                MethodFullUri: r.FullURI,
×
508
                                StreamRpc:     r.StreamRPC,
×
509
                                TypeName:      r.ProtoTypeName,
×
510
                                Serialized:    r.ProtoSerialized,
×
511
                        },
×
512
                }
×
513

514
        case TypeResponse:
×
515
                rpcRequest.InterceptType = &lnrpc.RPCMiddlewareRequest_Response{
×
516
                        Response: &lnrpc.RPCMessage{
×
517
                                MethodFullUri: r.FullURI,
×
518
                                StreamRpc:     r.StreamRPC,
×
519
                                TypeName:      r.ProtoTypeName,
×
520
                                Serialized:    r.ProtoSerialized,
×
521
                                IsError:       r.IsError,
×
522
                        },
×
523
                }
×
524

525
        default:
×
526
                return nil, fmt.Errorf("unknown intercept type %v", r.Type)
×
527
        }
528

529
        return rpcRequest, nil
×
530
}
531

532
// interceptRequest is a struct that keeps track of an interception request sent
533
// out to a middleware and the response that is eventually sent back by the
534
// middleware.
535
type interceptRequest struct {
536
        requestID uint64
537
        request   *InterceptionRequest
538
        response  chan *interceptResponse
539
}
540

541
// interceptResponse is the response a middleware sends back for each
542
// intercepted message.
543
type interceptResponse struct {
544
        err         error
545
        replace     bool
546
        replacement interface{}
547
}
548

549
// parseProto parses a proto serialized message of the given type into its
550
// native version.
551
func parseProto(typeName string, serialized []byte) (proto.Message, error) {
×
552
        messageType, err := protoregistry.GlobalTypes.FindMessageByName(
×
553
                protoreflect.FullName(typeName),
×
554
        )
×
555
        if err != nil {
×
556
                return nil, err
×
557
        }
×
558
        msg := messageType.New()
×
559
        err = proto.Unmarshal(serialized, msg.Interface())
×
560
        if err != nil {
×
561
                return nil, err
×
562
        }
×
563

564
        return msg.Interface(), nil
×
565
}
566

567
// replaceProtoMsg replaces the given target message with the content of the
568
// replacement message.
569
func replaceProtoMsg(target interface{}, replacement interface{}) error {
5✔
570
        targetMsg, ok := target.(proto.Message)
5✔
571
        if !ok {
6✔
572
                return fmt.Errorf("target is not a proto message: %v", target)
1✔
573
        }
1✔
574

575
        replacementMsg, ok := replacement.(proto.Message)
4✔
576
        if !ok {
4✔
577
                return fmt.Errorf("replacement is not a proto message: %v",
×
578
                        replacement)
×
579
        }
×
580

581
        if targetMsg.ProtoReflect().Type() !=
4✔
582
                replacementMsg.ProtoReflect().Type() {
5✔
583

1✔
584
                return fmt.Errorf("replacement message is of wrong type")
1✔
585
        }
1✔
586

587
        replacementBytes, err := proto.Marshal(replacementMsg)
3✔
588
        if err != nil {
3✔
589
                return fmt.Errorf("error marshaling replacement: %w", err)
×
590
        }
×
591
        err = proto.Unmarshal(replacementBytes, targetMsg)
3✔
592
        if err != nil {
3✔
593
                return fmt.Errorf("error unmarshaling replacement: %w", err)
×
594
        }
×
595

596
        return nil
3✔
597
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc