• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mendersoftware / deviceconnect / 1092400258

01 Dec 2023 02:41PM UTC coverage: 76.254% (-3.9%) from 80.201%
1092400258

push

gitlab-ci

web-flow
Merge pull request #334 from alfrunes/ALV-182

ALV-182: Add support for `HEAD` HTTP method for download API endpoint

153 of 194 new or added lines in 2 files covered. (78.87%)

7 existing lines in 1 file now uncovered.

2386 of 3129 relevant lines covered (76.25%)

22.77 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.42
/api/http/management_filetransfer.go
1
// Copyright 2021 Northern.tech AS
2
//
3
//    Licensed under the Apache License, Version 2.0 (the "License");
4
//    you may not use this file except in compliance with the License.
5
//    You may obtain a copy of the License at
6
//
7
//        http://www.apache.org/licenses/LICENSE-2.0
8
//
9
//    Unless required by applicable law or agreed to in writing, software
10
//    distributed under the License is distributed on an "AS IS" BASIS,
11
//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
//    See the License for the specific language governing permissions and
13
//    limitations under the License.
14

15
package http
16

17
import (
18
        "context"
19
        "fmt"
20
        "io"
21
        "net/http"
22
        "os"
23
        "path"
24
        "strconv"
25
        "time"
26

27
        "github.com/gin-gonic/gin"
28
        "github.com/google/uuid"
29
        natsio "github.com/nats-io/nats.go"
30
        "github.com/pkg/errors"
31
        "github.com/vmihailenco/msgpack/v5"
32

33
        "github.com/mendersoftware/go-lib-micro/identity"
34
        "github.com/mendersoftware/go-lib-micro/log"
35
        "github.com/mendersoftware/go-lib-micro/requestid"
36
        "github.com/mendersoftware/go-lib-micro/ws"
37
        wsft "github.com/mendersoftware/go-lib-micro/ws/filetransfer"
38

39
        "github.com/mendersoftware/deviceconnect/app"
40
        "github.com/mendersoftware/deviceconnect/model"
41
)
42

43
type fileTransferParams struct {
44
        TenantID  string
45
        UserID    string
46
        SessionID string
47
        Device    *model.Device
48
}
49

50
const (
51
        hdrContentType            = "Content-Type"
52
        hdrContentDisposition     = "Content-Disposition"
53
        hdrMenderFileTransferPath = "X-MEN-File-Path"
54
        hdrMenderFileTransferUID  = "X-MEN-File-UID"
55
        hdrMenderFileTransferGID  = "X-MEN-File-GID"
56
        hdrMenderFileTransferMode = "X-MEN-File-Mode"
57
        hdrMenderFileTransferSize = "X-MEN-File-Size"
58
)
59

60
const (
61
        fieldUploadPath = "path"
62
        fieldUploadUID  = "uid"
63
        fieldUploadGID  = "gid"
64
        fieldUploadMode = "mode"
65
        fieldUploadFile = "file"
66

67
        PropertyOffset = "offset"
68

69
        paramDownloadPath = "path"
70
)
71

72
var fileTransferTimeout = 60 * time.Second
73
var fileTransferBufferSize = 4096
74
var ackSlidingWindowSend = 10
75
var ackSlidingWindowRecv = 20
76

77
type Error struct {
78
        error      error
79
        statusCode int
80
}
81

82
func NewError(err error, code int) error {
2✔
83
        return &Error{
2✔
84
                error:      err,
2✔
85
                statusCode: code,
2✔
86
        }
2✔
87
}
2✔
88

89
func (err *Error) Error() string {
16✔
90
        return err.error.Error()
16✔
91
}
16✔
92

NEW
93
func (err *Error) Unwrap() error {
×
NEW
94
        return err.error
×
NEW
95
}
×
96

97
var (
98
        errFileTransferMarshalling   = errors.New("failed to marshal the request")
99
        errFileTransferUnmarshalling = errors.New("failed to unmarshal the request")
100
        errFileTransferPublishing    = errors.New("failed to publish the message")
101
        errFileTransferSubscribing   = errors.New("failed to subscribe to the mesages")
102
        errFileTransferTimeout       = &Error{
103
                error:      errors.New("file transfer timed out"),
104
                statusCode: http.StatusRequestTimeout,
105
        }
106
        errFileTransferFailed = &Error{
107
                error:      errors.New("file transfer failed"),
108
                statusCode: http.StatusBadRequest,
109
        }
110
        errFileTransferNotImplemented = &Error{
111
                error:      errors.New("file transfer not implemented on device"),
112
                statusCode: http.StatusBadGateway,
113
        }
114
        errFileTransferDisabled = &Error{
115
                error:      errors.New("file transfer disabled on device"),
116
                statusCode: http.StatusBadGateway,
117
        }
118
)
119

120
var newFileTransferSessionID = func() (uuid.UUID, error) {
×
121
        return uuid.NewRandom()
×
122
}
×
123

124
func (h ManagementController) getFileTransferParams(c *gin.Context) (*fileTransferParams, int,
125
        error) {
35✔
126
        ctx := c.Request.Context()
35✔
127

35✔
128
        idata := identity.FromContext(ctx)
35✔
129
        if idata == nil || !idata.IsUser {
37✔
130
                return nil, http.StatusUnauthorized, ErrMissingUserAuthentication
2✔
131
        }
2✔
132
        tenantID := idata.Tenant
33✔
133
        deviceID := c.Param("deviceId")
33✔
134

33✔
135
        device, err := h.app.GetDevice(ctx, tenantID, deviceID)
33✔
136
        if err == app.ErrDeviceNotFound {
35✔
137
                return nil, http.StatusNotFound, err
2✔
138
        } else if err != nil {
35✔
139
                return nil, http.StatusBadRequest, err
2✔
140
        } else if device.Status != model.DeviceStatusConnected {
33✔
141
                return nil, http.StatusConflict, app.ErrDeviceNotConnected
2✔
142
        }
2✔
143

144
        if c.Request.Method != http.MethodGet && c.Request.Body == nil {
29✔
145
                return nil, http.StatusBadRequest, errors.New("missing request body")
2✔
146
        }
2✔
147

148
        sessionID, err := newFileTransferSessionID()
25✔
149
        if err != nil {
25✔
150
                return nil, http.StatusInternalServerError,
×
151
                        errors.New("failed to generate session ID")
×
152
        }
×
153

154
        return &fileTransferParams{
25✔
155
                TenantID:  idata.Tenant,
25✔
156
                UserID:    idata.Subject,
25✔
157
                SessionID: sessionID.String(),
25✔
158
                Device:    device,
25✔
159
        }, 0, nil
25✔
160
}
161

162
func (h ManagementController) publishFileTransferProtoMessage(sessionID, userID, deviceTopic,
163
        msgType string, body interface{}, offset int64) error {
30✔
164
        var msgBody []byte
30✔
165
        if msgType == wsft.MessageTypeChunk && body != nil {
33✔
166
                msgBody = body.([]byte)
3✔
167
        } else if msgType == wsft.MessageTypeACK {
38✔
168
                msgBody = nil
8✔
169
        } else if body != nil {
45✔
170
                var err error
18✔
171
                msgBody, err = msgpack.Marshal(body)
18✔
172
                if err != nil {
18✔
173
                        return errors.Wrap(err, errFileTransferMarshalling.Error())
×
174
                }
×
175
        }
176
        proto := ws.ProtoTypeFileTransfer
30✔
177
        if msgType == ws.MessageTypePing || msgType == ws.MessageTypePong {
30✔
178
                proto = ws.ProtoTypeControl
×
179
        }
×
180
        msg := &ws.ProtoMsg{
30✔
181
                Header: ws.ProtoHdr{
30✔
182
                        Proto:     proto,
30✔
183
                        MsgType:   msgType,
30✔
184
                        SessionID: sessionID,
30✔
185
                        Properties: map[string]interface{}{
30✔
186
                                PropertyUserID: userID,
30✔
187
                        },
30✔
188
                },
30✔
189
                Body: msgBody,
30✔
190
        }
30✔
191
        if msgType == wsft.MessageTypeChunk || msgType == wsft.MessageTypeACK {
42✔
192
                msg.Header.Properties[PropertyOffset] = offset
12✔
193
        }
12✔
194
        data, err := msgpack.Marshal(msg)
30✔
195
        if err != nil {
30✔
196
                return errors.Wrap(err, errFileTransferMarshalling.Error())
×
197
        }
×
198

199
        err = h.nats.Publish(deviceTopic, data)
30✔
200
        if err != nil {
30✔
201
                return errors.Wrap(err, errFileTransferPublishing.Error())
×
202
        }
×
203
        return nil
30✔
204
}
205

206
func (h ManagementController) publishControlMessage(
207
        sessionID, deviceTopic, messageType string, body interface{},
208
) error {
31✔
209
        msg := &ws.ProtoMsg{
31✔
210
                Header: ws.ProtoHdr{
31✔
211
                        Proto:     ws.ProtoTypeControl,
31✔
212
                        MsgType:   messageType,
31✔
213
                        SessionID: sessionID,
31✔
214
                },
31✔
215
        }
31✔
216

31✔
217
        if body != nil {
48✔
218
                if b, ok := body.([]byte); ok {
17✔
219
                        msg.Body = b
×
220
                } else {
17✔
221
                        b, err := msgpack.Marshal(body)
17✔
222
                        if err != nil {
17✔
223
                                return errors.Wrap(errFileTransferMarshalling, err.Error())
×
224
                        }
×
225
                        msg.Body = b
17✔
226
                }
227
        }
228

229
        data, err := msgpack.Marshal(msg)
31✔
230
        if err != nil {
31✔
231
                return errors.Wrap(errFileTransferMarshalling, err.Error())
×
232
        }
×
233
        err = h.nats.Publish(deviceTopic, data)
31✔
234
        if err != nil {
31✔
235
                return errors.Wrap(errFileTransferPublishing, err.Error())
×
236
        }
×
237
        return err
31✔
238
}
239

240
func (h ManagementController) decodeFileTransferProtoMessage(data []byte) (*ws.ProtoMsg,
241
        interface{}, error) {
18✔
242
        msg := &ws.ProtoMsg{}
18✔
243
        err := msgpack.Unmarshal(data, msg)
18✔
244
        if err != nil {
18✔
245
                return nil, nil, errors.Wrap(err, errFileTransferUnmarshalling.Error())
×
246
        }
×
247

248
        switch msg.Header.MsgType {
18✔
249
        case wsft.MessageTypeError:
3✔
250
                msgBody := &wsft.Error{}
3✔
251
                err := msgpack.Unmarshal(msg.Body, msgBody)
3✔
252
                if err != nil {
3✔
253
                        return nil, nil, errors.Wrap(err, errFileTransferUnmarshalling.Error())
×
254
                }
×
255
                return msg, msgBody, nil
3✔
UNCOV
256
        case wsft.MessageTypeFileInfo:
×
UNCOV
257
                msgBody := &wsft.FileInfo{}
×
UNCOV
258
                err := msgpack.Unmarshal(msg.Body, msgBody)
×
UNCOV
259
                if err != nil {
×
260
                        return nil, nil, errors.Wrap(err, errFileTransferUnmarshalling.Error())
×
261
                }
×
UNCOV
262
                return msg, msgBody, nil
×
263
        case wsft.MessageTypeACK, wsft.MessageTypeChunk, ws.MessageTypePing, ws.MessageTypePong:
13✔
264
                return msg, nil, nil
13✔
265
        }
266

267
        return nil, nil, errors.Errorf("unexpected message type '%s'", msg.Header.MsgType)
2✔
268
}
269

270
func writeHeaders(c *gin.Context, fileInfo *wsft.FileInfo) {
5✔
271
        c.Writer.Header().Add(hdrContentType, "application/octet-stream")
5✔
272
        if fileInfo.Path != nil {
10✔
273
                filename := path.Base(*fileInfo.Path)
5✔
274
                c.Writer.Header().Add(hdrContentDisposition,
5✔
275
                        "attachment; filename=\""+filename+"\"")
5✔
276
                c.Writer.Header().Add(hdrMenderFileTransferPath, *fileInfo.Path)
5✔
277
        }
5✔
278
        if fileInfo.UID != nil {
6✔
279
                c.Writer.Header().Add(hdrMenderFileTransferUID, fmt.Sprintf("%d", *fileInfo.UID))
1✔
280
        }
1✔
281
        if fileInfo.GID != nil {
6✔
282
                c.Writer.Header().Add(hdrMenderFileTransferGID, fmt.Sprintf("%d", *fileInfo.GID))
1✔
283
        }
1✔
284
        if fileInfo.Mode != nil {
10✔
285
                c.Writer.Header().Add(hdrMenderFileTransferMode, fmt.Sprintf("%o", *fileInfo.Mode))
5✔
286
        }
5✔
287
        if fileInfo.Size != nil {
10✔
288
                c.Writer.Header().Add(hdrMenderFileTransferSize, fmt.Sprintf("%d", *fileInfo.Size))
5✔
289
        }
5✔
290
        c.Writer.WriteHeader(http.StatusOK)
5✔
291
}
292
func (h ManagementController) handleResponseError(c *gin.Context, err error) {
4✔
293
        l := log.FromContext(c.Request.Context())
4✔
294
        l.Errorf("error handling request: %s", err.Error())
4✔
295
        if !c.Writer.Written() {
8✔
296
                var statusError *Error
4✔
297
                var errMsg string = err.Error()
4✔
298
                var statusCode int = http.StatusInternalServerError
4✔
299
                if errors.As(err, &statusError) {
8✔
300
                        statusCode = statusError.statusCode
4✔
301
                }
4✔
302
                if statusCode >= 500 {
5✔
303
                        errMsg = "internal error"
1✔
304
                }
1✔
305
                c.Writer.WriteHeader(statusCode)
4✔
306
                c.JSON(statusCode, gin.H{
4✔
307
                        "error":      errMsg,
4✔
308
                        "request_id": requestid.FromContext(c.Request.Context()),
4✔
309
                })
4✔
NEW
310
        } else {
×
NEW
311
                l.Warn("response already written")
×
UNCOV
312
        }
×
313
}
314

315
func chanTimeout(
316
        src <-chan *natsio.Msg,
317
        timeout time.Duration,
318
) <-chan *natsio.Msg {
9✔
319
        timer := time.NewTimer(timeout)
9✔
320
        dst := make(chan *natsio.Msg)
9✔
321
        go func() {
18✔
322
                for {
44✔
323
                        select {
35✔
324
                        case <-timer.C:
2✔
325
                                close(dst)
2✔
326
                                return
2✔
327
                        case msg, ok := <-src:
33✔
328
                                if !ok {
40✔
329
                                        close(dst)
7✔
330
                                        return
7✔
331
                                }
7✔
332
                                if !timer.Stop() {
26✔
NEW
333
                                        // Timer must be stopped and drained before calling Reset.
×
NEW
334
                                        select {
×
NEW
335
                                        case <-timer.C:
×
NEW
336
                                        default:
×
337
                                        }
338
                                }
339
                                timer.Reset(timeout)
26✔
340
                                dst <- msg
26✔
341
                        }
342
                }
343
        }()
344
        return dst
9✔
345
}
346

347
func (h ManagementController) statFile(
348
        ctx context.Context,
349
        sessChan <-chan *natsio.Msg,
350
        path, sessionID, userID, deviceTopic string) (*wsft.FileInfo, error) {
8✔
351
        // stat the remote file
8✔
352
        req := wsft.StatFile{
8✔
353
                Path: &path,
8✔
354
        }
8✔
355
        if err := h.publishFileTransferProtoMessage(sessionID,
8✔
356
                userID, deviceTopic, wsft.MessageTypeStat, req, 0); err != nil {
8✔
NEW
357
                return nil, err
×
NEW
358
        }
×
359
        var fileInfo *wsft.FileInfo
8✔
360
        select {
8✔
361
        case rsp, ok := <-sessChan:
8✔
362
                if !ok {
9✔
363
                        return nil, errFileTransferTimeout
1✔
364
                }
1✔
365
                var msg ws.ProtoMsg
7✔
366
                err := msgpack.Unmarshal(rsp.Data, &msg)
7✔
367
                if err != nil {
7✔
NEW
368
                        return nil, fmt.Errorf("malformed message from device: %w", err)
×
NEW
369
                }
×
370
                if msg.Header.MsgType == ws.MessageTypeError {
8✔
371
                        var errMsg ws.Error
1✔
372
                        _ = msgpack.Unmarshal(msg.Body, &errMsg)
1✔
373
                        rspErr := NewError(
1✔
374
                                fmt.Errorf("error received from device: %s", errMsg.Error),
1✔
375
                                http.StatusBadRequest,
1✔
376
                        )
1✔
377
                        return nil, rspErr
1✔
378
                }
1✔
379
                if msg.Header.Proto != ws.ProtoTypeFileTransfer ||
6✔
380
                        msg.Header.MsgType != wsft.MessageTypeFileInfo {
6✔
NEW
381
                        return nil, fmt.Errorf("unexpected response from device %q", msg.Header.MsgType)
×
NEW
382
                }
×
383
                err = msgpack.Unmarshal(msg.Body, &fileInfo)
6✔
384
                if err != nil {
6✔
NEW
385
                        return nil, fmt.Errorf("malformed message body from device: %w", err)
×
NEW
386
                }
×
NEW
387
        case <-ctx.Done():
×
NEW
388
                return nil, ctx.Err()
×
389
        }
390
        return fileInfo, nil
6✔
391
}
392

393
func (h ManagementController) downloadFileResponse(c *gin.Context, params *fileTransferParams,
394
        request *model.DownloadFileRequest) {
9✔
395
        ctx := c.Request.Context()
9✔
396
        // send a JSON-encoded error message in case of failure
9✔
397

9✔
398
        // subscribe to messages from the device
9✔
399
        deviceTopic := model.GetDeviceSubject(params.TenantID, params.Device.ID)
9✔
400
        sessionTopic := model.GetSessionSubject(params.TenantID, params.SessionID)
9✔
401
        subChan := make(chan *natsio.Msg, channelSize)
9✔
402
        defer close(subChan)
9✔
403
        sub, err := h.nats.ChanSubscribe(sessionTopic, subChan)
9✔
404
        if err != nil {
9✔
NEW
405
                h.handleResponseError(c, errors.Wrap(err, errFileTransferSubscribing.Error()))
×
406
                return
×
UNCOV
407
        }
×
408
        //nolint:errcheck
409
        defer sub.Unsubscribe()
9✔
410

9✔
411
        msgChan := chanTimeout(subChan, fileTransferTimeout)
9✔
412

9✔
413
        if err = h.filetransferHandshake(msgChan, params.SessionID, deviceTopic); err != nil {
10✔
414
                h.handleResponseError(c, err)
1✔
415
                return
1✔
416
        }
1✔
417
        // Inform the device that we're closing the session
418
        //nolint:errcheck
419
        defer h.publishControlMessage(params.SessionID, deviceTopic, ws.MessageTypeClose, nil)
8✔
420

8✔
421
        fileInfo, err := h.statFile(
8✔
422
                ctx, msgChan, *request.Path,
8✔
423
                params.SessionID, params.UserID, deviceTopic,
8✔
424
        )
8✔
425
        if err != nil {
10✔
426
                h.handleResponseError(c, fmt.Errorf("failed to retrieve file info: %w", err))
2✔
427
                return
2✔
428
        }
2✔
429
        if fileInfo.Mode == nil || !os.FileMode(*fileInfo.Mode).IsRegular() {
7✔
430
                h.handleResponseError(
1✔
431
                        c,
1✔
432
                        NewError(fmt.Errorf("file must be a regular file"), http.StatusBadRequest),
1✔
433
                )
1✔
434
                return
1✔
435
        }
1✔
436
        writeHeaders(c, fileInfo)
5✔
437
        if c.Request.Method == http.MethodHead {
5✔
NEW
438
                return
×
NEW
439
        }
×
440
        err = h.downloadFile(
5✔
441
                ctx, msgChan, c.Writer, *request.Path,
5✔
442
                params.SessionID, params.UserID, deviceTopic,
5✔
443
        )
5✔
444
        if err != nil {
10✔
445
                log.FromContext(ctx).
5✔
446
                        Errorf("error downloading file from device: %s", err.Error())
5✔
447
        }
5✔
448
}
449

450
func (h ManagementController) downloadFile(
451
        ctx context.Context,
452
        msgChan <-chan *natsio.Msg,
453
        dst io.Writer,
454
        path, sessionID, userID, deviceTopic string,
455
) error {
5✔
456
        latestOffset := int64(0)
5✔
457
        numberOfChunks := 0
5✔
458
        req := wsft.GetFile{
5✔
459
                Path: &path,
5✔
460
        }
5✔
461
        if err := h.publishFileTransferProtoMessage(
5✔
462
                sessionID,
5✔
463
                userID,
5✔
464
                deviceTopic,
5✔
465
                wsft.MessageTypeGet,
5✔
466
                req, 0); err != nil {
5✔
467
                return err
×
468
        }
×
469
        for {
16✔
470
                select {
11✔
471
                case wsMessage, ok := <-msgChan:
11✔
472
                        if !ok {
12✔
473
                                return errFileTransferTimeout
1✔
474
                        }
1✔
475

476
                        // process the message
477
                        msg, msgBody, err := h.decodeFileTransferProtoMessage(wsMessage.Data)
10✔
478
                        if err != nil {
10✔
NEW
479
                                return err
×
NEW
480
                        }
×
481

482
                        // process incoming messages from the device by type
483
                        switch msg.Header.MsgType {
10✔
484

485
                        // error message, stop here
486
                        case wsft.MessageTypeError:
1✔
487
                                errorMsg := msgBody.(*wsft.Error)
1✔
488
                                return errors.New(*errorMsg.Error)
1✔
489

490
                        // file data chunk
491
                        case wsft.MessageTypeChunk:
9✔
492
                                if msg.Body == nil {
11✔
493
                                        if err := h.publishFileTransferProtoMessage(
2✔
494
                                                sessionID, userID, deviceTopic,
2✔
495
                                                wsft.MessageTypeACK, nil,
2✔
496
                                                latestOffset); err != nil {
2✔
NEW
497
                                                return err
×
NEW
498
                                        }
×
499
                                        return io.EOF
2✔
500
                                }
501

502
                                // verify the offset property
503
                                propOffset, _ := msg.Header.Properties[PropertyOffset].(int64)
7✔
504
                                if propOffset != latestOffset {
8✔
505
                                        return errors.Wrap(errFileTransferFailed,
1✔
506
                                                "wrong offset received")
1✔
507
                                }
1✔
508
                                latestOffset += int64(len(msg.Body))
6✔
509

6✔
510
                                _, err := dst.Write(msg.Body)
6✔
511
                                if err != nil {
6✔
NEW
512
                                        return err
×
NEW
513
                                }
×
514

515
                                numberOfChunks++
6✔
516
                                if numberOfChunks >= ackSlidingWindowSend {
12✔
517
                                        if err := h.publishFileTransferProtoMessage(
6✔
518
                                                sessionID, userID, deviceTopic,
6✔
519
                                                wsft.MessageTypeACK, nil,
6✔
520
                                                latestOffset); err != nil {
6✔
NEW
521
                                                return err
×
NEW
522
                                        }
×
523
                                        numberOfChunks = 0
6✔
524
                                }
525

NEW
526
                        case ws.MessageTypePing:
×
NEW
527
                                if err := h.publishFileTransferProtoMessage(
×
NEW
528
                                        sessionID, userID, deviceTopic,
×
NEW
529
                                        ws.MessageTypePong, nil,
×
NEW
530
                                        -1); err != nil {
×
NEW
531
                                        return err
×
NEW
532
                                }
×
533
                        }
NEW
534
                case <-ctx.Done():
×
NEW
535
                        return ctx.Err()
×
536
                }
537
        }
538
}
539

540
func (h ManagementController) DownloadFile(c *gin.Context) {
17✔
541
        l := log.FromContext(c.Request.Context())
17✔
542

17✔
543
        params, statusCode, err := h.getFileTransferParams(c)
17✔
544
        if err != nil {
21✔
545
                l.Error(err)
4✔
546
                c.JSON(statusCode, gin.H{"error": err.Error()})
4✔
547
                return
4✔
548
        }
4✔
549

550
        path := c.Request.URL.Query().Get(paramDownloadPath)
13✔
551
        request := &model.DownloadFileRequest{
13✔
552
                Path: &path,
13✔
553
        }
13✔
554

13✔
555
        if err := request.Validate(); err != nil {
16✔
556
                l.Error(err)
3✔
557
                c.JSON(http.StatusBadRequest, gin.H{
3✔
558
                        "error": errors.Wrap(err, "bad request").Error(),
3✔
559
                })
3✔
560
                return
3✔
561
        }
3✔
562

563
        ctx := c.Request.Context()
10✔
564
        if err := h.app.DownloadFile(ctx, params.UserID, params.Device.ID,
10✔
565
                *request.Path); err != nil {
11✔
566
                l.Error(err)
1✔
567
                c.JSON(http.StatusInternalServerError, gin.H{
1✔
568
                        "error": errors.Wrap(err, "bad request").Error(),
1✔
569
                })
1✔
570
                return
1✔
571
        }
1✔
572

573
        h.downloadFileResponse(c, params, request)
9✔
574
}
575

576
func (h ManagementController) uploadFileResponseHandleInboundMessages(
577
        c *gin.Context, params *fileTransferParams,
578
        msgChan chan *natsio.Msg, errorChan chan error,
579
        latestAckOffsets chan int64,
580
) {
3✔
581
        var latestAckOffset int64
3✔
582
        deviceTopic := model.GetDeviceSubject(params.TenantID, params.Device.ID)
3✔
583
        for {
7✔
584
                select {
4✔
585
                case wsMessage := <-msgChan:
4✔
586
                        msg, msgBody, err := h.decodeFileTransferProtoMessage(
4✔
587
                                wsMessage.Data)
4✔
588
                        if err != nil {
6✔
589
                                errorChan <- err
2✔
590
                                return
2✔
591
                        }
2✔
592

593
                        // process incoming messages from the device by type
594
                        switch msg.Header.MsgType {
2✔
595

596
                        // error message, stop here
597
                        case wsft.MessageTypeError:
1✔
598
                                errorMsg := msgBody.(*wsft.Error)
1✔
599
                                errorChan <- errors.New(*errorMsg.Error)
1✔
600
                                return
1✔
601

602
                        // you can continue the upload
603
                        case wsft.MessageTypeACK:
1✔
604
                                propValue := msg.Header.Properties[PropertyOffset]
1✔
605
                                propOffset, _ := propValue.(int64)
1✔
606
                                if propOffset > latestAckOffset {
2✔
607
                                        latestAckOffset = propOffset
1✔
608
                                        select {
1✔
609
                                        case latestAckOffsets <- latestAckOffset:
1✔
610
                                        case <-latestAckOffsets:
×
611
                                                // Replace ack offset with the latest one
×
612
                                                latestAckOffsets <- latestAckOffset
×
613
                                        }
614
                                }
615

616
                        // handle ping messages
617
                        case ws.MessageTypePing:
×
618
                                if err := h.publishFileTransferProtoMessage(
×
619
                                        params.SessionID, params.UserID, deviceTopic,
×
620
                                        ws.MessageTypePong, nil,
×
621
                                        -1); err != nil {
×
622
                                        errorChan <- err
×
623
                                }
×
624
                        }
625
                case <-c.Done():
×
626
                        return
×
627
                }
628
        }
629
}
630

631
// filetransferHandshake initiates a handshake and checks that the device
632
// is willing to accept file transfer requests.
633
func (h ManagementController) filetransferHandshake(
634
        sessChan <-chan *natsio.Msg, sessionID, deviceTopic string,
635
) error {
17✔
636
        if err := h.publishControlMessage(
17✔
637
                sessionID, deviceTopic,
17✔
638
                ws.MessageTypeOpen, ws.Open{
17✔
639
                        Versions: []int{ws.ProtocolVersion},
17✔
640
                }); err != nil {
17✔
641
                return errFileTransferPublishing
×
642
        }
×
643
        select {
17✔
644
        case natsMsg, ok := <-sessChan:
16✔
645
                if !ok {
16✔
NEW
646
                        return errFileTransferTimeout
×
NEW
647
                }
×
648
                var msg ws.ProtoMsg
16✔
649
                err := msgpack.Unmarshal(natsMsg.Data, &msg)
16✔
650
                if err != nil {
16✔
651
                        return errFileTransferUnmarshalling
×
652
                }
×
653

654
                if msg.Header.MsgType == ws.MessageTypeError {
17✔
655
                        erro := new(ws.Error)
1✔
656
                        //nolint:errcheck
1✔
657
                        msgpack.Unmarshal(natsMsg.Data, erro)
1✔
658
                        return errors.Errorf("handshake error from client: %s", erro.Error)
1✔
659
                } else if msg.Header.MsgType != ws.MessageTypeAccept {
17✔
660
                        return errFileTransferNotImplemented
1✔
661
                }
1✔
662
                accept := new(ws.Accept)
14✔
663
                err = msgpack.Unmarshal(msg.Body, accept)
14✔
664
                if err != nil {
14✔
665
                        return errFileTransferUnmarshalling
×
666
                }
×
667

668
                for _, proto := range accept.Protocols {
29✔
669
                        if proto == ws.ProtoTypeFileTransfer {
28✔
670
                                return nil
13✔
671
                        }
13✔
672
                }
673
                // Let's try to be polite and close the session before returning
674
                //nolint:errcheck
675
                h.publishControlMessage(sessionID, deviceTopic, ws.MessageTypeClose, nil)
1✔
676
                return errFileTransferDisabled
1✔
677

678
        case <-time.After(fileTransferTimeout):
1✔
679
                return errFileTransferTimeout
1✔
680
        }
681
}
682

683
func (h ManagementController) uploadFileResponse(c *gin.Context, params *fileTransferParams,
684
        request *model.UploadFileRequest) {
8✔
685
        l := log.FromContext(c.Request.Context())
8✔
686

8✔
687
        // send a JSON-encoded error message in case of failure
8✔
688
        var responseError error
8✔
689
        errorStatusCode := http.StatusInternalServerError
8✔
690
        defer func() {
16✔
691
                if responseError != nil {
15✔
692
                        l.Error(responseError.Error())
7✔
693
                        c.JSON(errorStatusCode, gin.H{
7✔
694
                                "error": responseError.Error(),
7✔
695
                        })
7✔
696
                        return
7✔
697
                }
7✔
698
        }()
699

700
        // subscribe to messages from the device
701
        deviceTopic := model.GetDeviceSubject(params.TenantID, params.Device.ID)
8✔
702
        sessionTopic := model.GetSessionSubject(params.TenantID, params.SessionID)
8✔
703
        msgChan := make(chan *natsio.Msg, channelSize)
8✔
704
        sub, err := h.nats.ChanSubscribe(sessionTopic, msgChan)
8✔
705
        if err != nil {
8✔
706
                responseError = errors.Wrap(err, errFileTransferSubscribing.Error())
×
707
                return
×
708
        }
×
709

710
        //nolint:errcheck
711
        defer sub.Unsubscribe()
8✔
712

8✔
713
        if err = h.filetransferHandshake(msgChan, params.SessionID, deviceTopic); err != nil {
11✔
714
                switch err {
3✔
715
                case errFileTransferTimeout:
1✔
716
                        errorStatusCode = http.StatusRequestTimeout
1✔
717
                case errFileTransferNotImplemented, errFileTransferDisabled:
1✔
718
                        errorStatusCode = http.StatusBadGateway
1✔
719
                }
720
                responseError = err
3✔
721
                return
3✔
722
        }
723

724
        // Inform the device that we're closing the session
725
        //nolint:errcheck
726
        defer h.publishControlMessage(params.SessionID, deviceTopic, ws.MessageTypeClose, nil)
5✔
727

5✔
728
        // initialize the file transfer
5✔
729
        req := wsft.UploadRequest{
5✔
730
                SrcPath: request.SrcPath,
5✔
731
                Path:    request.Path,
5✔
732
                UID:     request.UID,
5✔
733
                GID:     request.GID,
5✔
734
                Mode:    request.Mode,
5✔
735
        }
5✔
736
        if err := h.publishFileTransferProtoMessage(params.SessionID,
5✔
737
                params.UserID, deviceTopic, wsft.MessageTypePut, req, 0); err != nil {
5✔
738
                responseError = err
×
739
                return
×
740
        }
×
741

742
        // receive the message from the device
743
        select {
5✔
744
        case wsMessage := <-msgChan:
4✔
745
                msg, msgBody, err := h.decodeFileTransferProtoMessage(wsMessage.Data)
4✔
746
                if err != nil {
4✔
747
                        responseError = err
×
748
                        return
×
749
                }
×
750

751
                // process incoming messages from the device by type
752
                switch msg.Header.MsgType {
4✔
753

754
                // error message, stop here
755
                case wsft.MessageTypeError:
1✔
756
                        errorMsg := msgBody.(*wsft.Error)
1✔
757
                        errorStatusCode = http.StatusBadRequest
1✔
758
                        responseError = errors.New(*errorMsg.Error)
1✔
759
                        return
1✔
760

761
                // you can continue the upload
762
                case wsft.MessageTypeACK:
3✔
763
                }
764

765
        // no message after timeout expired, stop here
766
        case <-time.After(fileTransferTimeout):
1✔
767
                errorStatusCode = http.StatusRequestTimeout
1✔
768
                responseError = errFileTransferTimeout
1✔
769
                return
1✔
770
        }
771

772
        // receive the ack message from the device
773
        latestAckOffsets := make(chan int64, 1)
3✔
774
        errorChan := make(chan error)
3✔
775
        go h.uploadFileResponseHandleInboundMessages(
3✔
776
                c, params, msgChan, errorChan, latestAckOffsets,
3✔
777
        )
3✔
778

3✔
779
        h.uploadFileResponseWriter(
3✔
780
                c, params, request, errorChan, latestAckOffsets, &errorStatusCode, &responseError,
3✔
781
        )
3✔
782
}
783

784
func (h ManagementController) uploadFileResponseWriter(c *gin.Context,
785
        params *fileTransferParams, request *model.UploadFileRequest,
786
        errorChan chan error, latestAckOffsets <-chan int64,
787
        errorStatusCode *int, responseError *error) {
3✔
788
        var (
3✔
789
                offset          int64
3✔
790
                latestAckOffset int64
3✔
791
        )
3✔
792
        deviceTopic := model.GetDeviceSubject(params.TenantID, params.Device.ID)
3✔
793

3✔
794
        timeout := time.NewTimer(fileTransferTimeout)
3✔
795
        data := make([]byte, fileTransferBufferSize)
3✔
796
        for {
7✔
797
                n, err := request.File.Read(data)
4✔
798
                if err != nil && err != io.EOF {
4✔
799
                        if err == io.ErrUnexpectedEOF {
×
800
                                *errorStatusCode = http.StatusBadRequest
×
801
                                *responseError = errors.New(
×
802
                                        "malformed request body: " +
×
803
                                                "did not find closing multipart boundary",
×
804
                                )
×
805
                        } else {
×
806
                                *responseError = err
×
807
                        }
×
808
                        return
×
809
                } else if n == 0 {
5✔
810
                        if err := h.publishFileTransferProtoMessage(params.SessionID,
1✔
811
                                params.UserID, deviceTopic, wsft.MessageTypeChunk, nil,
1✔
812
                                offset); err != nil {
1✔
813
                                *responseError = err
×
814
                                return
×
815
                        }
×
816
                        break
1✔
817
                }
818

819
                // send the chunk
820
                if err := h.publishFileTransferProtoMessage(params.SessionID,
3✔
821
                        params.UserID, deviceTopic, wsft.MessageTypeChunk, data[0:n],
3✔
822
                        offset); err != nil {
3✔
823
                        *responseError = err
×
824
                        return
×
825
                }
×
826

827
                // update the offset
828
                offset += int64(n)
3✔
829

3✔
830
                // wait for acks, in case the ack sliding window is over
3✔
831
                if offset > latestAckOffset+int64(fileTransferBufferSize*ackSlidingWindowRecv) {
6✔
832
                        timeout.Reset(fileTransferTimeout)
3✔
833
                        select {
3✔
834
                        case err := <-errorChan:
1✔
835
                                *errorStatusCode = http.StatusBadRequest
1✔
836
                                *responseError = err
1✔
837
                                return
1✔
838
                        case latestAckOffset = <-latestAckOffsets:
1✔
839
                        case <-timeout.C:
1✔
840
                                *errorStatusCode = http.StatusRequestTimeout
1✔
841
                                *responseError = errFileTransferTimeout
1✔
842
                                return
1✔
843
                        }
844
                } else {
×
845
                        // in case of error, report it
×
846
                        select {
×
847
                        case err := <-errorChan:
×
848
                                *errorStatusCode = http.StatusBadRequest
×
849
                                *responseError = err
×
850
                                return
×
851
                        default:
×
852
                        }
853
                }
854

855
        }
856

857
        for offset > latestAckOffset {
1✔
858
                timeout.Reset(fileTransferTimeout)
×
859
                select {
×
860
                case latestAckOffset = <-latestAckOffsets:
×
861
                case <-timeout.C:
×
862
                        *errorStatusCode = http.StatusRequestTimeout
×
863
                        *responseError = errFileTransferTimeout
×
864
                        return
×
865
                }
866
        }
867

868
        c.Writer.WriteHeader(http.StatusCreated)
1✔
869
}
870

871
func (h ManagementController) parseUploadFileRequest(c *gin.Context) (*model.UploadFileRequest,
872
        error) {
12✔
873
        reader, err := c.Request.MultipartReader()
12✔
874
        if err != nil {
12✔
875
                return nil, err
×
876
        }
×
877

878
        request := &model.UploadFileRequest{}
12✔
879
        for {
72✔
880
                part, err := reader.NextPart()
60✔
881
                if err == io.EOF {
61✔
882
                        break
1✔
883
                }
884
                if err != nil {
59✔
885
                        return nil, err
×
886
                }
×
887
                var n int
59✔
888
                data := make([]byte, fileTransferBufferSize)
59✔
889
                partName := part.FormName()
59✔
890
                switch partName {
59✔
891
                case fieldUploadPath, fieldUploadUID, fieldUploadGID, fieldUploadMode:
48✔
892
                        n, err = part.Read(data)
48✔
893
                        var value string
48✔
894
                        if err == nil || err == io.EOF {
96✔
895
                                value = string(data[:n])
48✔
896
                        }
48✔
897
                        switch partName {
48✔
898
                        case fieldUploadPath:
12✔
899
                                request.Path = &value
12✔
900
                        case fieldUploadUID:
12✔
901
                                v, err := strconv.Atoi(string(data[:n]))
12✔
902
                                if err != nil {
12✔
903
                                        return nil, err
×
904
                                }
×
905
                                nUID := uint32(v)
12✔
906
                                request.UID = &nUID
12✔
907
                        case fieldUploadGID:
12✔
908
                                v, err := strconv.Atoi(string(data[:n]))
12✔
909
                                if err != nil {
12✔
910
                                        return nil, err
×
911
                                }
×
912
                                nGID := uint32(v)
12✔
913
                                request.GID = &nGID
12✔
914
                        case fieldUploadMode:
12✔
915
                                v, err := strconv.ParseUint(string(data[:n]), 8, 32)
12✔
916
                                if err != nil {
12✔
917
                                        return nil, err
×
918
                                }
×
919
                                nMode := uint32(v)
12✔
920
                                request.Mode = &nMode
12✔
921
                        }
922
                        part.Close()
48✔
923
                case fieldUploadFile:
11✔
924
                        filename := part.FileName()
11✔
925
                        request.SrcPath = &filename
11✔
926
                        request.File = part
11✔
927
                }
928
                // file is the last part we can process, in order to avoid loading it in memory
929
                if request.File != nil {
70✔
930
                        break
11✔
931
                }
932
        }
933

934
        return request, nil
12✔
935
}
936

937
func (h ManagementController) UploadFile(c *gin.Context) {
18✔
938
        l := log.FromContext(c.Request.Context())
18✔
939

18✔
940
        params, statusCode, err := h.getFileTransferParams(c)
18✔
941
        if err != nil {
24✔
942
                l.Error(err.Error())
6✔
943
                c.JSON(statusCode, gin.H{"error": err.Error()})
6✔
944
                return
6✔
945
        }
6✔
946

947
        request, err := h.parseUploadFileRequest(c)
12✔
948
        if err != nil {
12✔
949
                l.Error(err.Error())
×
950
                c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
×
951
                return
×
952
        }
×
953

954
        if err := request.Validate(); err != nil {
15✔
955
                l.Error(err.Error())
3✔
956
                c.JSON(http.StatusBadRequest, gin.H{
3✔
957
                        "error": errors.Wrap(err, "bad request").Error(),
3✔
958
                })
3✔
959
                return
3✔
960
        }
3✔
961

962
        defer request.File.Close()
9✔
963

9✔
964
        ctx := c.Request.Context()
9✔
965
        if err := h.app.UploadFile(ctx, params.UserID, params.Device.ID,
9✔
966
                *request.Path); err != nil {
10✔
967
                l.Error(err)
1✔
968
                c.JSON(http.StatusInternalServerError, gin.H{
1✔
969
                        "error": errors.Wrap(err, "bad request").Error(),
1✔
970
                })
1✔
971
                return
1✔
972
        }
1✔
973

974
        h.uploadFileResponse(c, params, request)
8✔
975
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc