• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mendersoftware / mender-server / 1971772518

07 Aug 2025 09:30AM UTC coverage: 65.502% (-0.002%) from 65.504%
1971772518

push

gitlab-ci

web-flow
Merge pull request #827 from alfrunes/QA-721

QA-721: Change database cleanup strategy to preserve indexes

32314 of 49333 relevant lines covered (65.5%)

1.39 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

73.32
/backend/services/deviceconnect/api/http/management.go
1
// Copyright 2023 Northern.tech AS
2
//
3
//    Licensed under the Apache License, Version 2.0 (the "License");
4
//    you may not use this file except in compliance with the License.
5
//    You may obtain a copy of the License at
6
//
7
//        http://www.apache.org/licenses/LICENSE-2.0
8
//
9
//    Unless required by applicable law or agreed to in writing, software
10
//    distributed under the License is distributed on an "AS IS" BASIS,
11
//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
//    See the License for the specific language governing permissions and
13
//    limitations under the License.
14

15
package http
16

17
import (
18
        "bufio"
19
        "context"
20
        "encoding/binary"
21
        "encoding/json"
22
        "io"
23
        "net/http"
24
        "strconv"
25
        "sync"
26
        "time"
27

28
        "github.com/gin-gonic/gin"
29
        validation "github.com/go-ozzo/ozzo-validation/v4"
30
        "github.com/gorilla/websocket"
31
        natsio "github.com/nats-io/nats.go"
32
        "github.com/pkg/errors"
33
        "github.com/vmihailenco/msgpack/v5"
34

35
        "github.com/mendersoftware/mender-server/pkg/identity"
36
        "github.com/mendersoftware/mender-server/pkg/log"
37
        "github.com/mendersoftware/mender-server/pkg/requestid"
38
        "github.com/mendersoftware/mender-server/pkg/rest.utils"
39
        "github.com/mendersoftware/mender-server/pkg/ws"
40
        "github.com/mendersoftware/mender-server/pkg/ws/menderclient"
41
        "github.com/mendersoftware/mender-server/pkg/ws/shell"
42

43
        "github.com/mendersoftware/mender-server/services/deviceconnect/app"
44
        "github.com/mendersoftware/mender-server/services/deviceconnect/client/nats"
45
        "github.com/mendersoftware/mender-server/services/deviceconnect/model"
46
)
47

48
// HTTP errors
49
var (
50
        ErrMissingUserAuthentication = errors.New(
51
                "missing or non-user identity in the authorization headers",
52
        )
53
        ErrMsgSessionLimit = "session byte limit exceeded"
54

55
        //The name of the field holding a number of milliseconds to sleep between
56
        //the consecutive writes of session recording data. Note that it does not have
57
        //anything to do with the sleep between the keystrokes send, lines printed,
58
        //or screen blinks, we are only aware of the stream of bytes.
59
        PlaybackSleepIntervalMsField = "sleep_ms"
60

61
        //The name of the field in the query parameter to GET that holds the id of a session
62
        PlaybackSessionIDField = "sessionId"
63

64
        //The threshold between the shell commands received (keystrokes) above which the
65
        //delay control message is saved (1.5 seconds)
66
        keyStrokeDelayRecordingThresholdNs = int64(1500 * 1000000)
67

68
        //The key stroke delay is recorded in two bytes, so this is the maximal
69
        //possible delay. We round down to this if the real delay is larger
70
        keyStrokeMaxDelayRecording = int64(65535 * 1000000)
71
)
72

73
const channelSize = 25 // TODO make configurable
74

75
const (
76
        PropertyUserID = "user_id"
77
)
78

79
var wsUpgrader = websocket.Upgrader{
80
        Subprotocols: []string{"protomsg/msgpack"},
81
        CheckOrigin:  allowAllOrigins,
82
        Error: func(
83
                w http.ResponseWriter, r *http.Request, s int, e error,
84
        ) {
3✔
85
                w.WriteHeader(s)
3✔
86
                enc := json.NewEncoder(w)
3✔
87
                _ = enc.Encode(rest.Error{
3✔
88
                        Err:       e.Error(),
3✔
89
                        RequestID: requestid.FromContext(r.Context())},
3✔
90
                )
3✔
91
        },
3✔
92
}
93

94
// ManagementController container for end-points
95
type ManagementController struct {
96
        app  app.App
97
        nats nats.Client
98
}
99

100
// NewManagementController returns a new ManagementController
101
func NewManagementController(
102
        app app.App,
103
        nc nats.Client,
104
) *ManagementController {
3✔
105
        return &ManagementController{
3✔
106
                app:  app,
3✔
107
                nats: nc,
3✔
108
        }
3✔
109
}
3✔
110

111
// GetDevice returns a device
112
func (h ManagementController) GetDevice(c *gin.Context) {
3✔
113
        ctx := c.Request.Context()
3✔
114

3✔
115
        idata := identity.FromContext(ctx)
3✔
116
        if idata == nil || !idata.IsUser {
3✔
117
                c.JSON(http.StatusBadRequest, gin.H{
×
118
                        "error": ErrMissingUserAuthentication.Error(),
×
119
                })
×
120
                return
×
121
        }
×
122
        tenantID := idata.Tenant
3✔
123
        deviceID := c.Param("deviceId")
3✔
124

3✔
125
        device, err := h.app.GetDevice(ctx, tenantID, deviceID)
3✔
126
        if err == app.ErrDeviceNotFound {
4✔
127
                c.JSON(http.StatusNotFound, gin.H{
1✔
128
                        "error": err.Error(),
1✔
129
                })
1✔
130
                return
1✔
131
        } else if err != nil {
5✔
132
                c.JSON(http.StatusBadRequest, gin.H{
1✔
133
                        "error": err.Error(),
1✔
134
                })
1✔
135
                return
1✔
136
        }
1✔
137

138
        c.JSON(http.StatusOK, device)
3✔
139
}
140

141
// Connect extracts identity from request, checks user permissions
142
// and calls ConnectDevice
143
func (h ManagementController) Connect(c *gin.Context) {
2✔
144
        ctx := c.Request.Context()
2✔
145
        l := log.FromContext(ctx)
2✔
146

2✔
147
        idata := identity.FromContext(ctx)
2✔
148
        if !idata.IsUser {
2✔
149
                c.JSON(http.StatusBadRequest, gin.H{
×
150
                        "error": ErrMissingUserAuthentication.Error(),
×
151
                })
×
152
                return
×
153
        }
×
154

155
        tenantID := idata.Tenant
2✔
156
        userID := idata.Subject
2✔
157
        deviceID := c.Param("deviceId")
2✔
158

2✔
159
        session := &model.Session{
2✔
160
                TenantID:           tenantID,
2✔
161
                UserID:             userID,
2✔
162
                DeviceID:           deviceID,
2✔
163
                StartTS:            time.Now(),
2✔
164
                BytesRecordedMutex: &sync.Mutex{},
2✔
165
                Types:              []string{},
2✔
166
        }
2✔
167

2✔
168
        // Prepare the user session
2✔
169
        err := h.app.PrepareUserSession(ctx, session)
2✔
170
        if err == app.ErrDeviceNotFound || err == app.ErrDeviceNotConnected {
4✔
171
                c.JSON(http.StatusNotFound, gin.H{
2✔
172
                        "error": err.Error(),
2✔
173
                })
2✔
174
                return
2✔
175
        } else if _, ok := errors.Cause(err).(validation.Errors); ok {
4✔
176
                c.JSON(http.StatusBadRequest, gin.H{
×
177
                        "error": err.Error(),
×
178
                })
×
179
                return
×
180
        } else if err != nil {
3✔
181
                l.Error(err)
1✔
182
                c.JSON(http.StatusInternalServerError, gin.H{
1✔
183
                        "error": err.Error(),
1✔
184
                })
1✔
185
                return
1✔
186
        }
1✔
187
        defer func() {
4✔
188
                err := h.app.FreeUserSession(ctx, session.ID, session.Types)
2✔
189
                if err != nil {
2✔
190
                        l.Warnf("failed to free session: %s", err.Error())
×
191
                }
×
192
        }()
193

194
        deviceChan := make(chan *natsio.Msg, channelSize)
2✔
195
        sub, err := h.nats.ChanSubscribe(session.Subject(tenantID), deviceChan)
2✔
196
        if err != nil {
2✔
197
                l.Error(err)
×
198
                c.JSON(http.StatusInternalServerError, gin.H{
×
199
                        "error": "failed to establish internal device session",
×
200
                })
×
201
                return
×
202
        }
×
203
        //nolint:errcheck
204
        defer sub.Unsubscribe()
2✔
205

2✔
206
        // upgrade get request to websocket protocol
2✔
207
        conn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil)
2✔
208
        if err != nil {
4✔
209
                err = errors.Wrap(err, "unable to upgrade the request to websocket protocol")
2✔
210
                l.Error(err)
2✔
211
                // upgrader.Upgrade has already responded
2✔
212
                return
2✔
213
        }
2✔
214
        conn.SetReadLimit(int64(app.MessageSizeLimit))
2✔
215

2✔
216
        //nolint:errcheck
2✔
217
        h.ConnectServeWS(ctx, conn, session, deviceChan)
2✔
218
}
219

220
func (h ManagementController) Playback(c *gin.Context) {
2✔
221
        ctx := c.Request.Context()
2✔
222
        l := log.FromContext(ctx)
2✔
223

2✔
224
        idata := identity.FromContext(ctx)
2✔
225
        if !idata.IsUser {
2✔
226
                c.JSON(http.StatusBadRequest, gin.H{
×
227
                        "error": ErrMissingUserAuthentication.Error(),
×
228
                })
×
229
                return
×
230
        }
×
231

232
        tenantID := idata.Tenant
2✔
233
        userID := idata.Subject
2✔
234
        sessionID := c.Param(PlaybackSessionIDField)
2✔
235
        session := &model.Session{
2✔
236
                TenantID:           tenantID,
2✔
237
                UserID:             userID,
2✔
238
                StartTS:            time.Now(),
2✔
239
                BytesRecordedMutex: &sync.Mutex{},
2✔
240
        }
2✔
241
        sleepInterval := c.Param(PlaybackSleepIntervalMsField)
2✔
242
        sleepMilliseconds := uint(app.DefaultPlaybackSleepIntervalMs)
2✔
243
        if len(sleepInterval) > 1 {
2✔
244
                n, err := strconv.ParseUint(sleepInterval, 10, 32)
×
245
                if err != nil {
×
246
                        sleepMilliseconds = uint(n)
×
247
                }
×
248
        }
249

250
        l.Infof("Playing back the session session_id=%s", sessionID)
2✔
251

2✔
252
        // upgrade get request to websocket protocol
2✔
253
        conn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil)
2✔
254
        if err != nil {
4✔
255
                err = errors.Wrap(err, "unable to upgrade the request to websocket protocol")
2✔
256
                l.Error(err)
2✔
257
                return
2✔
258
        }
2✔
259
        conn.SetReadLimit(int64(app.MessageSizeLimit))
1✔
260

1✔
261
        deviceChan := make(chan *natsio.Msg, channelSize)
1✔
262
        errChan := make(chan error, 1)
1✔
263

1✔
264
        //nolint:errcheck
1✔
265
        go h.websocketWriter(ctx,
1✔
266
                conn,
1✔
267
                session,
1✔
268
                deviceChan,
1✔
269
                errChan,
1✔
270
                bufio.NewWriterSize(io.Discard, app.RecorderBufferSize),
1✔
271
                bufio.NewWriterSize(io.Discard, app.RecorderBufferSize))
1✔
272

1✔
273
        go func() {
2✔
274
                err = h.app.GetSessionRecording(ctx,
1✔
275
                        sessionID,
1✔
276
                        app.NewPlayback(deviceChan, sleepMilliseconds))
1✔
277
                if err != nil {
1✔
278
                        err = errors.Wrap(err, "unable to get the session.")
×
279
                        errChan <- err
×
280
                        return
×
281
                }
×
282
        }()
283
        // We need to keep reading in order to keep ping/pong handlers functioning.
284
        for ; err == nil; _, _, err = conn.NextReader() {
2✔
285
        }
1✔
286
}
287

288
func writerFinalizer(conn *websocket.Conn, e *error, l *log.Logger) {
2✔
289
        err := *e
2✔
290
        if err != nil {
4✔
291
                if !websocket.IsUnexpectedCloseError(errors.Cause(err)) {
4✔
292
                        errMsg := err.Error()
2✔
293
                        errBody := make([]byte, len(errMsg)+2)
2✔
294
                        binary.BigEndian.PutUint16(errBody,
2✔
295
                                websocket.CloseInternalServerErr)
2✔
296
                        copy(errBody[2:], errMsg)
2✔
297
                        errClose := conn.WriteControl(
2✔
298
                                websocket.CloseMessage,
2✔
299
                                errBody,
2✔
300
                                time.Now().Add(writeWait),
2✔
301
                        )
2✔
302
                        if errClose != nil {
3✔
303
                                err = errors.Wrapf(err,
1✔
304
                                        "error sending websocket close frame: %s",
1✔
305
                                        errClose.Error(),
1✔
306
                                )
1✔
307
                        }
1✔
308
                }
309
                l.Errorf("websocket closed with error: %s", err.Error())
2✔
310
        }
311
        conn.Close()
2✔
312
}
313

314
// websocketWriter is the go-routine responsible for the writing end of the
315
// websocket. The routine forwards messages posted on the NATS session subject
316
// and periodically pings the connection. If the connection times out or a
317
// protocol violation occurs, the routine closes the connection.
318
func (h ManagementController) websocketWriter(
319
        ctx context.Context,
320
        conn *websocket.Conn,
321
        session *model.Session,
322
        deviceChan <-chan *natsio.Msg,
323
        errChan <-chan error,
324
        recorderBuffered *bufio.Writer,
325
        controlRecorderBuffered *bufio.Writer,
326
) (err error) {
2✔
327
        l := log.FromContext(ctx)
2✔
328
        defer writerFinalizer(conn, &err, l)
2✔
329

2✔
330
        // handle the ping-pong connection health check
2✔
331
        conn.SetPingHandler(func(msg string) error {
3✔
332
                if err != nil {
1✔
333
                        return err
×
334
                }
×
335
                return conn.WriteControl(
1✔
336
                        websocket.PongMessage,
1✔
337
                        []byte(msg),
1✔
338
                        time.Now().Add(writeWait),
1✔
339
                )
1✔
340
        })
341

342
        defer recorderBuffered.Flush()
2✔
343
        defer controlRecorderBuffered.Flush()
2✔
344
        recordedBytes := 0
2✔
345
        controlBytes := 0
2✔
346

2✔
347
        sessOverLimit := false
2✔
348
        sessOverLimitHandled := false
2✔
349

2✔
350
        lastKeystrokeAt := time.Now().UTC().UnixNano()
2✔
351
Loop:
2✔
352
        for {
4✔
353
                var forwardedMsg []byte
2✔
354

2✔
355
                select {
2✔
356
                case msg := <-deviceChan:
2✔
357
                        mr := &ws.ProtoMsg{}
2✔
358
                        err = msgpack.Unmarshal(msg.Data, mr)
2✔
359
                        if err != nil {
2✔
360
                                return err
×
361
                        }
×
362

363
                        forwardedMsg = msg.Data
2✔
364

2✔
365
                        if mr.Header.Proto == ws.ProtoTypeShell {
4✔
366
                                switch mr.Header.MsgType {
2✔
367
                                case shell.MessageTypeShellCommand:
2✔
368

2✔
369
                                        if recordedBytes >= app.MessageSizeLimit ||
2✔
370
                                                controlBytes >= app.MessageSizeLimit {
3✔
371
                                                sessOverLimit = true
1✔
372

1✔
373
                                                errMsg := h.handleSessLimit(ctx,
1✔
374
                                                        session,
1✔
375
                                                        &sessOverLimitHandled,
1✔
376
                                                )
1✔
377

1✔
378
                                                //override original message with shell error
1✔
379
                                                if errMsg != nil {
2✔
380
                                                        forwardedMsg = errMsg
1✔
381
                                                }
1✔
382
                                        } else {
2✔
383
                                                if err = recordSession(ctx,
2✔
384
                                                        mr,
2✔
385
                                                        recorderBuffered,
2✔
386
                                                        controlRecorderBuffered,
2✔
387
                                                        &recordedBytes,
2✔
388
                                                        &controlBytes,
2✔
389
                                                        &lastKeystrokeAt,
2✔
390
                                                        session,
2✔
391
                                                ); err != nil {
2✔
392
                                                        return err
×
393
                                                }
×
394
                                        }
395

396
                                case shell.MessageTypeStopShell:
×
397
                                        l.Debugf("session logging: recorderBuffered.Flush()"+
×
398
                                                " at %d on stop shell", recordedBytes)
×
399
                                        recorderBuffered.Flush()
×
400
                                }
401
                        }
402

403
                        if !sessOverLimit {
4✔
404
                                err = conn.WriteMessage(websocket.BinaryMessage, forwardedMsg)
2✔
405
                                if err != nil {
2✔
406
                                        l.Error(err)
×
407
                                        break Loop
×
408
                                }
409
                        }
410
                case <-ctx.Done():
1✔
411
                        break Loop
1✔
412
                case err := <-errChan:
2✔
413
                        return err
2✔
414
                }
415
        }
416
        return err
1✔
417
}
418

419
func (h ManagementController) handleSessLimit(ctx context.Context,
420
        session *model.Session,
421
        handled *bool,
422
) []byte {
1✔
423
        l := log.FromContext(ctx)
1✔
424

1✔
425
        // possible error return message (ws->user)
1✔
426
        var retMsg []byte
1✔
427

1✔
428
        // attempt to clean up once
1✔
429
        if !(*handled) {
2✔
430
                sendLimitErrDevice(ctx, session, h.nats)
1✔
431
                userErrMsg, err := prepLimitErrUser(ctx, session)
1✔
432
                if err != nil {
1✔
433
                        l.Errorf("session limit: " +
×
434
                                "failed to notify user")
×
435
                }
×
436

437
                retMsg = userErrMsg
1✔
438

1✔
439
                err = h.app.FreeUserSession(ctx, session.ID, session.Types)
1✔
440
                if err != nil {
1✔
441
                        l.Warnf("failed to free session"+
×
442
                                "that went over limit: %s", err.Error())
×
443
                }
×
444

445
                *handled = true
1✔
446
        }
447

448
        return retMsg
1✔
449
}
450

451
func recordSession(ctx context.Context,
452
        msg *ws.ProtoMsg,
453
        recorder io.Writer,
454
        recorderCtrl io.Writer,
455
        recBytes *int,
456
        ctrlBytes *int,
457
        lastKeystrokeAt *int64,
458
        session *model.Session) error {
2✔
459
        l := log.FromContext(ctx)
2✔
460

2✔
461
        b, e := recorder.Write(msg.Body)
2✔
462
        if e != nil {
2✔
463
                l.Errorf("session logging: "+
×
464
                        "recorderBuffered.Write"+
×
465
                        "(len=%d)=%d,%+v",
×
466
                        len(msg.Body), b, e)
×
467
        }
×
468
        timeNowUTC := time.Now().UTC().UnixNano()
2✔
469
        keystrokeDelay := timeNowUTC - (*lastKeystrokeAt)
2✔
470
        if keystrokeDelay >= keyStrokeDelayRecordingThresholdNs {
2✔
471
                if keystrokeDelay > keyStrokeMaxDelayRecording {
×
472
                        keystrokeDelay = keyStrokeMaxDelayRecording
×
473
                }
×
474

475
                controlMsg := app.Control{
×
476
                        Type:   app.DelayMessage,
×
477
                        Offset: *recBytes,
×
478
                        DelayMs: uint16(float64(keystrokeDelay) *
×
479
                                0.000001),
×
480
                        TerminalHeight: 0,
×
481
                        TerminalWidth:  0,
×
482
                }
×
483
                n, _ := recorderCtrl.Write(
×
484
                        controlMsg.MarshalBinary())
×
485
                l.Debugf("saving control delay message: %+v/%d",
×
486
                        controlMsg, n)
×
487
                (*ctrlBytes) += n
×
488
        }
489

490
        (*lastKeystrokeAt) = timeNowUTC
2✔
491

2✔
492
        (*recBytes) += len(msg.Body)
2✔
493
        session.BytesRecordedMutex.Lock()
2✔
494
        session.BytesRecorded = *recBytes
2✔
495
        session.BytesRecordedMutex.Unlock()
2✔
496

2✔
497
        return nil
2✔
498
}
499

500
// prepLimitErrUser preps a session limit exceeded error for the user (shell cmd + err status)
501
func prepLimitErrUser(ctx context.Context, session *model.Session) ([]byte, error) {
1✔
502
        userErrMsg := ws.ProtoMsg{
1✔
503
                Header: ws.ProtoHdr{
1✔
504
                        Proto:     ws.ProtoTypeShell,
1✔
505
                        MsgType:   shell.MessageTypeShellCommand,
1✔
506
                        SessionID: session.ID,
1✔
507
                        Properties: map[string]interface{}{
1✔
508
                                "status": shell.ErrorMessage,
1✔
509
                        },
1✔
510
                },
1✔
511
                Body: []byte(ErrMsgSessionLimit),
1✔
512
        }
1✔
513

1✔
514
        return msgpack.Marshal(userErrMsg)
1✔
515
}
1✔
516

517
// sendLimitErrDevice preps and sends
518
// session limit exceeded error to device (stop shell + err status)
519
// this is best effort, log and swallow errors
520
func sendLimitErrDevice(ctx context.Context, session *model.Session, nats nats.Client) {
1✔
521
        l := log.FromContext(ctx)
1✔
522

1✔
523
        msg := ws.ProtoMsg{
1✔
524
                Header: ws.ProtoHdr{
1✔
525
                        Proto:     ws.ProtoTypeShell,
1✔
526
                        MsgType:   shell.MessageTypeStopShell,
1✔
527
                        SessionID: session.ID,
1✔
528
                        Properties: map[string]interface{}{
1✔
529
                                "status":       shell.ErrorMessage,
1✔
530
                                PropertyUserID: session.UserID,
1✔
531
                        },
1✔
532
                },
1✔
533
                Body: []byte(ErrMsgSessionLimit),
1✔
534
        }
1✔
535
        data, err := msgpack.Marshal(msg)
1✔
536
        if err != nil {
1✔
537
                l.Errorf(
×
538
                        "session limit: "+
×
539
                                "failed to prep stop session"+
×
540
                                "%s message to device: %s, error %v",
×
541
                        session.ID,
×
542
                        session.DeviceID,
×
543
                        err,
×
544
                )
×
545
        }
×
546
        err = nats.Publish(model.GetDeviceSubject(
1✔
547
                session.TenantID, session.DeviceID),
1✔
548
                data,
1✔
549
        )
1✔
550
        if err != nil {
1✔
551
                l.Errorf(
×
552
                        "session limit: failed to send stop session"+
×
553
                                "%s message to device: %s, error %v",
×
554
                        session.ID,
×
555
                        session.DeviceID,
×
556
                        err,
×
557
                )
×
558
        }
×
559
}
560

561
// ConnectServeWS starts a websocket connection with the device
562
// Currently this handler only properly handles a single terminal session.
563
func (h ManagementController) ConnectServeWS(
564
        ctx context.Context,
565
        conn *websocket.Conn,
566
        sess *model.Session,
567
        deviceChan chan *natsio.Msg,
568
) (err error) {
2✔
569
        l := log.FromContext(ctx)
2✔
570
        id := identity.FromContext(ctx)
2✔
571
        errChan := make(chan error, 1)
2✔
572
        remoteTerminalRunning := false
2✔
573

2✔
574
        defer func() {
4✔
575
                if err != nil {
4✔
576
                        select {
2✔
577
                        case errChan <- err:
2✔
578

579
                        case <-time.After(time.Second):
×
580
                                l.Warn("Failed to propagate error to client")
×
581
                        }
582
                }
583
                if remoteTerminalRunning {
3✔
584
                        msg := ws.ProtoMsg{
1✔
585
                                Header: ws.ProtoHdr{
1✔
586
                                        Proto:     ws.ProtoTypeShell,
1✔
587
                                        MsgType:   shell.MessageTypeStopShell,
1✔
588
                                        SessionID: sess.ID,
1✔
589
                                        Properties: map[string]interface{}{
1✔
590
                                                "status":       shell.ErrorMessage,
1✔
591
                                                PropertyUserID: sess.UserID,
1✔
592
                                        },
1✔
593
                                },
1✔
594
                                Body: []byte("user disconnected"),
1✔
595
                        }
1✔
596
                        data, _ := msgpack.Marshal(msg)
1✔
597
                        errPublish := h.nats.Publish(model.GetDeviceSubject(
1✔
598
                                id.Tenant, sess.DeviceID),
1✔
599
                                data,
1✔
600
                        )
1✔
601
                        if errPublish != nil {
1✔
602
                                l.Warnf(
×
603
                                        "failed to propagate stop session "+
×
604
                                                "message to device: %s",
×
605
                                        errPublish.Error(),
×
606
                                )
×
607
                        }
×
608
                }
609
                close(errChan)
2✔
610
        }()
611

612
        controlRecorder := h.app.GetControlRecorder(ctx, sess.ID)
2✔
613
        controlRecorderBuffered := bufio.NewWriterSize(controlRecorder, app.RecorderBufferSize)
2✔
614
        defer controlRecorderBuffered.Flush()
2✔
615

2✔
616
        sessionRecorder := h.app.GetRecorder(ctx, sess.ID)
2✔
617
        sessionRecorderBuffered := bufio.NewWriterSize(sessionRecorder, app.RecorderBufferSize)
2✔
618
        defer sessionRecorderBuffered.Flush()
2✔
619

2✔
620
        // websocketWriter is responsible for closing the websocket
2✔
621
        //nolint:errcheck
2✔
622
        go h.websocketWriter(ctx,
2✔
623
                conn,
2✔
624
                sess,
2✔
625
                deviceChan,
2✔
626
                errChan,
2✔
627
                sessionRecorderBuffered,
2✔
628
                controlRecorderBuffered)
2✔
629

2✔
630
        return h.connectServeWSProcessMessages(ctx, conn, sess, deviceChan,
2✔
631
                &remoteTerminalRunning, controlRecorderBuffered)
2✔
632
}
633

634
func (h ManagementController) connectServeWSProcessMessages(
635
        ctx context.Context,
636
        conn *websocket.Conn,
637
        sess *model.Session,
638
        deviceChan chan *natsio.Msg,
639
        remoteTerminalRunning *bool,
640
        controlRecorderBuffered *bufio.Writer,
641
) (err error) {
2✔
642
        l := log.FromContext(ctx)
2✔
643
        id := identity.FromContext(ctx)
2✔
644
        logTerminal := false
2✔
645
        logPortForward := false
2✔
646

2✔
647
        var data []byte
2✔
648
        controlBytes := 0
2✔
649
        ignoreControlMessages := false
2✔
650
        for {
4✔
651
                _, data, err = conn.ReadMessage()
2✔
652
                if err != nil {
4✔
653
                        if _, ok := err.(*websocket.CloseError); ok {
4✔
654
                                return nil
2✔
655
                        }
2✔
656
                        return err
1✔
657
                }
658
                m := &ws.ProtoMsg{}
2✔
659
                err = msgpack.Unmarshal(data, m)
2✔
660
                if err != nil {
3✔
661
                        return err
1✔
662
                }
1✔
663

664
                m.Header.SessionID = sess.ID
2✔
665
                if m.Header.Properties == nil {
3✔
666
                        m.Header.Properties = make(map[string]interface{})
1✔
667
                }
1✔
668
                m.Header.Properties[PropertyUserID] = sess.UserID
2✔
669
                data, _ = msgpack.Marshal(m)
2✔
670
                switch m.Header.Proto {
2✔
671
                case ws.ProtoTypeShell:
2✔
672
                        // send the audit log for remote terminal
2✔
673
                        if !logTerminal {
4✔
674
                                if err := h.app.LogUserSession(ctx, sess,
2✔
675
                                        model.SessionTypeTerminal); err != nil {
2✔
676
                                        return err
×
677
                                }
×
678
                                sess.Types = append(sess.Types, model.SessionTypeTerminal)
2✔
679
                                logTerminal = true
2✔
680
                        }
681
                        // handle remote terminal-specific messages
682
                        switch m.Header.MsgType {
2✔
683
                        case shell.MessageTypeSpawnShell:
1✔
684
                                *remoteTerminalRunning = true
1✔
685
                        case shell.MessageTypeStopShell:
1✔
686
                                *remoteTerminalRunning = false
1✔
687
                        case shell.MessageTypeResizeShell:
×
688
                                if ignoreControlMessages {
×
689
                                        continue
×
690
                                }
691
                                if controlBytes >= app.MessageSizeLimit {
×
692
                                        l.Infof("session_id=%s control data limit reached.",
×
693
                                                sess.ID)
×
694
                                        //see https://northerntech.atlassian.net/browse/MEN-4448
×
695
                                        ignoreControlMessages = true
×
696
                                        continue
×
697
                                }
698

699
                                controlBytes += sendResizeMessage(m, sess, controlRecorderBuffered)
×
700
                        }
701
                case ws.ProtoTypePortForward:
×
702
                        if !logPortForward {
×
703
                                if err := h.app.LogUserSession(ctx, sess,
×
704
                                        model.SessionTypePortForward); err != nil {
×
705
                                        return err
×
706
                                }
×
707
                                sess.Types = append(sess.Types, model.SessionTypePortForward)
×
708
                                logPortForward = true
×
709
                        }
710
                }
711

712
                err = h.nats.Publish(model.GetDeviceSubject(id.Tenant, sess.DeviceID), data)
2✔
713
                if err != nil {
2✔
714
                        return err
×
715
                }
×
716
        }
717
}
718

719
func sendResizeMessage(m *ws.ProtoMsg,
720
        sess *model.Session,
721
        controlRecorderBuffered *bufio.Writer) (n int) {
×
722
        if _, ok := m.Header.Properties[model.ResizeMessageTermHeightField]; ok {
×
723
                return 0
×
724
        }
×
725
        if _, ok := m.Header.Properties[model.ResizeMessageTermWidthField]; ok {
×
726
                return 0
×
727
        }
×
728

729
        var height uint16 = 0
×
730
        switch m.Header.Properties[model.ResizeMessageTermHeightField].(type) {
×
731
        case uint8:
×
732
                height = uint16(m.Header.Properties[model.ResizeMessageTermHeightField].(uint8))
×
733
        case int8:
×
734
                height = uint16(m.Header.Properties[model.ResizeMessageTermHeightField].(int8))
×
735
        }
736

737
        var width uint16 = 0
×
738
        switch m.Header.Properties[model.ResizeMessageTermWidthField].(type) {
×
739
        case uint8:
×
740
                width = uint16(m.Header.Properties[model.ResizeMessageTermWidthField].(uint8))
×
741
        case int8:
×
742
                width = uint16(m.Header.Properties[model.ResizeMessageTermWidthField].(int8))
×
743
        }
744

745
        sess.BytesRecordedMutex.Lock()
×
746
        controlMsg := app.Control{
×
747
                Type:           app.ResizeMessage,
×
748
                Offset:         sess.BytesRecorded,
×
749
                DelayMs:        0,
×
750
                TerminalHeight: height,
×
751
                TerminalWidth:  width,
×
752
        }
×
753
        sess.BytesRecordedMutex.Unlock()
×
754

×
755
        n, _ = controlRecorderBuffered.Write(
×
756
                controlMsg.MarshalBinary(),
×
757
        )
×
758
        return n
×
759
}
760

761
func (h ManagementController) CheckUpdate(c *gin.Context) {
1✔
762
        h.sendMenderCommand(c, menderclient.MessageTypeMenderClientCheckUpdate)
1✔
763
}
1✔
764

765
func (h ManagementController) SendInventory(c *gin.Context) {
1✔
766
        h.sendMenderCommand(c, menderclient.MessageTypeMenderClientSendInventory)
1✔
767
}
1✔
768

769
func (h ManagementController) sendMenderCommand(c *gin.Context, msgType string) {
1✔
770
        ctx := c.Request.Context()
1✔
771

1✔
772
        idata := identity.FromContext(ctx)
1✔
773
        if idata == nil || !idata.IsUser {
1✔
774
                c.JSON(http.StatusBadRequest, gin.H{
×
775
                        "error": ErrMissingUserAuthentication.Error(),
×
776
                })
×
777
                return
×
778
        }
×
779
        tenantID := idata.Tenant
1✔
780
        deviceID := c.Param("deviceId")
1✔
781

1✔
782
        device, err := h.app.GetDevice(ctx, tenantID, deviceID)
1✔
783
        if err == app.ErrDeviceNotFound {
2✔
784
                c.JSON(http.StatusNotFound, gin.H{
1✔
785
                        "error": err.Error(),
1✔
786
                })
1✔
787
                return
1✔
788
        } else if err != nil {
3✔
789
                c.JSON(http.StatusBadRequest, gin.H{
1✔
790
                        "error": err.Error(),
1✔
791
                })
1✔
792
                return
1✔
793
        } else if device.Status != model.DeviceStatusConnected {
3✔
794
                c.JSON(http.StatusConflict, gin.H{
1✔
795
                        "error": app.ErrDeviceNotConnected,
1✔
796
                })
1✔
797
                return
1✔
798
        }
1✔
799

800
        msg := &ws.ProtoMsg{
1✔
801
                Header: ws.ProtoHdr{
1✔
802
                        Proto:   ws.ProtoTypeMenderClient,
1✔
803
                        MsgType: msgType,
1✔
804
                        Properties: map[string]interface{}{
1✔
805
                                PropertyUserID: idata.Subject,
1✔
806
                        },
1✔
807
                },
1✔
808
        }
1✔
809
        data, _ := msgpack.Marshal(msg)
1✔
810

1✔
811
        err = h.nats.Publish(model.GetDeviceSubject(idata.Tenant, device.ID), data)
1✔
812
        if err != nil {
2✔
813
                c.JSON(http.StatusInternalServerError, gin.H{
1✔
814
                        "error": err.Error(),
1✔
815
                })
1✔
816
        }
1✔
817

818
        c.JSON(http.StatusAccepted, nil)
1✔
819
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc