• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 12199391122

06 Dec 2024 01:10PM UTC coverage: 49.807% (-9.1%) from 58.933%
12199391122

push

github

web-flow
Merge pull request #9337 from Guayaba221/patch-1

chore: fix typo in ruby.md

100137 of 201051 relevant lines covered (49.81%)

2.07 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

75.99
/rpcperms/middleware_handler.go
1
package rpcperms
2

3
import (
4
        "context"
5
        "encoding/hex"
6
        "errors"
7
        "fmt"
8
        "sync"
9
        "sync/atomic"
10
        "time"
11

12
        "github.com/btcsuite/btcd/chaincfg"
13
        "github.com/lightningnetwork/lnd/lnrpc"
14
        "github.com/lightningnetwork/lnd/macaroons"
15
        "google.golang.org/protobuf/proto"
16
        "google.golang.org/protobuf/reflect/protoreflect"
17
        "google.golang.org/protobuf/reflect/protoregistry"
18
        "gopkg.in/macaroon.v2"
19
)
20

21
var (
22
        // ErrShuttingDown is the error that's returned when the server is
23
        // shutting down and a request cannot be served anymore.
24
        ErrShuttingDown = errors.New("server shutting down")
25

26
        // ErrTimeoutReached is the error that's returned if any of the
27
        // middleware's tasks is not completed in the given time.
28
        ErrTimeoutReached = errors.New("intercept timeout reached")
29

30
        // errClientQuit is the error that's returned if the client closes the
31
        // middleware communication stream before a request was fully handled.
32
        errClientQuit = errors.New("interceptor RPC client quit")
33
)
34

35
// MiddlewareHandler is a type that communicates with a middleware over the
36
// established bi-directional RPC stream. It sends messages to the middleware
37
// whenever the custom business logic implemented there should give feedback to
38
// a request or response that's happening on the main gRPC server.
39
type MiddlewareHandler struct {
40
        // lastMsgID is the ID of the last intercept message that was forwarded
41
        // to the middleware.
42
        //
43
        // NOTE: Must be used atomically!
44
        lastMsgID uint64
45

46
        middlewareName string
47

48
        readOnly bool
49

50
        customCaveatName string
51

52
        receive func() (*lnrpc.RPCMiddlewareResponse, error)
53

54
        send func(request *lnrpc.RPCMiddlewareRequest) error
55

56
        interceptRequests chan *interceptRequest
57

58
        timeout time.Duration
59

60
        // params are our current chain params.
61
        params *chaincfg.Params
62

63
        // done is closed when the rpc client terminates.
64
        done chan struct{}
65

66
        // quit is closed when lnd is shutting down.
67
        quit chan struct{}
68

69
        wg sync.WaitGroup
70
}
71

72
// NewMiddlewareHandler creates a new handler for the middleware with the given
73
// name and custom caveat name.
74
func NewMiddlewareHandler(name, customCaveatName string, readOnly bool,
75
        receive func() (*lnrpc.RPCMiddlewareResponse, error),
76
        send func(request *lnrpc.RPCMiddlewareRequest) error,
77
        timeout time.Duration, params *chaincfg.Params,
78
        quit chan struct{}) *MiddlewareHandler {
4✔
79

4✔
80
        // We explicitly want to log this as a warning since intercepting any
4✔
81
        // gRPC messages can also be used for malicious purposes and the user
4✔
82
        // should be made aware of the risks.
4✔
83
        log.Warnf("A new gRPC middleware with the name '%s' was registered "+
4✔
84
                " with custom_macaroon_caveat='%s', read_only=%v. Make sure "+
4✔
85
                "you trust the middleware author since that code will be able "+
4✔
86
                "to intercept and possibly modify any gRPC messages sent/"+
4✔
87
                "received to/from a client that has a macaroon with that "+
4✔
88
                "custom caveat.", name, customCaveatName, readOnly)
4✔
89

4✔
90
        return &MiddlewareHandler{
4✔
91
                middlewareName:    name,
4✔
92
                customCaveatName:  customCaveatName,
4✔
93
                readOnly:          readOnly,
4✔
94
                receive:           receive,
4✔
95
                send:              send,
4✔
96
                interceptRequests: make(chan *interceptRequest),
4✔
97
                timeout:           timeout,
4✔
98
                params:            params,
4✔
99
                done:              make(chan struct{}),
4✔
100
                quit:              quit,
4✔
101
        }
4✔
102
}
4✔
103

104
// intercept handles the full interception lifecycle of a single middleware
105
// event (stream authentication, request interception or response interception).
106
// The lifecycle consists of sending a message to the middleware, receiving a
107
// feedback on it and sending the feedback to the appropriate channel. All steps
108
// are guarded by the configured timeout to make sure a middleware cannot slow
109
// down requests too much.
110
func (h *MiddlewareHandler) intercept(requestID uint64,
111
        req *InterceptionRequest) (*interceptResponse, error) {
4✔
112

4✔
113
        respChan := make(chan *interceptResponse, 1)
4✔
114

4✔
115
        newRequest := &interceptRequest{
4✔
116
                requestID: requestID,
4✔
117
                request:   req,
4✔
118
                response:  respChan,
4✔
119
        }
4✔
120

4✔
121
        // timeout is the time after which intercept requests expire.
4✔
122
        timeout := time.After(h.timeout)
4✔
123

4✔
124
        // Send the request to the interceptRequests channel for the main
4✔
125
        // goroutine to be picked up.
4✔
126
        select {
4✔
127
        case h.interceptRequests <- newRequest:
4✔
128

129
        case <-timeout:
×
130
                log.Errorf("MiddlewareHandler returned error - reached "+
×
131
                        "timeout of %v for request interception", h.timeout)
×
132

×
133
                return nil, ErrTimeoutReached
×
134

135
        case <-h.done:
×
136
                return nil, errClientQuit
×
137

138
        case <-h.quit:
×
139
                return nil, ErrShuttingDown
×
140
        }
141

142
        // Receive the response and return it. If no response has been received
143
        // in AcceptorTimeout, then return false.
144
        select {
4✔
145
        case resp := <-respChan:
4✔
146
                return resp, nil
4✔
147

148
        case <-timeout:
×
149
                log.Errorf("MiddlewareHandler returned error - reached "+
×
150
                        "timeout of %v for response interception", h.timeout)
×
151
                return nil, ErrTimeoutReached
×
152

153
        case <-h.done:
×
154
                return nil, errClientQuit
×
155

156
        case <-h.quit:
×
157
                return nil, ErrShuttingDown
×
158
        }
159
}
160

161
// Run is the main loop for the middleware handler. This function will block
162
// until it receives the signal that lnd is shutting down, or the rpc stream is
163
// cancelled by the client.
164
func (h *MiddlewareHandler) Run() error {
4✔
165
        // Wait for our goroutines to exit before we return.
4✔
166
        defer h.wg.Wait()
4✔
167
        defer log.Debugf("Exiting middleware run loop for %s", h.middlewareName)
4✔
168

4✔
169
        // Create a channel that responses from middlewares are sent into.
4✔
170
        responses := make(chan *lnrpc.RPCMiddlewareResponse)
4✔
171

4✔
172
        // errChan is used by the receive loop to signal any errors that occur
4✔
173
        // during reading from the stream. This is primarily used to shutdown
4✔
174
        // the send loop in the case of an RPC client disconnecting.
4✔
175
        errChan := make(chan error, 1)
4✔
176

4✔
177
        // Start a goroutine to receive responses from the interceptor. We
4✔
178
        // expect the receive function to block, so it must be run in a
4✔
179
        // goroutine (otherwise we could not send more than one intercept
4✔
180
        // request to the client).
4✔
181
        h.wg.Add(1)
4✔
182
        go func() {
8✔
183
                defer h.wg.Done()
4✔
184

4✔
185
                h.receiveResponses(errChan, responses)
4✔
186
        }()
4✔
187

188
        return h.sendInterceptRequests(errChan, responses)
4✔
189
}
190

191
// receiveResponses receives responses for our intercept requests and dispatches
192
// them into the responses channel provided, sending any errors that occur into
193
// the error channel provided.
194
func (h *MiddlewareHandler) receiveResponses(errChan chan error,
195
        responses chan *lnrpc.RPCMiddlewareResponse) {
4✔
196

4✔
197
        for {
8✔
198
                resp, err := h.receive()
4✔
199
                if err != nil {
8✔
200
                        errChan <- err
4✔
201
                        return
4✔
202
                }
4✔
203

204
                select {
4✔
205
                case responses <- resp:
4✔
206

207
                case <-h.done:
×
208
                        return
×
209

210
                case <-h.quit:
×
211
                        return
×
212
                }
213
        }
214
}
215

216
// sendInterceptRequests handles intercept requests sent to us by our Accept()
217
// function, dispatching them to our acceptor stream and coordinating return of
218
// responses to their callers.
219
func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error,
220
        responses chan *lnrpc.RPCMiddlewareResponse) error {
4✔
221

4✔
222
        // Close the done channel to indicate that the interceptor is no longer
4✔
223
        // listening and any in-progress requests should be terminated.
4✔
224
        defer close(h.done)
4✔
225

4✔
226
        interceptRequests := make(map[uint64]*interceptRequest)
4✔
227

4✔
228
        for {
8✔
229
                select {
4✔
230
                // Consume requests passed to us from our Accept() function and
231
                // send them into our stream.
232
                case newRequest := <-h.interceptRequests:
4✔
233
                        msgID := atomic.AddUint64(&h.lastMsgID, 1)
4✔
234

4✔
235
                        req := newRequest.request
4✔
236
                        interceptRequests[msgID] = newRequest
4✔
237

4✔
238
                        interceptReq, err := req.ToRPC(
4✔
239
                                newRequest.requestID, msgID,
4✔
240
                        )
4✔
241
                        if err != nil {
4✔
242
                                return err
×
243
                        }
×
244

245
                        if err := h.send(interceptReq); err != nil {
4✔
246
                                return err
×
247
                        }
×
248

249
                // Process newly received responses from our interceptor,
250
                // looking the original request up in our map of requests and
251
                // dispatching the response.
252
                case resp := <-responses:
4✔
253
                        requestInfo, ok := interceptRequests[resp.RefMsgId]
4✔
254
                        if !ok {
4✔
255
                                continue
×
256
                        }
257

258
                        response := &interceptResponse{}
4✔
259
                        switch msg := resp.GetMiddlewareMessage().(type) {
4✔
260
                        case *lnrpc.RPCMiddlewareResponse_Feedback:
4✔
261
                                t := msg.Feedback
4✔
262
                                if t.Error != "" {
4✔
263
                                        response.err = fmt.Errorf("%s", t.Error)
×
264
                                        break
×
265
                                }
266

267
                                // If there's nothing to replace, we're done,
268
                                // this request was just accepted.
269
                                if !t.ReplaceResponse {
8✔
270
                                        break
4✔
271
                                }
272

273
                                // We are replacing the response, the question
274
                                // now just is: was it an error or a proper
275
                                // proto message?
276
                                response.replace = true
4✔
277
                                if requestInfo.request.IsError {
8✔
278
                                        response.replacement = errors.New(
4✔
279
                                                string(t.ReplacementSerialized),
4✔
280
                                        )
4✔
281

4✔
282
                                        break
4✔
283
                                }
284

285
                                // Not an error but a proper proto message that
286
                                // needs to be replaced. For that we need to
287
                                // parse it from the raw bytes into the full RPC
288
                                // message.
289
                                protoMsg, err := parseProto(
4✔
290
                                        requestInfo.request.ProtoTypeName,
4✔
291
                                        t.ReplacementSerialized,
4✔
292
                                )
4✔
293

4✔
294
                                if err != nil {
4✔
295
                                        response.err = err
×
296

×
297
                                        break
×
298
                                }
299

300
                                response.replacement = protoMsg
4✔
301

302
                        default:
×
303
                                return fmt.Errorf("unknown middleware "+
×
304
                                        "message: %v", msg)
×
305
                        }
306

307
                        select {
4✔
308
                        case requestInfo.response <- response:
4✔
309
                        case <-h.quit:
×
310
                        }
311

312
                        delete(interceptRequests, resp.RefMsgId)
4✔
313

314
                // If we failed to receive from our middleware, we exit.
315
                case err := <-errChan:
4✔
316
                        log.Errorf("Received an error: %v, shutting down", err)
4✔
317
                        return err
4✔
318

319
                // Exit if we are shutting down.
320
                case <-h.quit:
×
321
                        return ErrShuttingDown
×
322
                }
323
        }
324
}
325

326
// InterceptType defines the different types of intercept messages a middleware
327
// can receive.
328
type InterceptType uint8
329

330
const (
331
        // TypeStreamAuth is the type of intercept message that is sent when a
332
        // client or streaming RPC is initialized. A message with this type will
333
        // be sent out during stream initialization so a middleware can
334
        // accept/deny the whole stream instead of only single messages on the
335
        // stream.
336
        TypeStreamAuth InterceptType = 1
337

338
        // TypeRequest is the type of intercept message that is sent when an RPC
339
        // request message is sent to lnd. For client-streaming RPCs a new
340
        // message of this type is sent for each individual RPC request sent to
341
        // the stream. Middleware has the option to modify a request message
342
        // before it is delivered to lnd.
343
        TypeRequest InterceptType = 2
344

345
        // TypeResponse is the type of intercept message that is sent when an
346
        // RPC response message is sent from lnd to a client. For
347
        // server-streaming RPCs a new message of this type is sent for each
348
        // individual RPC response sent to the stream. Middleware has the option
349
        // to modify a response message before it is sent out to the client.
350
        TypeResponse InterceptType = 3
351
)
352

353
// InterceptionRequest is a struct holding all information that is sent to a
354
// middleware whenever there is something to intercept (auth, request,
355
// response).
356
type InterceptionRequest struct {
357
        // Type is the type of the interception message.
358
        Type InterceptType
359

360
        // StreamRPC is set to true if the invoked RPC method is client or
361
        // server streaming.
362
        StreamRPC bool
363

364
        // Macaroon holds the macaroon that the client sent to lnd.
365
        Macaroon *macaroon.Macaroon
366

367
        // RawMacaroon holds the raw binary serialized macaroon that the client
368
        // sent to lnd.
369
        RawMacaroon []byte
370

371
        // CustomCaveatName is the name of the custom caveat that the middleware
372
        // was intercepting for.
373
        CustomCaveatName string
374

375
        // CustomCaveatCondition is the condition of the custom caveat that the
376
        // middleware was intercepting for. This can be empty for custom caveats
377
        // that only have a name (marker caveats).
378
        CustomCaveatCondition string
379

380
        // FullURI is the full RPC method URI that was invoked.
381
        FullURI string
382

383
        // ProtoSerialized is the full request or response object in the
384
        // protobuf binary serialization format.
385
        ProtoSerialized []byte
386

387
        // ProtoTypeName is the fully qualified name of the protobuf type of the
388
        // request or response message that is serialized in the field above.
389
        ProtoTypeName string
390

391
        // IsError indicates that the message contained within this request is
392
        // an error. Will only ever be true for response messages.
393
        IsError bool
394
}
395

396
// NewMessageInterceptionRequest creates a new interception request for either
397
// a request or response message.
398
func NewMessageInterceptionRequest(ctx context.Context,
399
        authType InterceptType, isStream bool, fullMethod string,
400
        m interface{}) (*InterceptionRequest, error) {
4✔
401

4✔
402
        mac, rawMacaroon, err := macaroonFromContext(ctx)
4✔
403
        if err != nil {
4✔
404
                return nil, err
×
405
        }
×
406

407
        req := &InterceptionRequest{
4✔
408
                Type:        authType,
4✔
409
                StreamRPC:   isStream,
4✔
410
                Macaroon:    mac,
4✔
411
                RawMacaroon: rawMacaroon,
4✔
412
                FullURI:     fullMethod,
4✔
413
        }
4✔
414

4✔
415
        // The message is either a proto message or an error, we don't support
4✔
416
        // any other types being intercepted.
4✔
417
        switch t := m.(type) {
4✔
418
        case proto.Message:
4✔
419
                req.ProtoSerialized, err = proto.Marshal(t)
4✔
420
                if err != nil {
4✔
421
                        return nil, fmt.Errorf("cannot marshal proto msg: %w",
×
422
                                err)
×
423
                }
×
424
                req.ProtoTypeName = string(proto.MessageName(t))
4✔
425

426
        case error:
4✔
427
                req.ProtoSerialized = []byte(t.Error())
4✔
428
                req.ProtoTypeName = "error"
4✔
429
                req.IsError = true
4✔
430

431
        default:
×
432
                return nil, fmt.Errorf("unsupported type for interception "+
×
433
                        "request: %v", m)
×
434
        }
435

436
        return req, nil
4✔
437
}
438

439
// NewStreamAuthInterceptionRequest creates a new interception request for a
440
// stream authentication message.
441
func NewStreamAuthInterceptionRequest(ctx context.Context,
442
        fullMethod string) (*InterceptionRequest, error) {
4✔
443

4✔
444
        mac, rawMacaroon, err := macaroonFromContext(ctx)
4✔
445
        if err != nil {
4✔
446
                return nil, err
×
447
        }
×
448

449
        return &InterceptionRequest{
4✔
450
                Type:        TypeStreamAuth,
4✔
451
                StreamRPC:   true,
4✔
452
                Macaroon:    mac,
4✔
453
                RawMacaroon: rawMacaroon,
4✔
454
                FullURI:     fullMethod,
4✔
455
        }, nil
4✔
456
}
457

458
// macaroonFromContext tries to extract the macaroon from the incoming context.
459
// If there is no macaroon, a nil error is returned since some RPCs might not
460
// require a macaroon. But in case there is something in the macaroon header
461
// field that cannot be parsed, a non-nil error is returned.
462
func macaroonFromContext(ctx context.Context) (*macaroon.Macaroon, []byte,
463
        error) {
4✔
464

4✔
465
        macHex, err := macaroons.RawMacaroonFromContext(ctx)
4✔
466
        if err != nil {
4✔
467
                // If there is no macaroon, we continue anyway as it might be an
×
468
                // RPC that doesn't require a macaroon.
×
469
                return nil, nil, nil
×
470
        }
×
471

472
        macBytes, err := hex.DecodeString(macHex)
4✔
473
        if err != nil {
4✔
474
                return nil, nil, err
×
475
        }
×
476

477
        mac := &macaroon.Macaroon{}
4✔
478
        if err := mac.UnmarshalBinary(macBytes); err != nil {
4✔
479
                return nil, nil, err
×
480
        }
×
481

482
        return mac, macBytes, nil
4✔
483
}
484

485
// ToRPC converts the interception request to its RPC counterpart.
486
func (r *InterceptionRequest) ToRPC(requestID,
487
        msgID uint64) (*lnrpc.RPCMiddlewareRequest, error) {
4✔
488

4✔
489
        rpcRequest := &lnrpc.RPCMiddlewareRequest{
4✔
490
                RequestId:             requestID,
4✔
491
                MsgId:                 msgID,
4✔
492
                RawMacaroon:           r.RawMacaroon,
4✔
493
                CustomCaveatCondition: r.CustomCaveatCondition,
4✔
494
        }
4✔
495

4✔
496
        switch r.Type {
4✔
497
        case TypeStreamAuth:
4✔
498
                rpcRequest.InterceptType = &lnrpc.RPCMiddlewareRequest_StreamAuth{
4✔
499
                        StreamAuth: &lnrpc.StreamAuth{
4✔
500
                                MethodFullUri: r.FullURI,
4✔
501
                        },
4✔
502
                }
4✔
503

504
        case TypeRequest:
4✔
505
                rpcRequest.InterceptType = &lnrpc.RPCMiddlewareRequest_Request{
4✔
506
                        Request: &lnrpc.RPCMessage{
4✔
507
                                MethodFullUri: r.FullURI,
4✔
508
                                StreamRpc:     r.StreamRPC,
4✔
509
                                TypeName:      r.ProtoTypeName,
4✔
510
                                Serialized:    r.ProtoSerialized,
4✔
511
                        },
4✔
512
                }
4✔
513

514
        case TypeResponse:
4✔
515
                rpcRequest.InterceptType = &lnrpc.RPCMiddlewareRequest_Response{
4✔
516
                        Response: &lnrpc.RPCMessage{
4✔
517
                                MethodFullUri: r.FullURI,
4✔
518
                                StreamRpc:     r.StreamRPC,
4✔
519
                                TypeName:      r.ProtoTypeName,
4✔
520
                                Serialized:    r.ProtoSerialized,
4✔
521
                                IsError:       r.IsError,
4✔
522
                        },
4✔
523
                }
4✔
524

525
        default:
×
526
                return nil, fmt.Errorf("unknown intercept type %v", r.Type)
×
527
        }
528

529
        return rpcRequest, nil
4✔
530
}
531

532
// interceptRequest is a struct that keeps track of an interception request sent
533
// out to a middleware and the response that is eventually sent back by the
534
// middleware.
535
type interceptRequest struct {
536
        requestID uint64
537
        request   *InterceptionRequest
538
        response  chan *interceptResponse
539
}
540

541
// interceptResponse is the response a middleware sends back for each
542
// intercepted message.
543
type interceptResponse struct {
544
        err         error
545
        replace     bool
546
        replacement interface{}
547
}
548

549
// parseProto parses a proto serialized message of the given type into its
550
// native version.
551
func parseProto(typeName string, serialized []byte) (proto.Message, error) {
4✔
552
        messageType, err := protoregistry.GlobalTypes.FindMessageByName(
4✔
553
                protoreflect.FullName(typeName),
4✔
554
        )
4✔
555
        if err != nil {
4✔
556
                return nil, err
×
557
        }
×
558
        msg := messageType.New()
4✔
559
        err = proto.Unmarshal(serialized, msg.Interface())
4✔
560
        if err != nil {
4✔
561
                return nil, err
×
562
        }
×
563

564
        return msg.Interface(), nil
4✔
565
}
566

567
// replaceProtoMsg replaces the given target message with the content of the
568
// replacement message.
569
func replaceProtoMsg(target interface{}, replacement interface{}) error {
4✔
570
        targetMsg, ok := target.(proto.Message)
4✔
571
        if !ok {
4✔
572
                return fmt.Errorf("target is not a proto message: %v", target)
×
573
        }
×
574

575
        replacementMsg, ok := replacement.(proto.Message)
4✔
576
        if !ok {
4✔
577
                return fmt.Errorf("replacement is not a proto message: %v",
×
578
                        replacement)
×
579
        }
×
580

581
        if targetMsg.ProtoReflect().Type() !=
4✔
582
                replacementMsg.ProtoReflect().Type() {
4✔
583

×
584
                return fmt.Errorf("replacement message is of wrong type")
×
585
        }
×
586

587
        replacementBytes, err := proto.Marshal(replacementMsg)
4✔
588
        if err != nil {
4✔
589
                return fmt.Errorf("error marshaling replacement: %w", err)
×
590
        }
×
591
        err = proto.Unmarshal(replacementBytes, targetMsg)
4✔
592
        if err != nil {
4✔
593
                return fmt.Errorf("error unmarshaling replacement: %w", err)
×
594
        }
×
595

596
        return nil
4✔
597
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc