From f991fa0498c781badec675720d44666728ad9ea3 Mon Sep 17 00:00:00 2001 From: Alf-Rune Siqveland Date: Thu, 30 Nov 2023 10:52:30 +0100 Subject: [PATCH 1/3] chore: Move control logic earlier in the filetransfer handler Signed-off-by: Alf-Rune Siqveland --- api/http/management_filetransfer.go | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/api/http/management_filetransfer.go b/api/http/management_filetransfer.go index b5738106..ee504884 100644 --- a/api/http/management_filetransfer.go +++ b/api/http/management_filetransfer.go @@ -296,14 +296,16 @@ func (h ManagementController) downloadFileResponse(c *gin.Context, params *fileT responseError = errors.Wrap(err, errFileTransferSubscribing.Error()) return } + //nolint:errcheck + defer sub.Unsubscribe() if err = h.filetransferHandshake(msgChan, params.SessionID, deviceTopic); err != nil { responseError = err return } - + // Inform the device that we're closing the session //nolint:errcheck - defer sub.Unsubscribe() + defer h.publishControlMessage(params.SessionID, deviceTopic, ws.MessageTypeClose, nil) // stat the remote file req := wsft.StatFile{ @@ -315,10 +317,6 @@ func (h ManagementController) downloadFileResponse(c *gin.Context, params *fileT return } - // Inform the device that we're closing the session - //nolint:errcheck - defer h.publishControlMessage(params.SessionID, deviceTopic, ws.MessageTypeClose, nil) - ticker := time.NewTicker(fileTransferPingInterval) defer ticker.Stop() From 176380fd8ee5003652ad9885265bb13a41d8c9d6 Mon Sep 17 00:00:00 2001 From: Alf-Rune Siqveland Date: Thu, 30 Nov 2023 11:23:59 +0100 Subject: [PATCH 2/3] chore: Improve idle timeout logic for filetransfer downloads Signed-off-by: Alf-Rune Siqveland --- api/http/management_filetransfer.go | 60 +++++++++++++++++++++++------ 1 file changed, 48 insertions(+), 12 deletions(-) diff --git a/api/http/management_filetransfer.go b/api/http/management_filetransfer.go index ee504884..da57cb9f 100644 --- a/api/http/management_filetransfer.go +++ b/api/http/management_filetransfer.go @@ -15,6 +15,7 @@ package http import ( + "context" "fmt" "io" "net/http" @@ -280,8 +281,41 @@ func (h ManagementController) downloadFileResponseError(c *gin.Context, } } +func chanTimeout( + src <-chan *natsio.Msg, + timeout time.Duration, +) <-chan *natsio.Msg { + timer := time.NewTimer(timeout) + dst := make(chan *natsio.Msg) + go func() { + for { + select { + case <-timer.C: + close(dst) + return + case msg, ok := <-src: + if !ok { + close(dst) + return + } + if !timer.Stop() { + // Timer must be stopped and drained before calling Reset. + select { + case <-timer.C: + default: + } + } + timer.Reset(timeout) + dst <- msg + } + } + }() + return dst +} + func (h ManagementController) downloadFileResponse(c *gin.Context, params *fileTransferParams, request *model.DownloadFileRequest) { + ctx := c.Request.Context() // send a JSON-encoded error message in case of failure var responseError error var responseHeaderSent bool @@ -290,8 +324,9 @@ func (h ManagementController) downloadFileResponse(c *gin.Context, params *fileT // subscribe to messages from the device deviceTopic := model.GetDeviceSubject(params.TenantID, params.Device.ID) sessionTopic := model.GetSessionSubject(params.TenantID, params.SessionID) - msgChan := make(chan *natsio.Msg, channelSize) - sub, err := h.nats.ChanSubscribe(sessionTopic, msgChan) + subChan := make(chan *natsio.Msg, channelSize) + defer close(subChan) + sub, err := h.nats.ChanSubscribe(sessionTopic, subChan) if err != nil { responseError = errors.Wrap(err, errFileTransferSubscribing.Error()) return @@ -299,6 +334,8 @@ func (h ManagementController) downloadFileResponse(c *gin.Context, params *fileT //nolint:errcheck defer sub.Unsubscribe() + msgChan := chanTimeout(subChan, fileTransferTimeout) + if err = h.filetransferHandshake(msgChan, params.SessionID, deviceTopic); err != nil { responseError = err return @@ -321,15 +358,16 @@ func (h ManagementController) downloadFileResponse(c *gin.Context, params *fileT defer ticker.Stop() // handle messages from the device - timeout := time.NewTimer(fileTransferTimeout) latestOffset := int64(0) numberOfChunks := 0 var fileInfo wsft.FileInfo for { select { - case wsMessage := <-msgChan: - // reset the timeout ticket - timeout.Reset(fileTransferTimeout) + case wsMessage, ok := <-msgChan: + if !ok { + responseError = errFileTransferTimeout + return + } // process the message err := h.downloadFileResponseProcessMessage(c, params, request, wsMessage, deviceTopic, &latestOffset, &numberOfChunks, @@ -348,11 +386,6 @@ func (h ManagementController) downloadFileResponse(c *gin.Context, params *fileT if responseError != nil { return } - - // no message after timeout expired, stop here - case <-timeout.C: - responseError = errFileTransferTimeout - return } } } @@ -555,7 +588,10 @@ func (h ManagementController) filetransferHandshake( return errFileTransferPublishing } select { - case natsMsg := <-sessChan: + case natsMsg, ok := <-sessChan: + if !ok { + return errFileTransferTimeout + } var msg ws.ProtoMsg err := msgpack.Unmarshal(natsMsg.Data, &msg) if err != nil { From d11bb25d9b3ff8e63d44ce76302d66bab657be38 Mon Sep 17 00:00:00 2001 From: Alf-Rune Siqveland Date: Thu, 30 Nov 2023 17:03:25 +0100 Subject: [PATCH 3/3] feat: Add support for HEAD to download API endpoint Changelog: Title Ticket: ALV-182 Signed-off-by: Alf-Rune Siqveland --- api/http/management_filetransfer.go | 353 +++++++++++++---------- api/http/management_filetransfer_test.go | 20 +- api/http/router.go | 1 + 3 files changed, 221 insertions(+), 153 deletions(-) diff --git a/api/http/management_filetransfer.go b/api/http/management_filetransfer.go index da57cb9f..7d8eedbc 100644 --- a/api/http/management_filetransfer.go +++ b/api/http/management_filetransfer.go @@ -22,7 +22,6 @@ import ( "os" "path" "strconv" - "strings" "time" "github.com/gin-gonic/gin" @@ -33,6 +32,7 @@ import ( "github.com/mendersoftware/go-lib-micro/identity" "github.com/mendersoftware/go-lib-micro/log" + "github.com/mendersoftware/go-lib-micro/requestid" "github.com/mendersoftware/go-lib-micro/ws" wsft "github.com/mendersoftware/go-lib-micro/ws/filetransfer" @@ -69,21 +69,52 @@ const ( paramDownloadPath = "path" ) -var fileTransferPingInterval = 30 * time.Second var fileTransferTimeout = 60 * time.Second var fileTransferBufferSize = 4096 var ackSlidingWindowSend = 10 var ackSlidingWindowRecv = 20 +type Error struct { + error error + statusCode int +} + +func NewError(err error, code int) error { + return &Error{ + error: err, + statusCode: code, + } +} + +func (err *Error) Error() string { + return err.error.Error() +} + +func (err *Error) Unwrap() error { + return err.error +} + var ( - errFileTransferMarshalling = errors.New("failed to marshal the request") - errFileTransferUnmarshalling = errors.New("failed to unmarshal the request") - errFileTransferPublishing = errors.New("failed to publish the message") - errFileTransferSubscribing = errors.New("failed to subscribe to the mesages") - errFileTransferTimeout = errors.New("file transfer timed out") - errFileTransferFailed = errors.New("file transfer failed") - errFileTransferNotImplemented = errors.New("file transfer not implemented on device") - errFileTransferDisabled = errors.New("file transfer disabled on device") + errFileTransferMarshalling = errors.New("failed to marshal the request") + errFileTransferUnmarshalling = errors.New("failed to unmarshal the request") + errFileTransferPublishing = errors.New("failed to publish the message") + errFileTransferSubscribing = errors.New("failed to subscribe to the mesages") + errFileTransferTimeout = &Error{ + error: errors.New("file transfer timed out"), + statusCode: http.StatusRequestTimeout, + } + errFileTransferFailed = &Error{ + error: errors.New("file transfer failed"), + statusCode: http.StatusBadRequest, + } + errFileTransferNotImplemented = &Error{ + error: errors.New("file transfer not implemented on device"), + statusCode: http.StatusBadGateway, + } + errFileTransferDisabled = &Error{ + error: errors.New("file transfer disabled on device"), + statusCode: http.StatusBadGateway, + } ) var newFileTransferSessionID = func() (uuid.UUID, error) { @@ -237,7 +268,6 @@ func (h ManagementController) decodeFileTransferProtoMessage(data []byte) (*ws.P } func writeHeaders(c *gin.Context, fileInfo *wsft.FileInfo) { - c.Writer.WriteHeader(http.StatusOK) c.Writer.Header().Add(hdrContentType, "application/octet-stream") if fileInfo.Path != nil { filename := path.Base(*fileInfo.Path) @@ -257,27 +287,28 @@ func writeHeaders(c *gin.Context, fileInfo *wsft.FileInfo) { if fileInfo.Size != nil { c.Writer.Header().Add(hdrMenderFileTransferSize, fmt.Sprintf("%d", *fileInfo.Size)) } + c.Writer.WriteHeader(http.StatusOK) } - -func (h ManagementController) downloadFileResponseError(c *gin.Context, - responseHeaderSent *bool, responseError *error) { +func (h ManagementController) handleResponseError(c *gin.Context, err error) { l := log.FromContext(c.Request.Context()) - if !*responseHeaderSent && *responseError != nil { - l.Error((*responseError).Error()) - status := http.StatusInternalServerError - // errFileTranserFailed is a special case, we return 400 instead of 500 - if strings.Contains((*responseError).Error(), errFileTransferFailed.Error()) { - status = http.StatusBadRequest - } else if *responseError == errFileTransferTimeout { - status = http.StatusRequestTimeout - } else if *responseError == errFileTransferNotImplemented || - *responseError == errFileTransferDisabled { - status = http.StatusBadGateway - } - c.JSON(status, gin.H{ - "error": (*responseError).Error(), + l.Errorf("error handling request: %s", err.Error()) + if !c.Writer.Written() { + var statusError *Error + var errMsg string = err.Error() + var statusCode int = http.StatusInternalServerError + if errors.As(err, &statusError) { + statusCode = statusError.statusCode + } + if statusCode >= 500 { + errMsg = "internal error" + } + c.Writer.WriteHeader(statusCode) + c.JSON(statusCode, gin.H{ + "error": errMsg, + "request_id": requestid.FromContext(c.Request.Context()), }) - return + } else { + l.Warn("response already written") } } @@ -313,13 +344,56 @@ func chanTimeout( return dst } +func (h ManagementController) statFile( + ctx context.Context, + sessChan <-chan *natsio.Msg, + path, sessionID, userID, deviceTopic string) (*wsft.FileInfo, error) { + // stat the remote file + req := wsft.StatFile{ + Path: &path, + } + if err := h.publishFileTransferProtoMessage(sessionID, + userID, deviceTopic, wsft.MessageTypeStat, req, 0); err != nil { + return nil, err + } + var fileInfo *wsft.FileInfo + select { + case rsp, ok := <-sessChan: + if !ok { + return nil, errFileTransferTimeout + } + var msg ws.ProtoMsg + err := msgpack.Unmarshal(rsp.Data, &msg) + if err != nil { + return nil, fmt.Errorf("malformed message from device: %w", err) + } + if msg.Header.MsgType == ws.MessageTypeError { + var errMsg ws.Error + _ = msgpack.Unmarshal(msg.Body, &errMsg) + rspErr := NewError( + fmt.Errorf("error received from device: %s", errMsg.Error), + http.StatusBadRequest, + ) + return nil, rspErr + } + if msg.Header.Proto != ws.ProtoTypeFileTransfer || + msg.Header.MsgType != wsft.MessageTypeFileInfo { + return nil, fmt.Errorf("unexpected response from device %q", msg.Header.MsgType) + } + err = msgpack.Unmarshal(msg.Body, &fileInfo) + if err != nil { + return nil, fmt.Errorf("malformed message body from device: %w", err) + } + case <-ctx.Done(): + return nil, ctx.Err() + } + return fileInfo, nil +} + func (h ManagementController) downloadFileResponse(c *gin.Context, params *fileTransferParams, request *model.DownloadFileRequest) { ctx := c.Request.Context() // send a JSON-encoded error message in case of failure - var responseError error - var responseHeaderSent bool - defer h.downloadFileResponseError(c, &responseHeaderSent, &responseError) // subscribe to messages from the device deviceTopic := model.GetDeviceSubject(params.TenantID, params.Device.ID) @@ -328,7 +402,7 @@ func (h ManagementController) downloadFileResponse(c *gin.Context, params *fileT defer close(subChan) sub, err := h.nats.ChanSubscribe(sessionTopic, subChan) if err != nil { - responseError = errors.Wrap(err, errFileTransferSubscribing.Error()) + h.handleResponseError(c, errors.Wrap(err, errFileTransferSubscribing.Error())) return } //nolint:errcheck @@ -337,151 +411,130 @@ func (h ManagementController) downloadFileResponse(c *gin.Context, params *fileT msgChan := chanTimeout(subChan, fileTransferTimeout) if err = h.filetransferHandshake(msgChan, params.SessionID, deviceTopic); err != nil { - responseError = err + h.handleResponseError(c, err) return } // Inform the device that we're closing the session //nolint:errcheck defer h.publishControlMessage(params.SessionID, deviceTopic, ws.MessageTypeClose, nil) - // stat the remote file - req := wsft.StatFile{ - Path: request.Path, + fileInfo, err := h.statFile( + ctx, msgChan, *request.Path, + params.SessionID, params.UserID, deviceTopic, + ) + if err != nil { + h.handleResponseError(c, fmt.Errorf("failed to retrieve file info: %w", err)) + return } - if err := h.publishFileTransferProtoMessage(params.SessionID, - params.UserID, deviceTopic, wsft.MessageTypeStat, req, 0); err != nil { - responseError = err + if fileInfo.Mode == nil || !os.FileMode(*fileInfo.Mode).IsRegular() { + h.handleResponseError( + c, + NewError(fmt.Errorf("file must be a regular file"), http.StatusBadRequest), + ) return } + writeHeaders(c, fileInfo) + if c.Request.Method == http.MethodHead { + return + } + err = h.downloadFile( + ctx, msgChan, c.Writer, *request.Path, + params.SessionID, params.UserID, deviceTopic, + ) + if err != nil { + log.FromContext(ctx). + Errorf("error downloading file from device: %s", err.Error()) + } +} - ticker := time.NewTicker(fileTransferPingInterval) - defer ticker.Stop() - - // handle messages from the device +func (h ManagementController) downloadFile( + ctx context.Context, + msgChan <-chan *natsio.Msg, + dst io.Writer, + path, sessionID, userID, deviceTopic string, +) error { latestOffset := int64(0) numberOfChunks := 0 - var fileInfo wsft.FileInfo + req := wsft.GetFile{ + Path: &path, + } + if err := h.publishFileTransferProtoMessage( + sessionID, + userID, + deviceTopic, + wsft.MessageTypeGet, + req, 0); err != nil { + return err + } for { select { case wsMessage, ok := <-msgChan: if !ok { - responseError = errFileTransferTimeout - return + return errFileTransferTimeout } + // process the message - err := h.downloadFileResponseProcessMessage(c, params, request, - wsMessage, deviceTopic, &latestOffset, &numberOfChunks, - &responseHeaderSent, &fileInfo, ticker) - if err == io.EOF { - return - } else if err != nil { - responseError = err - return - } - // send a Ping message to keep the session alive - case <-ticker.C: - responseError = h.publishControlMessage( - params.SessionID, deviceTopic, ws.MessageTypePing, nil, - ) - if responseError != nil { - return + msg, msgBody, err := h.decodeFileTransferProtoMessage(wsMessage.Data) + if err != nil { + return err } - } - } -} -func (h ManagementController) downloadFileResponseProcessMessage(c *gin.Context, - params *fileTransferParams, request *model.DownloadFileRequest, wsMessage *natsio.Msg, - deviceTopic string, latestOffset *int64, numberOfChunks *int, responseHeaderSent *bool, - fileInfo *wsft.FileInfo, ticker *time.Ticker) error { - msg, msgBody, err := h.decodeFileTransferProtoMessage(wsMessage.Data) - if err != nil { - return err - } - - // process incoming messages from the device by type - switch msg.Header.MsgType { - - // error message, stop here - case wsft.MessageTypeError: - errorMsg := msgBody.(*wsft.Error) - if *errorMsg.MessageType == wsft.MessageTypeStat { - return errors.Wrap(errors.New(*errorMsg.Error), - errFileTransferFailed.Error()) - } else { - return errors.New(*errorMsg.Error) - } + // process incoming messages from the device by type + switch msg.Header.MsgType { - // file stat response, if okay, let's get the file - case wsft.MessageTypeFileInfo: - req := wsft.GetFile{ - Path: request.Path, - } - if err := h.publishFileTransferProtoMessage(params.SessionID, - params.UserID, deviceTopic, wsft.MessageTypeGet, - req, 0); err != nil { - return err - } - *fileInfo = *msgBody.(*wsft.FileInfo) - if (os.FileMode(*fileInfo.Mode) & os.ModeType) != 0 { - err := errors.New("path is not a regular file") - return errors.Wrap(err, errFileTransferFailed.Error()) - } + // error message, stop here + case wsft.MessageTypeError: + errorMsg := msgBody.(*wsft.Error) + return errors.New(*errorMsg.Error) + + // file data chunk + case wsft.MessageTypeChunk: + if msg.Body == nil { + if err := h.publishFileTransferProtoMessage( + sessionID, userID, deviceTopic, + wsft.MessageTypeACK, nil, + latestOffset); err != nil { + return err + } + return io.EOF + } - // file data chunk - case wsft.MessageTypeChunk: - if !*responseHeaderSent { - writeHeaders(c, fileInfo) - *responseHeaderSent = true - } - if msg.Body == nil { - if err := h.publishFileTransferProtoMessage( - params.SessionID, params.UserID, deviceTopic, - wsft.MessageTypeACK, nil, - *latestOffset); err != nil { - return err - } - return io.EOF - } + // verify the offset property + propOffset, _ := msg.Header.Properties[PropertyOffset].(int64) + if propOffset != latestOffset { + return errors.Wrap(errFileTransferFailed, + "wrong offset received") + } + latestOffset += int64(len(msg.Body)) - // verify the offset property - propOffset, _ := msg.Header.Properties[PropertyOffset].(int64) - if propOffset != *latestOffset { - return errors.Wrap(errFileTransferFailed, - "wrong offset received") - } - *latestOffset += int64(len(msg.Body)) + _, err := dst.Write(msg.Body) + if err != nil { + return err + } - _, err := c.Writer.Write(msg.Body) - if err != nil { - return err - } + numberOfChunks++ + if numberOfChunks >= ackSlidingWindowSend { + if err := h.publishFileTransferProtoMessage( + sessionID, userID, deviceTopic, + wsft.MessageTypeACK, nil, + latestOffset); err != nil { + return err + } + numberOfChunks = 0 + } - (*numberOfChunks)++ - if *numberOfChunks >= ackSlidingWindowSend { - if err := h.publishFileTransferProtoMessage( - params.SessionID, params.UserID, deviceTopic, - wsft.MessageTypeACK, nil, - *latestOffset); err != nil { - return err + case ws.MessageTypePing: + if err := h.publishFileTransferProtoMessage( + sessionID, userID, deviceTopic, + ws.MessageTypePong, nil, + -1); err != nil { + return err + } } - *numberOfChunks = 0 - } - - case ws.MessageTypePing: - if err := h.publishFileTransferProtoMessage( - params.SessionID, params.UserID, deviceTopic, - ws.MessageTypePong, nil, - -1); err != nil { - return err + case <-ctx.Done(): + return ctx.Err() } - fallthrough - - case ws.MessageTypePong: - ticker.Reset(fileTransferPingInterval) } - - return nil } func (h ManagementController) DownloadFile(c *gin.Context) { diff --git a/api/http/management_filetransfer_test.go b/api/http/management_filetransfer_test.go index 7984d2a1..f1459f25 100644 --- a/api/http/management_filetransfer_test.go +++ b/api/http/management_filetransfer_test.go @@ -58,17 +58,14 @@ func int642pointer(v int64) *int64 { func TestManagementDownloadFile(t *testing.T) { originalNewFileTransferSessionID := newFileTransferSessionID originalFileTransferTimeout := fileTransferTimeout - originalFileTransferPingInterval := fileTransferPingInterval originalAckSlidingWindowSend := ackSlidingWindowSend defer func() { newFileTransferSessionID = originalNewFileTransferSessionID fileTransferTimeout = originalFileTransferTimeout - fileTransferPingInterval = originalFileTransferPingInterval ackSlidingWindowSend = originalAckSlidingWindowSend }() fileTransferTimeout = 2 * time.Second - fileTransferPingInterval = 500 * time.Millisecond ackSlidingWindowSend = 1 sessionID, _ := uuid.NewRandom() @@ -136,6 +133,7 @@ func TestManagementDownloadFile(t *testing.T) { bodyData, err := msgpack.Marshal(body) msg = &ws.ProtoMsg{ Header: ws.ProtoHdr{ + Proto: ws.ProtoTypeFileTransfer, MsgType: wsft.MessageTypeFileInfo, SessionID: sessionID.String(), }, @@ -149,6 +147,7 @@ func TestManagementDownloadFile(t *testing.T) { // first chunk msg = &ws.ProtoMsg{ Header: ws.ProtoHdr{ + Proto: ws.ProtoTypeFileTransfer, MsgType: wsft.MessageTypeChunk, SessionID: sessionID.String(), Properties: map[string]interface{}{ @@ -165,6 +164,7 @@ func TestManagementDownloadFile(t *testing.T) { // final chunk msg = &ws.ProtoMsg{ Header: ws.ProtoHdr{ + Proto: ws.ProtoTypeFileTransfer, MsgType: wsft.MessageTypeChunk, SessionID: sessionID.String(), Properties: map[string]interface{}{ @@ -259,6 +259,7 @@ func TestManagementDownloadFile(t *testing.T) { bodyData, err := msgpack.Marshal(body) msg = &ws.ProtoMsg{ Header: ws.ProtoHdr{ + Proto: ws.ProtoTypeFileTransfer, MsgType: wsft.MessageTypeFileInfo, SessionID: sessionID.String(), }, @@ -272,6 +273,7 @@ func TestManagementDownloadFile(t *testing.T) { // first chunk msg = &ws.ProtoMsg{ Header: ws.ProtoHdr{ + Proto: ws.ProtoTypeFileTransfer, MsgType: wsft.MessageTypeChunk, SessionID: sessionID.String(), Properties: map[string]interface{}{ @@ -288,6 +290,7 @@ func TestManagementDownloadFile(t *testing.T) { // second chunk msg = &ws.ProtoMsg{ Header: ws.ProtoHdr{ + Proto: ws.ProtoTypeFileTransfer, MsgType: wsft.MessageTypeChunk, SessionID: sessionID.String(), Properties: map[string]interface{}{ @@ -304,6 +307,7 @@ func TestManagementDownloadFile(t *testing.T) { // final chunk msg = &ws.ProtoMsg{ Header: ws.ProtoHdr{ + Proto: ws.ProtoTypeFileTransfer, MsgType: wsft.MessageTypeChunk, SessionID: sessionID.String(), }, @@ -394,6 +398,7 @@ func TestManagementDownloadFile(t *testing.T) { bodyData, err := msgpack.Marshal(body) msg = &ws.ProtoMsg{ Header: ws.ProtoHdr{ + Proto: ws.ProtoTypeFileTransfer, MsgType: wsft.MessageTypeError, SessionID: sessionID.String(), }, @@ -486,6 +491,7 @@ func TestManagementDownloadFile(t *testing.T) { bodyData, err := msgpack.Marshal(body) msg = &ws.ProtoMsg{ Header: ws.ProtoHdr{ + Proto: ws.ProtoTypeFileTransfer, MsgType: wsft.MessageTypeFileInfo, SessionID: sessionID.String(), }, @@ -575,6 +581,7 @@ func TestManagementDownloadFile(t *testing.T) { bodyData, err := msgpack.Marshal(body) msg = &ws.ProtoMsg{ Header: ws.ProtoHdr{ + Proto: ws.ProtoTypeFileTransfer, MsgType: wsft.MessageTypeFileInfo, SessionID: sessionID.String(), }, @@ -588,6 +595,7 @@ func TestManagementDownloadFile(t *testing.T) { // first chunk msg = &ws.ProtoMsg{ Header: ws.ProtoHdr{ + Proto: ws.ProtoTypeFileTransfer, MsgType: wsft.MessageTypeChunk, SessionID: sessionID.String(), Properties: map[string]interface{}{ @@ -609,6 +617,7 @@ func TestManagementDownloadFile(t *testing.T) { bodyData, err = msgpack.Marshal(errBody) msg = &ws.ProtoMsg{ Header: ws.ProtoHdr{ + Proto: ws.ProtoTypeFileTransfer, MsgType: wsft.MessageTypeError, SessionID: sessionID.String(), }, @@ -770,6 +779,7 @@ func TestManagementDownloadFile(t *testing.T) { bodyData, err := msgpack.Marshal(body) msg = &ws.ProtoMsg{ Header: ws.ProtoHdr{ + Proto: ws.ProtoTypeFileTransfer, MsgType: wsft.MessageTypeFileInfo, SessionID: sessionID.String(), }, @@ -783,6 +793,7 @@ func TestManagementDownloadFile(t *testing.T) { // first chunk msg = &ws.ProtoMsg{ Header: ws.ProtoHdr{ + Proto: ws.ProtoTypeFileTransfer, MsgType: wsft.MessageTypeChunk, SessionID: sessionID.String(), Properties: map[string]interface{}{ @@ -876,6 +887,7 @@ func TestManagementDownloadFile(t *testing.T) { bodyData, err := msgpack.Marshal(body) msg = &ws.ProtoMsg{ Header: ws.ProtoHdr{ + Proto: ws.ProtoTypeFileTransfer, MsgType: wsft.MessageTypeFileInfo, SessionID: sessionID.String(), }, @@ -889,6 +901,7 @@ func TestManagementDownloadFile(t *testing.T) { // first chunk msg = &ws.ProtoMsg{ Header: ws.ProtoHdr{ + Proto: ws.ProtoTypeFileTransfer, MsgType: wsft.MessageTypeChunk, SessionID: sessionID.String(), Properties: map[string]interface{}{ @@ -905,6 +918,7 @@ func TestManagementDownloadFile(t *testing.T) { // second chunk msg = &ws.ProtoMsg{ Header: ws.ProtoHdr{ + Proto: ws.ProtoTypeFileTransfer, MsgType: wsft.MessageTypeChunk, SessionID: sessionID.String(), Properties: map[string]interface{}{ diff --git a/api/http/router.go b/api/http/router.go index 0d7d2341..03b14edc 100644 --- a/api/http/router.go +++ b/api/http/router.go @@ -101,6 +101,7 @@ func NewRouter( router.GET(APIURLManagementDevice, management.GetDevice) router.GET(APIURLManagementDeviceConnect, management.Connect) router.GET(APIURLManagementDeviceDownload, management.DownloadFile) + router.HEAD(APIURLManagementDeviceDownload, management.DownloadFile) router.POST(APIURLManagementDeviceCheckUpdate, management.CheckUpdate) router.POST(APIURLManagementDeviceSendInventory, management.SendInventory) router.PUT(APIURLManagementDeviceUpload, management.UploadFile)