• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mendersoftware / mender-cli / 1779286128

22 Apr 2025 08:09AM UTC coverage: 1.737% (-30.1%) from 31.802%
1779286128

Pull #277

gitlab-ci

alfrunes
chore: Improve dockerfile by pinning versions and pruning dependencies

Changed the builder image to the debian (official) golang image and
upgraded to the latest version.

Signed-off-by: Alf-Rune Siqveland <alf.rune@northern.tech>
Pull Request #277: MEN-7794: Add support for pagination when listing devices

28 of 82 new or added lines in 4 files covered. (34.15%)

770 existing lines in 17 files now uncovered.

45 of 2590 relevant lines covered (1.74%)

0.04 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/cmd/portforward.go
1
// Copyright 2022 Northern.tech AS
2
//
3
//    Licensed under the Apache License, Version 2.0 (the "License");
4
//    you may not use this file except in compliance with the License.
5
//    You may obtain a copy of the License at
6
//
7
//        http://www.apache.org/licenses/LICENSE-2.0
8
//
9
//    Unless required by applicable law or agreed to in writing, software
10
//    distributed under the License is distributed on an "AS IS" BASIS,
11
//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
//    See the License for the specific language governing permissions and
13
//    limitations under the License.
14

15
package cmd
16

17
import (
18
        "context"
19
        "fmt"
20
        "os"
21
        "os/signal"
22
        "strconv"
23
        "strings"
24
        "time"
25

26
        "github.com/mendersoftware/go-lib-micro/ws"
27
        wspf "github.com/mendersoftware/go-lib-micro/ws/portforward"
28
        "github.com/pkg/errors"
29
        "github.com/spf13/cobra"
30
        "github.com/spf13/viper"
31
        "github.com/vmihailenco/msgpack"
32
        "golang.org/x/sys/unix"
33

34
        "github.com/mendersoftware/mender-cli/client/deviceconnect"
35
)
36

37
const (
38
        argBindHost    = "bind"
39
        readBuffLength = 4096
40
        localhost      = "127.0.0.1"
41
)
42

43
var portForwardCmd = &cobra.Command{
44
        Use: "port-forward DEVICE_ID [tcp|udp/]LOCAL_PORT[:REMOTE_PORT]" +
45
                " [[tcp|udp/]LOCAL_PORT[:REMOTE_PORT]...]",
46
        Short: "Forward one or more local ports to remote port(s) on the device",
47
        Long: "This command supports both TCP and UDP port-forwarding.\n\n" +
48
                "The port specification can be prefixed with \"tcp/\" or \"udp/\".\n" +
49
                "If no prefix is specified, TCP is the default.\n\n" +
50
                "REMOTE_PORT can also be specified in the form REMOTE_HOST:REMOTE_PORT, making\n" +
51
                "it possible to port-forward to third hosts running in the device's network.\n" +
52
                "In this case, the specification will be LOCAL_PORT:REMOTE_HOST:REMOTE_PORT.\n\n" +
53
                "You can specify multiple port mapping specifications.",
54
        Example: "  mender-cli port-forward DEVICE_ID 8000:8000\n" +
55
                "  mender-cli port-forward DEVICE_ID udp/8000:8000\n" +
56
                "  mender-cli port-forward DEVICE_ID tcp/8000:192.168.1.1:8000",
57
        Args: cobra.MinimumNArgs(2),
58
        Run: func(c *cobra.Command, args []string) {
×
59
                cmd, err := NewPortForwardCmd(c, args)
×
60
                CheckErr(err)
×
61
                CheckErr(cmd.Run())
×
62
        },
×
63
}
64

65
var portForwardMaxDuration = 24 * time.Hour
66

67
var errPortForwardNotImplemented = errors.New(
68
        "port forward not implemented or enabled on the device",
69
)
70
var errRestart = errors.New("restart")
71

UNCOV
72
func init() {
×
UNCOV
73
        portForwardCmd.Flags().StringP(argBindHost, "", localhost, "binding host")
×
UNCOV
74
}
×
75

76
const (
77
        protocolTCP = "tcp"
78
        protocolUDP = "udp"
79
)
80

81
type portMapping struct {
82
        Protocol   string
83
        LocalPort  uint16
84
        RemoteHost string
85
        RemotePort uint16
86
}
87

88
// PortForwardCmd handles the port-forward command
89
type PortForwardCmd struct {
90
        server       string
91
        token        string
92
        skipVerify   bool
93
        deviceID     string
94
        sessionID    string
95
        bindingHost  string
96
        portMappings []portMapping
97
        recvChans    map[string]chan *ws.ProtoMsg
98
        running      bool
99
        stop         chan struct{}
100
        err          error
101
}
102

103
func getPortMappings(args []string) ([]portMapping, error) {
×
104
        var err error
×
105
        portMappings := []portMapping{}
×
106
        for _, arg := range args {
×
107
                remoteHost := localhost
×
108
                protocol := wspf.PortForwardProtocolTCP
×
109
                if strings.Contains(arg, "/") {
×
110
                        parts := strings.SplitN(arg, "/", 2)
×
111
                        if parts[0] == protocolTCP {
×
112
                                protocol = protocolTCP
×
113
                        } else if parts[0] == protocolUDP {
×
114
                                protocol = protocolUDP
×
115
                        } else {
×
116
                                return nil, errors.New("unknown protocol: " + parts[0])
×
117
                        }
×
118
                        arg = parts[1]
×
119
                }
120
                var localPort, remotePort int
×
121
                if strings.Contains(arg, ":") {
×
122
                        parts := strings.SplitN(arg, ":", 3)
×
123
                        if len(parts) == 3 {
×
124
                                remoteHost = parts[1]
×
125
                                parts = []string{parts[0], parts[2]}
×
126
                        }
×
127
                        localPort, err = strconv.Atoi(parts[0])
×
128
                        if err != nil || localPort < 0 || localPort > 65536 {
×
129
                                return nil, errors.New("invalid port number: " + parts[0])
×
130
                        }
×
131
                        remotePort, err = strconv.Atoi(parts[1])
×
132
                        if err != nil || remotePort < 0 || remotePort > 65536 {
×
133
                                return nil, errors.New("invalid port number: " + parts[1])
×
134
                        }
×
135
                } else {
×
136
                        port, err := strconv.Atoi(arg)
×
137
                        if err != nil || port < 0 || port > 65536 {
×
138
                                return nil, errors.New("invalid port number: " + arg)
×
139
                        }
×
140
                        localPort = port
×
141
                        remotePort = port
×
142
                }
143
                portMappings = append(portMappings, portMapping{
×
144
                        Protocol:   protocol,
×
145
                        LocalPort:  uint16(localPort),
×
146
                        RemoteHost: remoteHost,
×
147
                        RemotePort: uint16(remotePort),
×
148
                })
×
149
        }
150
        return portMappings, nil
×
151
}
152

153
// NewPortForwardCmd returns a new PortForwardCmd
154
func NewPortForwardCmd(cmd *cobra.Command, args []string) (*PortForwardCmd, error) {
×
155
        server := viper.GetString(argRootServer)
×
156
        if server == "" {
×
157
                return nil, errors.New("No server")
×
158
        }
×
159

160
        skipVerify, err := cmd.Flags().GetBool(argRootSkipVerify)
×
161
        if err != nil {
×
162
                return nil, err
×
163
        }
×
164

165
        bindingHost, err := cmd.Flags().GetString(argBindHost)
×
166
        if err != nil {
×
167
                return nil, err
×
168
        }
×
169

170
        portMappings, err := getPortMappings(args[1:])
×
171
        if err != nil {
×
172
                return nil, err
×
173
        }
×
174

175
        token, err := getAuthToken(cmd)
×
176
        if err != nil {
×
177
                return nil, err
×
178
        }
×
179

180
        return &PortForwardCmd{
×
181
                server:       server,
×
182
                token:        token,
×
183
                skipVerify:   skipVerify,
×
184
                deviceID:     args[0],
×
185
                bindingHost:  bindingHost,
×
186
                portMappings: portMappings,
×
187
                recvChans:    make(map[string]chan *ws.ProtoMsg),
×
188
                stop:         make(chan struct{}),
×
189
        }, nil
×
190
}
191

192
// Run executes the command
193
func (c *PortForwardCmd) Run() error {
×
194
        for {
×
195
                if err := c.run(); err != errRestart {
×
196
                        return err
×
197
                }
×
198
        }
199
}
200

201
func (c *PortForwardCmd) run() error {
×
202
        ctx, cancelContext := context.WithCancel(context.Background())
×
203
        defer cancelContext()
×
204

×
205
        client := deviceconnect.NewClient(c.server, c.token, c.skipVerify)
×
206

×
207
        // check if the device is connected
×
208
        device, err := client.GetDevice(c.deviceID)
×
209
        if err != nil {
×
210
                return errors.Wrap(err, "unable to get the device")
×
211
        } else if device.Status != deviceconnect.CONNECTED {
×
212
                return errors.New("the device is not connected")
×
213
        }
×
214

215
        // connect to the websocket and start the ping-pong connection health-check
216
        err = client.Connect(c.deviceID, c.token)
×
217
        if err != nil {
×
218
                return err
×
219
        }
×
220

221
        go client.PingPong(ctx)
×
222
        defer client.Close()
×
223

×
224
        // perform ws protocol handshake
×
225
        err = c.handshake(client)
×
226
        if err != nil {
×
227
                return err
×
228
        }
×
229

230
        // message channel
231
        msgChan := make(chan *ws.ProtoMsg)
×
232

×
233
        // start the local TCP listeners
×
234
        for _, portMapping := range c.portMappings {
×
235
                switch portMapping.Protocol {
×
236
                case protocolTCP:
×
237
                        forwarder, err := NewTCPPortForwarder(c.bindingHost, portMapping.LocalPort,
×
238
                                portMapping.RemoteHost, portMapping.RemotePort)
×
239
                        if err != nil {
×
240
                                return err
×
241
                        }
×
242
                        go forwarder.Run(ctx, c.sessionID, msgChan, c.recvChans)
×
243
                case protocolUDP:
×
244
                        forwarder, err := NewUDPPortForwarder(c.bindingHost, portMapping.LocalPort,
×
245
                                portMapping.RemoteHost, portMapping.RemotePort)
×
246
                        if err != nil {
×
247
                                return err
×
248
                        }
×
249
                        go forwarder.Run(ctx, c.sessionID, msgChan, c.recvChans)
×
250
                default:
×
251
                        return errors.New("unknown protocol: " + portMapping.Protocol)
×
252
                }
253
        }
254

255
        c.running = true
×
256
        go c.processIncomingMessages(msgChan, client)
×
257

×
258
        // handle CTRL+C and signals
×
259
        quit := make(chan os.Signal, 1)
×
260
        signal.Notify(quit, unix.SIGINT, unix.SIGTERM)
×
261

×
262
        // wait for CTRL+C, signals or stop
×
263
        restart := false
×
264
        timeout := time.Now().Add(portForwardMaxDuration)
×
265
        for c.running {
×
266
                select {
×
267
                case msg := <-msgChan:
×
268
                        err := client.WriteMessage(msg)
×
269
                        if err != nil {
×
270
                                c.err = err
×
271
                                break
×
272
                        }
273
                case <-time.After(time.Until(timeout)):
×
274
                        c.err = errors.New("port forward timed out: max duration reached")
×
275
                        c.running = false
×
276
                case <-quit:
×
277
                        c.running = false
×
278
                case <-c.stop:
×
279
                        restart = true
×
280
                        c.running = false
×
281
                }
282
        }
283

284
        // cancel the context
285
        cancelContext()
×
286

×
287
        // close the ws session
×
288
        err = c.closeSession(client)
×
289
        if c.err == nil && err != nil {
×
290
                c.err = err
×
291
        }
×
292

293
        // if stopping because of an error, restart the port-forwarding command
294
        if restart {
×
295
                return errRestart
×
296
        }
×
297

298
        // return the error message (if any)
299
        return c.err
×
300
}
301

302
func (c *PortForwardCmd) Stop() {
×
303
        c.stop <- struct{}{}
×
304
}
×
305

306
// handshake initiates a handshake and checks that the device
307
// is willing to accept port forward requests.
308
func (c *PortForwardCmd) handshake(client *deviceconnect.Client) error {
×
309
        // open the session
×
310
        body, err := msgpack.Marshal(&ws.Open{
×
311
                Versions: []int{ws.ProtocolVersion},
×
312
        })
×
313
        if err != nil {
×
314
                return err
×
315
        }
×
316
        m := &ws.ProtoMsg{
×
317
                Header: ws.ProtoHdr{
×
318
                        Proto:   ws.ProtoTypeControl,
×
319
                        MsgType: ws.MessageTypeOpen,
×
320
                },
×
321
                Body: body,
×
322
        }
×
323
        err = client.WriteMessage(m)
×
324
        if err != nil {
×
325
                return err
×
326
        }
×
327

328
        msg, err := client.ReadMessage()
×
329
        if err != nil {
×
330
                return err
×
331
        }
×
332
        if msg.Header.MsgType == ws.MessageTypeError {
×
333
                erro := new(ws.Error)
×
334
                _ = msgpack.Unmarshal(msg.Body, erro)
×
335
                return errors.Errorf("handshake error from client: %s", erro.Error)
×
336
        } else if msg.Header.MsgType != ws.MessageTypeAccept {
×
337
                return errPortForwardNotImplemented
×
338
        }
×
339

340
        accept := new(ws.Accept)
×
341
        err = msgpack.Unmarshal(msg.Body, accept)
×
342
        if err != nil {
×
343
                return err
×
344
        }
×
345

346
        found := false
×
347
        for _, proto := range accept.Protocols {
×
348
                if proto == ws.ProtoTypePortForward {
×
349
                        found = true
×
350
                        break
×
351
                }
352
        }
353
        if !found {
×
354
                return errPortForwardNotImplemented
×
355
        }
×
356

357
        c.sessionID = msg.Header.SessionID
×
358
        return nil
×
359
}
360

361
// closeSession closes the WS session
362
func (c *PortForwardCmd) closeSession(client *deviceconnect.Client) error {
×
363
        m := &ws.ProtoMsg{
×
364
                Header: ws.ProtoHdr{
×
365
                        Proto:   ws.ProtoTypeControl,
×
366
                        MsgType: ws.MessageTypeClose,
×
367
                },
×
368
        }
×
369
        err := client.WriteMessage(m)
×
370
        if err != nil {
×
371
                return err
×
372
        }
×
373

374
        return nil
×
375
}
376

377
func (c *PortForwardCmd) processIncomingMessages(
378
        msgChan chan *ws.ProtoMsg,
379
        client *deviceconnect.Client,
380
) {
×
381
        for c.running {
×
382
                m, err := client.ReadMessage()
×
383
                if err != nil {
×
384
                        c.err = err
×
385
                        c.Stop()
×
386
                        break
×
387
                } else if m.Header.Proto == ws.ProtoTypeControl && m.Header.MsgType == ws.MessageTypePing {
×
388
                        m := &ws.ProtoMsg{
×
389
                                Header: ws.ProtoHdr{
×
390
                                        Proto:     ws.ProtoTypeControl,
×
391
                                        MsgType:   ws.MessageTypePong,
×
392
                                        SessionID: c.sessionID,
×
393
                                },
×
394
                        }
×
395
                        msgChan <- m
×
396
                } else if m.Header.Proto == ws.ProtoTypePortForward &&
×
397
                        m.Header.MsgType == ws.MessageTypeError {
×
398
                        erro := new(ws.Error)
×
399
                        if err := msgpack.Unmarshal(m.Body, erro); err != nil &&
×
400
                                erro.MessageType != wspf.MessageTypePortForwardStop {
×
401
                                c.err = errors.New(fmt.Sprintf(
×
402
                                        "Unable to start the port-forwarding: %s",
×
403
                                        string(m.Body),
×
404
                                ))
×
405
                                c.running = false
×
406
                                c.Stop()
×
407
                        }
×
408
                } else if m.Header.Proto == ws.ProtoTypePortForward &&
×
409
                        (m.Header.MsgType == wspf.MessageTypePortForward ||
×
410
                                m.Header.MsgType == wspf.MessageTypePortForwardAck ||
×
411
                                m.Header.MsgType == wspf.MessageTypePortForwardStop) {
×
412
                        connectionID, _ := m.Header.Properties[wspf.PropertyConnectionID].(string)
×
413
                        if connectionID != "" {
×
414
                                if recvChan, ok := c.recvChans[connectionID]; ok {
×
415
                                        recvChan <- m
×
416
                                }
×
417
                        }
418
                }
419
        }
420
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc