diff --git a/pkg/api/aim/api/request/run.go b/pkg/api/aim/api/request/run.go index 73922bebd..3d962f05e 100644 --- a/pkg/api/aim/api/request/run.go +++ b/pkg/api/aim/api/request/run.go @@ -22,6 +22,15 @@ type GetRunMetricsRequest []struct { Context map[string]string `json:"context"` } +// GetRunImagesRequest is a request object for `POST /runs/:id/images/get-batch` endpoint. +type GetRunImagesRequest []struct { + Name string `json:"name"` + Context map[string]string `json:"context"` +} + +// GetRunImagesBatchRequest is a request object for `POST /runs/images/get-batch` endpoint. +type GetRunImagesBatchRequest []string + // GetRunsActiveRequest is a request object for `GET /runs/active` endpoint. type GetRunsActiveRequest struct { BaseSearchRequest diff --git a/pkg/api/aim/api/response/run.go b/pkg/api/aim/api/response/run.go index 97985914c..bb537d430 100644 --- a/pkg/api/aim/api/response/run.go +++ b/pkg/api/aim/api/response/run.go @@ -7,7 +7,9 @@ import ( "encoding/json" "fmt" "math" + "path/filepath" "slices" + "strings" "time" "github.com/gofiber/fiber/v2" @@ -19,6 +21,8 @@ import ( "github.com/G-Research/fasttrackml/pkg/api/aim/dao/models" "github.com/G-Research/fasttrackml/pkg/api/aim/dao/repositories" "github.com/G-Research/fasttrackml/pkg/api/aim/encoding" + mlflowCommon "github.com/G-Research/fasttrackml/pkg/api/mlflow/common" + "github.com/G-Research/fasttrackml/pkg/api/mlflow/services/artifact/storage" "github.com/G-Research/fasttrackml/pkg/common" "github.com/G-Research/fasttrackml/pkg/database" ) @@ -37,10 +41,10 @@ type GetRunInfoParamsPartial map[string]any type GetRunInfoTracesPartial struct { Tags map[string]string `json:"tags"` Logs map[string]string `json:"logs"` - Texts map[string]string `json:"texts"` + Texts []GetRunInfoTracesMetricPartial `json:"texts"` Audios map[string]string `json:"audios"` Metric []GetRunInfoTracesMetricPartial `json:"metric"` - Images map[string]string `json:"images"` + Images []GetRunInfoTracesMetricPartial `json:"images"` Figures map[string]string `json:"figures"` LogRecords map[string]string `json:"log_records"` Distributions map[string]string `json:"distributions"` @@ -74,7 +78,7 @@ type GetRunInfoResponse struct { } // NewGetRunInfoResponse creates new response object for `GER runs/:id/info` endpoint. -func NewGetRunInfoResponse(run *models.Run) *GetRunInfoResponse { +func NewGetRunInfoResponse(run *models.Run, artifacts []storage.ArtifactObject) *GetRunInfoResponse { metrics := make([]GetRunInfoTracesMetricPartial, len(run.LatestMetrics)) for i, metric := range run.LatestMetrics { metrics[i] = GetRunInfoTracesMetricPartial{ @@ -84,6 +88,43 @@ func NewGetRunInfoResponse(run *models.Run) *GetRunInfoResponse { } } + imagesCounter := 0 + textsCounter := 0 + const imageMimeType = "image/" + const textMimeType = "text/" + for _, artifact := range artifacts { + filename := filepath.Base(artifact.Path) + mime := mlflowCommon.GetContentType(filename) + if strings.HasPrefix(mime, imageMimeType) { + imagesCounter++ + } else if strings.HasPrefix(mime, textMimeType) { + textsCounter++ + } + } + + images := make([]GetRunInfoTracesMetricPartial, imagesCounter) + texts := make([]GetRunInfoTracesMetricPartial, textsCounter) + imagesCounter = 0 + textsCounter = 0 + for _, artifact := range artifacts { + filename := filepath.Base(artifact.Path) + mime := mlflowCommon.GetContentType(filename) + if strings.HasPrefix(mime, imageMimeType) { + images[imagesCounter] = GetRunInfoTracesMetricPartial{ + Name: artifact.Path, + Context: nil, + LastValue: 0, + } + imagesCounter++ + } else if strings.HasPrefix(mime, textMimeType) { + texts[textsCounter] = GetRunInfoTracesMetricPartial{ + Name: artifact.Path, + Context: nil, + LastValue: 0, + } + textsCounter++ + } + } params := make(GetRunInfoParamsPartial, len(run.Params)+1) for _, p := range run.Params { params[p.Key] = p.ValueAny() @@ -99,10 +140,10 @@ func NewGetRunInfoResponse(run *models.Run) *GetRunInfoResponse { Traces: GetRunInfoTracesPartial{ Tags: map[string]string{}, Logs: map[string]string{}, - Texts: map[string]string{}, + Texts: texts, Audios: map[string]string{}, Metric: metrics, - Images: map[string]string{}, + Images: images, Figures: map[string]string{}, LogRecords: map[string]string{}, Distributions: map[string]string{}, @@ -918,3 +959,83 @@ func NewActiveRunsStreamResponse(ctx *fiber.Ctx, runs []models.Run, reportProgre }) return nil } + +// NewRunImagesStreamResponse streams the provided images to the fiber context. +func NewRunImagesStreamResponse(ctx *fiber.Ctx, images []models.Image) error { + ctx.Set("Content-Type", "application/octet-stream") + ctx.Context().SetBodyStreamWriter(func(w *bufio.Writer) { + start := time.Now() + if err := func() error { + var values [][]map[string]interface{} + var valuesResult []map[string]interface{} + + for _, image := range images { + for _, valueArray := range image.Values { + for _, val := range valueArray { + valMap := map[string]interface{}{ + "blob_uri": val.BlobURI, + "caption": val.Caption, + "context": val.Context, + "format": val.Format, + "height": val.Height, + "index": val.Index, + "key": val.Key, + "seqKey": val.SeqKey, + "name": val.Name, + "run": val.Run, + "step": val.Step, + "width": val.Width, + } + valuesResult = append(valuesResult, valMap) + } + } + + values = append(values, valuesResult) + imgMap := map[string]interface{}{ + "record_range": image.RecordRange, + "index_range": image.IndexRange, + "name": image.Name, + "context": image.Context, + "values": values, + "iters": image.Iters, + } + + if err := encoding.EncodeTree(w, imgMap); err != nil { + return err + } + } + + return nil + }(); err != nil { + log.Errorf("Error encountered in %s %s: error streaming active runs: %s", ctx.Method(), ctx.Path(), err) + } + + log.Infof("body - %s %s %s", time.Since(start), ctx.Method(), ctx.Path()) + }) + return nil +} + +// NewRunImagesBatchStreamResponse streams the provided images to the fiber context. +func NewRunImagesBatchStreamResponse(ctx *fiber.Ctx, imagesMap map[string]any) error { + ctx.Context().Response.SetBodyStreamWriter(func(w *bufio.Writer) { + start := time.Now() + if err := func() error { + if err := encoding.EncodeTree(w, imagesMap); err != nil { + return err + } + if err := w.Flush(); err != nil { + return eris.Wrap(err, "error flushing output stream") + } + return nil + }(); err != nil { + log.Errorf( + "error encountered in %s %s: error streaming artifact: %s", + ctx.Method(), + ctx.Path(), + err, + ) + } + log.Infof("body - %s %s %s", time.Since(start), ctx.Method(), ctx.Path()) + }) + return nil +} diff --git a/pkg/api/aim/controller/controller.go b/pkg/api/aim/controller/controller.go index e6c72e615..b5f950358 100644 --- a/pkg/api/aim/controller/controller.go +++ b/pkg/api/aim/controller/controller.go @@ -7,6 +7,7 @@ import ( "github.com/G-Research/fasttrackml/pkg/api/aim/services/project" "github.com/G-Research/fasttrackml/pkg/api/aim/services/run" "github.com/G-Research/fasttrackml/pkg/api/aim/services/tag" + "github.com/G-Research/fasttrackml/pkg/api/mlflow/services/artifact" ) // Controller handles all the input HTTP requests. @@ -14,6 +15,7 @@ type Controller struct { tagService *tag.Service appService *app.Service runService *run.Service + artifactService *artifact.Service projectService *project.Service dashboardService *dashboard.Service experimentService *experiment.Service @@ -24,6 +26,7 @@ func NewController( tagService *tag.Service, appService *app.Service, runService *run.Service, + artifactService *artifact.Service, projectService *project.Service, dashboardService *dashboard.Service, experimentService *experiment.Service, @@ -32,6 +35,7 @@ func NewController( tagService: tagService, appService: appService, runService: runService, + artifactService: artifactService, projectService: projectService, dashboardService: dashboardService, experimentService: experimentService, diff --git a/pkg/api/aim/controller/helpers.go b/pkg/api/aim/controller/helpers.go index 2f5398e99..ee6a4e621 100644 --- a/pkg/api/aim/controller/helpers.go +++ b/pkg/api/aim/controller/helpers.go @@ -1,8 +1,13 @@ package controller import ( + "bytes" + "io" + "github.com/gofiber/fiber/v2" + "github.com/rotisserie/eris" + "github.com/G-Research/fasttrackml/pkg/api/aim/api/request" "github.com/G-Research/fasttrackml/pkg/common/api" ) @@ -16,3 +21,19 @@ func convertError(err error) error { } return err } + +func convertImagesToMap( + images []io.ReadCloser, req request.GetRunImagesBatchRequest, +) (map[string]any, error) { + imagesMap := make(map[string]any) + + for i, image := range images { + var buffer bytes.Buffer + _, err := io.CopyBuffer(&buffer, image, make([]byte, 4096)) + if err != nil { + return nil, eris.Wrap(err, "error copying artifact Reader to output stream") + } + imagesMap[req[i]] = buffer.Bytes() + } + return imagesMap, nil +} diff --git a/pkg/api/aim/controller/runs.go b/pkg/api/aim/controller/runs.go index 0e1b9394c..3dbfe10de 100644 --- a/pkg/api/aim/controller/runs.go +++ b/pkg/api/aim/controller/runs.go @@ -9,6 +9,7 @@ import ( "github.com/G-Research/fasttrackml/pkg/api/aim/api/request" "github.com/G-Research/fasttrackml/pkg/api/aim/api/response" "github.com/G-Research/fasttrackml/pkg/api/aim/services/run" + mlflowRequest "github.com/G-Research/fasttrackml/pkg/api/mlflow/api/request" "github.com/G-Research/fasttrackml/pkg/common/api" "github.com/G-Research/fasttrackml/pkg/common/middleware" ) @@ -34,12 +35,21 @@ func (c Controller) GetRunInfo(ctx *fiber.Ctx) error { return err } - resp := response.NewGetRunInfoResponse(runInfo) + artifactReq := mlflowRequest.ListArtifactsRequest{ + RunUUID: req.ID, + } + + _, artifacts, err := c.artifactService.ListArtifacts(ctx.Context(), ns, &artifactReq) + if err != nil { + return err + } + + resp := response.NewGetRunInfoResponse(runInfo, artifacts) log.Debugf("getRunInfo response: %#v", resp) return ctx.JSON(resp) } -// GetRunMetrics handles `GET /runs/:id/metric/get-batch` endpoint. +// GetRunMetrics handles `POST /runs/:id/metric/get-batch` endpoint. func (c Controller) GetRunMetrics(ctx *fiber.Ctx) error { ns, err := middleware.GetNamespaceFromContext(ctx.Context()) if err != nil { @@ -62,6 +72,52 @@ func (c Controller) GetRunMetrics(ctx *fiber.Ctx) error { return ctx.JSON(resp) } +// GetRunImages handles `POST /runs/:id/images/get-batch` endpoint. +func (c Controller) GetRunImages(ctx *fiber.Ctx) error { + ns, err := middleware.GetNamespaceFromContext(ctx.Context()) + if err != nil { + return api.NewInternalError("error getting namespace from context") + } + log.Debugf("getRunImages namespace: %s", ns.Code) + + req := request.GetRunImagesRequest{} + if err := ctx.BodyParser(&req); err != nil { + return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error()) + } + + images, err := c.runService.GetRunImages(ctx.Context(), ns.ID, ctx.Params("id"), &req) + if err != nil { + return fiber.NewError(fiber.StatusInternalServerError, err.Error()) + } + + return response.NewRunImagesStreamResponse(ctx, images) +} + +// GetRunImagesBatch handles `POST /runs/images/get-batch` endpoint. +func (c Controller) GetRunImagesBatch(ctx *fiber.Ctx) error { + ns, err := middleware.GetNamespaceFromContext(ctx.Context()) + if err != nil { + return api.NewInternalError("error getting namespace from context") + } + log.Debugf("getRunImages namespace: %s", ns.Code) + + req := request.GetRunImagesBatchRequest{} + if err := ctx.BodyParser(&req); err != nil { + return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error()) + } + + images, err := c.runService.GetRunImagesBatch(ctx.Context(), &req) + if err != nil { + return err + } + imagesMap, err := convertImagesToMap(images, req) + if err != nil { + return err + } + + return response.NewRunImagesBatchStreamResponse(ctx, imagesMap) +} + // GetRunsActive handles `GET /runs/active` endpoint. func (c Controller) GetRunsActive(ctx *fiber.Ctx) error { ns, err := middleware.GetNamespaceFromContext(ctx.Context()) diff --git a/pkg/api/aim/dao/models/image.go b/pkg/api/aim/dao/models/image.go new file mode 100644 index 000000000..3197bdc8a --- /dev/null +++ b/pkg/api/aim/dao/models/image.go @@ -0,0 +1,25 @@ +package models + +type ImageValues struct { + BlobURI string `json:"blob_uri"` + Caption string `json:"caption"` + Context interface{} `json:"context"` + Format string `json:"format"` + Height interface{} `json:"height"` + Index interface{} `json:"index"` + Key string `json:"key"` + SeqKey string `json:"seqKey"` + Name string `json:"name"` + Run interface{} `json:"run"` + Step int `json:"step"` + Width interface{} `json:"width"` +} + +type Image struct { + RecordRange interface{} `json:"record_range"` + IndexRange interface{} `json:"index_range"` + Name string `json:"name"` + Context interface{} `json:"context"` + Values [][]ImageValues `json:"values"` + Iters int `json:"iters"` +} diff --git a/pkg/api/aim/dao/repositories/run.go b/pkg/api/aim/dao/repositories/run.go index 7218ac4de..091d7548a 100644 --- a/pkg/api/aim/dao/repositories/run.go +++ b/pkg/api/aim/dao/repositories/run.go @@ -191,7 +191,7 @@ func (r RunRepository) GetRunByNamespaceIDAndRunID( ) (*models.Run, error) { var run models.Run if err := r.GetDB().WithContext(ctx).Select( - "ID", + "ID", "ArtifactURI", ).InnerJoins( "Experiment", database.DB.Select( diff --git a/pkg/api/aim/routes.go b/pkg/api/aim/routes.go index 6bb7725e3..f51008a78 100644 --- a/pkg/api/aim/routes.go +++ b/pkg/api/aim/routes.go @@ -69,6 +69,8 @@ func (r *Router) Init(server fiber.Router) { runs.Post("/:id/tags/new", r.controller.AddRunTag) runs.Delete("/:id/tags/:tagID", r.controller.DeleteRunTag) runs.Post("/:id/metric/get-batch/", r.controller.GetRunMetrics) + runs.Post("/:id/images/get-batch/", r.controller.GetRunImages) + runs.Post("/images/get-batch/", r.controller.GetRunImagesBatch) runs.Put("/:id/", r.controller.UpdateRun) runs.Get("/:id/logs", r.controller.GetRunLogs) runs.Delete("/:id/", r.controller.DeleteRun) diff --git a/pkg/api/aim/services/run/service.go b/pkg/api/aim/services/run/service.go index d897a9ecf..46012427d 100644 --- a/pkg/api/aim/services/run/service.go +++ b/pkg/api/aim/services/run/service.go @@ -4,6 +4,11 @@ import ( "context" "database/sql" "encoding/json" + "errors" + "fmt" + "io" + "io/fs" + "net/url" "github.com/rotisserie/eris" @@ -11,6 +16,7 @@ import ( "github.com/G-Research/fasttrackml/pkg/api/aim/common" "github.com/G-Research/fasttrackml/pkg/api/aim/dao/models" "github.com/G-Research/fasttrackml/pkg/api/aim/dao/repositories" + "github.com/G-Research/fasttrackml/pkg/api/mlflow/services/artifact/storage" "github.com/G-Research/fasttrackml/pkg/common/api" "github.com/G-Research/fasttrackml/pkg/common/dao/types" ) @@ -24,12 +30,13 @@ const ( // Service provides service layer to work with `run` business logic. type Service struct { - runRepository repositories.RunRepositoryProvider - logRepository repositories.LogRepositoryProvider - metricRepository repositories.MetricRepositoryProvider - tagRepository repositories.TagRepositoryProvider - sharedTagRepository repositories.SharedTagRepositoryProvider - artifactRepository repositories.ArtifactRepositoryProvider + runRepository repositories.RunRepositoryProvider + logRepository repositories.LogRepositoryProvider + metricRepository repositories.MetricRepositoryProvider + tagRepository repositories.TagRepositoryProvider + sharedTagRepository repositories.SharedTagRepositoryProvider + artifactStorageFactory storage.ArtifactStorageFactoryProvider + artifactRepository repositories.ArtifactRepositoryProvider } // NewService creates new Service instance. @@ -39,15 +46,17 @@ func NewService( metricRepository repositories.MetricRepositoryProvider, tagRepository repositories.TagRepositoryProvider, sharedTagRepository repositories.SharedTagRepositoryProvider, + artifactStorageFactory storage.ArtifactStorageFactoryProvider, artifactRepository repositories.ArtifactRepositoryProvider, ) *Service { return &Service{ - runRepository: runRepository, - logRepository: logRepository, - metricRepository: metricRepository, - tagRepository: tagRepository, - sharedTagRepository: sharedTagRepository, - artifactRepository: artifactRepository, + runRepository: runRepository, + logRepository: logRepository, + metricRepository: metricRepository, + tagRepository: tagRepository, + sharedTagRepository: sharedTagRepository, + artifactStorageFactory: artifactStorageFactory, + artifactRepository: artifactRepository, } } @@ -108,6 +117,86 @@ func (s Service) GetRunMetrics( return metrics, metricKeysMap, nil } +// GetRunImages returns run images. +func (s Service) GetRunImages( + ctx context.Context, namespaceID uint, runID string, req *request.GetRunImagesRequest, +) ([]models.Image, error) { + run, err := s.runRepository.GetRunByNamespaceIDAndRunID(ctx, namespaceID, runID) + if err != nil { + return nil, api.NewInternalError("error getting run by id %s: %s", runID, err) + } + + if run == nil { + return nil, api.NewResourceDoesNotExistError("run '%s' not found", runID) + } + if err != nil { + return nil, api.NewBadRequestError("unable to convert request: %s", err) + } + var images []models.Image + + for _, image := range *req { + var values [][]models.ImageValues + blobURI, err := url.JoinPath(run.ArtifactURI, image.Name) + if err != nil { + return nil, eris.Wrap(err, "error constructing blobURI") + } + imageValuesArray := []models.ImageValues{ + { + BlobURI: blobURI, + Caption: "", + Context: nil, + Format: "", + Height: nil, + Index: nil, + Key: "", + SeqKey: "", + Name: image.Name, + Run: nil, + Step: 0, + Width: nil, + }, + } + values = append(values, imageValuesArray) + + image := models.Image{ + RecordRange: nil, + IndexRange: nil, + Name: image.Name, + Context: nil, + Values: values, + Iters: 1, + } + images = append(images, image) + } + + return images, nil +} + +// GetRunImagesBatch returns run images. +func (s Service) GetRunImagesBatch( + ctx context.Context, req *request.GetRunImagesBatchRequest, +) ([]io.ReadCloser, error) { + readers := make([]io.ReadCloser, len(*req)) + for i, image := range *req { + artifactStorage, err := s.artifactStorageFactory.GetStorage(ctx, image) + if err != nil { + return nil, api.NewInternalError("Unsupported artifact storage") + } + artifactReader, err := artifactStorage.Get( + ctx, image, "", + ) + if err != nil { + msg := fmt.Sprintf("error getting artifact object for URI: %s", image) + if errors.Is(err, fs.ErrNotExist) { + return nil, api.NewResourceDoesNotExistError(msg) + } + return nil, api.NewInternalError(msg) + } + readers[i] = artifactReader + } + return readers, nil +} + // GetRunsActive returns the active runs. func (s Service) GetRunsActive( ctx context.Context, namespaceID uint, req *request.GetRunsActiveRequest, diff --git a/pkg/server/server.go b/pkg/server/server.go index 21bad64c3..9a65c15cf 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -282,8 +282,13 @@ func createApp( aimRepositories.NewMetricRepository(db.GormDB()), aimRepositories.NewTagRepository(db.GormDB()), aimRepositories.NewSharedTagRepository(db.GormDB()), + artifactStorageFactory, aimRepositories.NewArtifactRepository(db.GormDB()), ), + mlflowArtifactService.NewService( + mlflowRepositories.NewRunRepository(db.GormDB()), + artifactStorageFactory, + ), aimProjectService.NewService( aimRepositories.NewTagRepository(db.GormDB()), aimRepositories.NewRunRepository(db.GormDB()),