• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15736109134

18 Jun 2025 02:46PM UTC coverage: 58.197% (-10.1%) from 68.248%
15736109134

Pull #9752

github

web-flow
Merge d2634a68c into 31c74f20f
Pull Request #9752: routerrpc: reject payment to invoice that don't have payment secret or blinded paths

6 of 13 new or added lines in 2 files covered. (46.15%)

28331 existing lines in 455 files now uncovered.

97860 of 168153 relevant lines covered (58.2%)

1.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

76.9
/rpcperms/middleware_handler.go
1
package rpcperms
2

3
import (
4
        "context"
5
        "encoding/hex"
6
        "errors"
7
        "fmt"
8
        "sync"
9
        "sync/atomic"
10
        "time"
11

12
        "github.com/btcsuite/btcd/chaincfg"
13
        "github.com/lightningnetwork/lnd/lnrpc"
14
        "github.com/lightningnetwork/lnd/macaroons"
15
        "google.golang.org/grpc/metadata"
16
        "google.golang.org/protobuf/proto"
17
        "google.golang.org/protobuf/reflect/protoreflect"
18
        "google.golang.org/protobuf/reflect/protoregistry"
19
        "gopkg.in/macaroon.v2"
20
)
21

22
var (
23
        // ErrShuttingDown is the error that's returned when the server is
24
        // shutting down and a request cannot be served anymore.
25
        ErrShuttingDown = errors.New("server shutting down")
26

27
        // ErrTimeoutReached is the error that's returned if any of the
28
        // middleware's tasks is not completed in the given time.
29
        ErrTimeoutReached = errors.New("intercept timeout reached")
30

31
        // errClientQuit is the error that's returned if the client closes the
32
        // middleware communication stream before a request was fully handled.
33
        errClientQuit = errors.New("interceptor RPC client quit")
34
)
35

36
// MiddlewareHandler is a type that communicates with a middleware over the
37
// established bi-directional RPC stream. It sends messages to the middleware
38
// whenever the custom business logic implemented there should give feedback to
39
// a request or response that's happening on the main gRPC server.
40
type MiddlewareHandler struct {
41
        // lastMsgID is the ID of the last intercept message that was forwarded
42
        // to the middleware.
43
        //
44
        // NOTE: Must be used atomically!
45
        lastMsgID uint64
46

47
        middlewareName string
48

49
        readOnly bool
50

51
        customCaveatName string
52

53
        receive func() (*lnrpc.RPCMiddlewareResponse, error)
54

55
        send func(request *lnrpc.RPCMiddlewareRequest) error
56

57
        interceptRequests chan *interceptRequest
58

59
        timeout time.Duration
60

61
        // params are our current chain params.
62
        params *chaincfg.Params
63

64
        // done is closed when the rpc client terminates.
65
        done chan struct{}
66

67
        // quit is closed when lnd is shutting down.
68
        quit chan struct{}
69

70
        wg sync.WaitGroup
71
}
72

73
// NewMiddlewareHandler creates a new handler for the middleware with the given
74
// name and custom caveat name.
75
func NewMiddlewareHandler(name, customCaveatName string, readOnly bool,
76
        receive func() (*lnrpc.RPCMiddlewareResponse, error),
77
        send func(request *lnrpc.RPCMiddlewareRequest) error,
78
        timeout time.Duration, params *chaincfg.Params,
79
        quit chan struct{}) *MiddlewareHandler {
3✔
80

3✔
81
        // We explicitly want to log this as a warning since intercepting any
3✔
82
        // gRPC messages can also be used for malicious purposes and the user
3✔
83
        // should be made aware of the risks.
3✔
84
        log.Warnf("A new gRPC middleware with the name '%s' was registered "+
3✔
85
                " with custom_macaroon_caveat='%s', read_only=%v. Make sure "+
3✔
86
                "you trust the middleware author since that code will be able "+
3✔
87
                "to intercept and possibly modify any gRPC messages sent/"+
3✔
88
                "received to/from a client that has a macaroon with that "+
3✔
89
                "custom caveat.", name, customCaveatName, readOnly)
3✔
90

3✔
91
        return &MiddlewareHandler{
3✔
92
                middlewareName:    name,
3✔
93
                customCaveatName:  customCaveatName,
3✔
94
                readOnly:          readOnly,
3✔
95
                receive:           receive,
3✔
96
                send:              send,
3✔
97
                interceptRequests: make(chan *interceptRequest),
3✔
98
                timeout:           timeout,
3✔
99
                params:            params,
3✔
100
                done:              make(chan struct{}),
3✔
101
                quit:              quit,
3✔
102
        }
3✔
103
}
3✔
104

105
// intercept handles the full interception lifecycle of a single middleware
106
// event (stream authentication, request interception or response interception).
107
// The lifecycle consists of sending a message to the middleware, receiving a
108
// feedback on it and sending the feedback to the appropriate channel. All steps
109
// are guarded by the configured timeout to make sure a middleware cannot slow
110
// down requests too much.
111
func (h *MiddlewareHandler) intercept(requestID uint64,
112
        req *InterceptionRequest) (*interceptResponse, error) {
3✔
113

3✔
114
        respChan := make(chan *interceptResponse, 1)
3✔
115

3✔
116
        newRequest := &interceptRequest{
3✔
117
                requestID: requestID,
3✔
118
                request:   req,
3✔
119
                response:  respChan,
3✔
120
        }
3✔
121

3✔
122
        // timeout is the time after which intercept requests expire.
3✔
123
        timeout := time.After(h.timeout)
3✔
124

3✔
125
        // Send the request to the interceptRequests channel for the main
3✔
126
        // goroutine to be picked up.
3✔
127
        select {
3✔
128
        case h.interceptRequests <- newRequest:
3✔
129

130
        case <-timeout:
×
131
                log.Errorf("MiddlewareHandler returned error - reached "+
×
132
                        "timeout of %v for request interception", h.timeout)
×
133

×
134
                return nil, ErrTimeoutReached
×
135

136
        case <-h.done:
×
137
                return nil, errClientQuit
×
138

139
        case <-h.quit:
×
140
                return nil, ErrShuttingDown
×
141
        }
142

143
        // Receive the response and return it. If no response has been received
144
        // in AcceptorTimeout, then return false.
145
        select {
3✔
146
        case resp := <-respChan:
3✔
147
                return resp, nil
3✔
148

149
        case <-timeout:
×
150
                log.Errorf("MiddlewareHandler returned error - reached "+
×
151
                        "timeout of %v for response interception", h.timeout)
×
152
                return nil, ErrTimeoutReached
×
153

154
        case <-h.done:
×
155
                return nil, errClientQuit
×
156

157
        case <-h.quit:
×
158
                return nil, ErrShuttingDown
×
159
        }
160
}
161

162
// Run is the main loop for the middleware handler. This function will block
163
// until it receives the signal that lnd is shutting down, or the rpc stream is
164
// cancelled by the client.
165
func (h *MiddlewareHandler) Run() error {
3✔
166
        // Wait for our goroutines to exit before we return.
3✔
167
        defer h.wg.Wait()
3✔
168
        defer log.Debugf("Exiting middleware run loop for %s", h.middlewareName)
3✔
169

3✔
170
        // Create a channel that responses from middlewares are sent into.
3✔
171
        responses := make(chan *lnrpc.RPCMiddlewareResponse)
3✔
172

3✔
173
        // errChan is used by the receive loop to signal any errors that occur
3✔
174
        // during reading from the stream. This is primarily used to shutdown
3✔
175
        // the send loop in the case of an RPC client disconnecting.
3✔
176
        errChan := make(chan error, 1)
3✔
177

3✔
178
        // Start a goroutine to receive responses from the interceptor. We
3✔
179
        // expect the receive function to block, so it must be run in a
3✔
180
        // goroutine (otherwise we could not send more than one intercept
3✔
181
        // request to the client).
3✔
182
        h.wg.Add(1)
3✔
183
        go func() {
6✔
184
                defer h.wg.Done()
3✔
185

3✔
186
                h.receiveResponses(errChan, responses)
3✔
187
        }()
3✔
188

189
        return h.sendInterceptRequests(errChan, responses)
3✔
190
}
191

192
// receiveResponses receives responses for our intercept requests and dispatches
193
// them into the responses channel provided, sending any errors that occur into
194
// the error channel provided.
195
func (h *MiddlewareHandler) receiveResponses(errChan chan error,
196
        responses chan *lnrpc.RPCMiddlewareResponse) {
3✔
197

3✔
198
        for {
6✔
199
                resp, err := h.receive()
3✔
200
                if err != nil {
6✔
201
                        errChan <- err
3✔
202
                        return
3✔
203
                }
3✔
204

205
                select {
3✔
206
                case responses <- resp:
3✔
207

208
                case <-h.done:
×
209
                        return
×
210

211
                case <-h.quit:
×
212
                        return
×
213
                }
214
        }
215
}
216

217
// sendInterceptRequests handles intercept requests sent to us by our Accept()
218
// function, dispatching them to our acceptor stream and coordinating return of
219
// responses to their callers.
220
func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error,
221
        responses chan *lnrpc.RPCMiddlewareResponse) error {
3✔
222

3✔
223
        // Close the done channel to indicate that the interceptor is no longer
3✔
224
        // listening and any in-progress requests should be terminated.
3✔
225
        defer close(h.done)
3✔
226

3✔
227
        interceptRequests := make(map[uint64]*interceptRequest)
3✔
228

3✔
229
        for {
6✔
230
                select {
3✔
231
                // Consume requests passed to us from our Accept() function and
232
                // send them into our stream.
233
                case newRequest := <-h.interceptRequests:
3✔
234
                        msgID := atomic.AddUint64(&h.lastMsgID, 1)
3✔
235

3✔
236
                        req := newRequest.request
3✔
237
                        interceptRequests[msgID] = newRequest
3✔
238

3✔
239
                        interceptReq, err := req.ToRPC(
3✔
240
                                newRequest.requestID, msgID,
3✔
241
                        )
3✔
242
                        if err != nil {
3✔
243
                                return err
×
244
                        }
×
245

246
                        if err := h.send(interceptReq); err != nil {
3✔
247
                                return err
×
248
                        }
×
249

250
                // Process newly received responses from our interceptor,
251
                // looking the original request up in our map of requests and
252
                // dispatching the response.
253
                case resp := <-responses:
3✔
254
                        requestInfo, ok := interceptRequests[resp.RefMsgId]
3✔
255
                        if !ok {
3✔
256
                                continue
×
257
                        }
258

259
                        response := &interceptResponse{}
3✔
260
                        switch msg := resp.GetMiddlewareMessage().(type) {
3✔
261
                        case *lnrpc.RPCMiddlewareResponse_Feedback:
3✔
262
                                t := msg.Feedback
3✔
263
                                if t.Error != "" {
3✔
264
                                        response.err = fmt.Errorf("%s", t.Error)
×
265
                                        break
×
266
                                }
267

268
                                // If there's nothing to replace, we're done,
269
                                // this request was just accepted.
270
                                if !t.ReplaceResponse {
6✔
271
                                        break
3✔
272
                                }
273

274
                                // We are replacing the response, the question
275
                                // now just is: was it an error or a proper
276
                                // proto message?
277
                                response.replace = true
3✔
278
                                if requestInfo.request.IsError {
6✔
279
                                        response.replacement = errors.New(
3✔
280
                                                string(t.ReplacementSerialized),
3✔
281
                                        )
3✔
282

3✔
283
                                        break
3✔
284
                                }
285

286
                                // Not an error but a proper proto message that
287
                                // needs to be replaced. For that we need to
288
                                // parse it from the raw bytes into the full RPC
289
                                // message.
290
                                protoMsg, err := parseProto(
3✔
291
                                        requestInfo.request.ProtoTypeName,
3✔
292
                                        t.ReplacementSerialized,
3✔
293
                                )
3✔
294

3✔
295
                                if err != nil {
3✔
296
                                        response.err = err
×
297

×
298
                                        break
×
299
                                }
300

301
                                response.replacement = protoMsg
3✔
302

303
                        default:
×
304
                                return fmt.Errorf("unknown middleware "+
×
305
                                        "message: %v", msg)
×
306
                        }
307

308
                        select {
3✔
309
                        case requestInfo.response <- response:
3✔
310
                        case <-h.quit:
×
311
                        }
312

313
                        delete(interceptRequests, resp.RefMsgId)
3✔
314

315
                // If we failed to receive from our middleware, we exit.
316
                case err := <-errChan:
3✔
317
                        log.Errorf("Received an error: %v, shutting down", err)
3✔
318
                        return err
3✔
319

320
                // Exit if we are shutting down.
321
                case <-h.quit:
×
322
                        return ErrShuttingDown
×
323
                }
324
        }
325
}
326

327
// InterceptType defines the different types of intercept messages a middleware
328
// can receive.
329
type InterceptType uint8
330

331
const (
332
        // TypeStreamAuth is the type of intercept message that is sent when a
333
        // client or streaming RPC is initialized. A message with this type will
334
        // be sent out during stream initialization so a middleware can
335
        // accept/deny the whole stream instead of only single messages on the
336
        // stream.
337
        TypeStreamAuth InterceptType = 1
338

339
        // TypeRequest is the type of intercept message that is sent when an RPC
340
        // request message is sent to lnd. For client-streaming RPCs a new
341
        // message of this type is sent for each individual RPC request sent to
342
        // the stream. Middleware has the option to modify a request message
343
        // before it is delivered to lnd.
344
        TypeRequest InterceptType = 2
345

346
        // TypeResponse is the type of intercept message that is sent when an
347
        // RPC response message is sent from lnd to a client. For
348
        // server-streaming RPCs a new message of this type is sent for each
349
        // individual RPC response sent to the stream. Middleware has the option
350
        // to modify a response message before it is sent out to the client.
351
        TypeResponse InterceptType = 3
352
)
353

354
// InterceptionRequest is a struct holding all information that is sent to a
355
// middleware whenever there is something to intercept (auth, request,
356
// response).
357
type InterceptionRequest struct {
358
        // Type is the type of the interception message.
359
        Type InterceptType
360

361
        // StreamRPC is set to true if the invoked RPC method is client or
362
        // server streaming.
363
        StreamRPC bool
364

365
        // Macaroon holds the macaroon that the client sent to lnd.
366
        Macaroon *macaroon.Macaroon
367

368
        // RawMacaroon holds the raw binary serialized macaroon that the client
369
        // sent to lnd.
370
        RawMacaroon []byte
371

372
        // CustomCaveatName is the name of the custom caveat that the middleware
373
        // was intercepting for.
374
        CustomCaveatName string
375

376
        // CustomCaveatCondition is the condition of the custom caveat that the
377
        // middleware was intercepting for. This can be empty for custom caveats
378
        // that only have a name (marker caveats).
379
        CustomCaveatCondition string
380

381
        // FullURI is the full RPC method URI that was invoked.
382
        FullURI string
383

384
        // ProtoSerialized is the full request or response object in the
385
        // protobuf binary serialization format.
386
        ProtoSerialized []byte
387

388
        // ProtoTypeName is the fully qualified name of the protobuf type of the
389
        // request or response message that is serialized in the field above.
390
        ProtoTypeName string
391

392
        // IsError indicates that the message contained within this request is
393
        // an error. Will only ever be true for response messages.
394
        IsError bool
395

396
        // CtxMetadataPairs contains the metadata pairs that were sent along
397
        // with the RPC request via the context.
398
        CtxMetadataPairs metadata.MD
399
}
400

401
// NewMessageInterceptionRequest creates a new interception request for either
402
// a request or response message.
403
func NewMessageInterceptionRequest(ctx context.Context,
404
        authType InterceptType, isStream bool, fullMethod string,
405
        m interface{}) (*InterceptionRequest, error) {
3✔
406

3✔
407
        mac, rawMacaroon, err := macaroonFromContext(ctx)
3✔
408
        if err != nil {
3✔
409
                return nil, err
×
410
        }
×
411

412
        md, _ := metadata.FromIncomingContext(ctx)
3✔
413

3✔
414
        req := &InterceptionRequest{
3✔
415
                Type:             authType,
3✔
416
                StreamRPC:        isStream,
3✔
417
                Macaroon:         mac,
3✔
418
                RawMacaroon:      rawMacaroon,
3✔
419
                FullURI:          fullMethod,
3✔
420
                CtxMetadataPairs: md,
3✔
421
        }
3✔
422

3✔
423
        // The message is either a proto message or an error, we don't support
3✔
424
        // any other types being intercepted.
3✔
425
        switch t := m.(type) {
3✔
426
        case proto.Message:
3✔
427
                req.ProtoSerialized, err = proto.Marshal(t)
3✔
428
                if err != nil {
3✔
429
                        return nil, fmt.Errorf("cannot marshal proto msg: %w",
×
430
                                err)
×
431
                }
×
432
                req.ProtoTypeName = string(proto.MessageName(t))
3✔
433

434
        case error:
3✔
435
                req.ProtoSerialized = []byte(t.Error())
3✔
436
                req.ProtoTypeName = "error"
3✔
437
                req.IsError = true
3✔
438

439
        default:
×
440
                return nil, fmt.Errorf("unsupported type for interception "+
×
441
                        "request: %v", m)
×
442
        }
443

444
        return req, nil
3✔
445
}
446

447
// NewStreamAuthInterceptionRequest creates a new interception request for a
448
// stream authentication message.
449
func NewStreamAuthInterceptionRequest(ctx context.Context,
450
        fullMethod string) (*InterceptionRequest, error) {
3✔
451

3✔
452
        mac, rawMacaroon, err := macaroonFromContext(ctx)
3✔
453
        if err != nil {
3✔
454
                return nil, err
×
455
        }
×
456

457
        return &InterceptionRequest{
3✔
458
                Type:        TypeStreamAuth,
3✔
459
                StreamRPC:   true,
3✔
460
                Macaroon:    mac,
3✔
461
                RawMacaroon: rawMacaroon,
3✔
462
                FullURI:     fullMethod,
3✔
463
        }, nil
3✔
464
}
465

466
// macaroonFromContext tries to extract the macaroon from the incoming context.
467
// If there is no macaroon, a nil error is returned since some RPCs might not
468
// require a macaroon. But in case there is something in the macaroon header
469
// field that cannot be parsed, a non-nil error is returned.
470
func macaroonFromContext(ctx context.Context) (*macaroon.Macaroon, []byte,
471
        error) {
3✔
472

3✔
473
        macHex, err := macaroons.RawMacaroonFromContext(ctx)
3✔
474
        if err != nil {
3✔
475
                // If there is no macaroon, we continue anyway as it might be an
×
476
                // RPC that doesn't require a macaroon.
×
477
                return nil, nil, nil
×
478
        }
×
479

480
        macBytes, err := hex.DecodeString(macHex)
3✔
481
        if err != nil {
3✔
482
                return nil, nil, err
×
483
        }
×
484

485
        mac := &macaroon.Macaroon{}
3✔
486
        if err := mac.UnmarshalBinary(macBytes); err != nil {
3✔
487
                return nil, nil, err
×
488
        }
×
489

490
        return mac, macBytes, nil
3✔
491
}
492

493
// ToRPC converts the interception request to its RPC counterpart.
494
func (r *InterceptionRequest) ToRPC(requestID,
495
        msgID uint64) (*lnrpc.RPCMiddlewareRequest, error) {
3✔
496

3✔
497
        mdPairs := make(
3✔
498
                map[string]*lnrpc.MetadataValues, len(r.CtxMetadataPairs),
3✔
499
        )
3✔
500
        for key, values := range r.CtxMetadataPairs {
6✔
501
                mdPairs[key] = &lnrpc.MetadataValues{
3✔
502
                        Values: values,
3✔
503
                }
3✔
504
        }
3✔
505

506
        rpcRequest := &lnrpc.RPCMiddlewareRequest{
3✔
507
                RequestId:             requestID,
3✔
508
                MsgId:                 msgID,
3✔
509
                RawMacaroon:           r.RawMacaroon,
3✔
510
                CustomCaveatCondition: r.CustomCaveatCondition,
3✔
511
                MetadataPairs:         mdPairs,
3✔
512
        }
3✔
513

3✔
514
        switch r.Type {
3✔
515
        case TypeStreamAuth:
3✔
516
                rpcRequest.InterceptType = &lnrpc.RPCMiddlewareRequest_StreamAuth{
3✔
517
                        StreamAuth: &lnrpc.StreamAuth{
3✔
518
                                MethodFullUri: r.FullURI,
3✔
519
                        },
3✔
520
                }
3✔
521

522
        case TypeRequest:
3✔
523
                rpcRequest.InterceptType = &lnrpc.RPCMiddlewareRequest_Request{
3✔
524
                        Request: &lnrpc.RPCMessage{
3✔
525
                                MethodFullUri: r.FullURI,
3✔
526
                                StreamRpc:     r.StreamRPC,
3✔
527
                                TypeName:      r.ProtoTypeName,
3✔
528
                                Serialized:    r.ProtoSerialized,
3✔
529
                        },
3✔
530
                }
3✔
531

532
        case TypeResponse:
3✔
533
                rpcRequest.InterceptType = &lnrpc.RPCMiddlewareRequest_Response{
3✔
534
                        Response: &lnrpc.RPCMessage{
3✔
535
                                MethodFullUri: r.FullURI,
3✔
536
                                StreamRpc:     r.StreamRPC,
3✔
537
                                TypeName:      r.ProtoTypeName,
3✔
538
                                Serialized:    r.ProtoSerialized,
3✔
539
                                IsError:       r.IsError,
3✔
540
                        },
3✔
541
                }
3✔
542

543
        default:
×
544
                return nil, fmt.Errorf("unknown intercept type %v", r.Type)
×
545
        }
546

547
        return rpcRequest, nil
3✔
548
}
549

550
// interceptRequest is a struct that keeps track of an interception request sent
551
// out to a middleware and the response that is eventually sent back by the
552
// middleware.
553
type interceptRequest struct {
554
        requestID uint64
555
        request   *InterceptionRequest
556
        response  chan *interceptResponse
557
}
558

559
// interceptResponse is the response a middleware sends back for each
560
// intercepted message.
561
type interceptResponse struct {
562
        err         error
563
        replace     bool
564
        replacement interface{}
565
}
566

567
// parseProto parses a proto serialized message of the given type into its
568
// native version.
569
func parseProto(typeName string, serialized []byte) (proto.Message, error) {
3✔
570
        messageType, err := protoregistry.GlobalTypes.FindMessageByName(
3✔
571
                protoreflect.FullName(typeName),
3✔
572
        )
3✔
573
        if err != nil {
3✔
574
                return nil, err
×
575
        }
×
576
        msg := messageType.New()
3✔
577
        err = proto.Unmarshal(serialized, msg.Interface())
3✔
578
        if err != nil {
3✔
579
                return nil, err
×
580
        }
×
581

582
        return msg.Interface(), nil
3✔
583
}
584

585
// replaceProtoMsg replaces the given target message with the content of the
586
// replacement message.
587
func replaceProtoMsg(target interface{}, replacement interface{}) error {
3✔
588
        targetMsg, ok := target.(proto.Message)
3✔
589
        if !ok {
3✔
UNCOV
590
                return fmt.Errorf("target is not a proto message: %v", target)
×
UNCOV
591
        }
×
592

593
        replacementMsg, ok := replacement.(proto.Message)
3✔
594
        if !ok {
3✔
595
                return fmt.Errorf("replacement is not a proto message: %v",
×
596
                        replacement)
×
597
        }
×
598

599
        if targetMsg.ProtoReflect().Type() !=
3✔
600
                replacementMsg.ProtoReflect().Type() {
3✔
UNCOV
601

×
UNCOV
602
                return fmt.Errorf("replacement message is of wrong type")
×
UNCOV
603
        }
×
604

605
        replacementBytes, err := proto.Marshal(replacementMsg)
3✔
606
        if err != nil {
3✔
607
                return fmt.Errorf("error marshaling replacement: %w", err)
×
608
        }
×
609
        err = proto.Unmarshal(replacementBytes, targetMsg)
3✔
610
        if err != nil {
3✔
611
                return fmt.Errorf("error unmarshaling replacement: %w", err)
×
612
        }
×
613

614
        return nil
3✔
615
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc