Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIM image artefacts #1335

Merged
merged 6 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pkg/api/aim/api/request/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ type GetRunMetricsRequest []struct {
Context map[string]string `json:"context"`
}

// GetRunImagesRequest is a request object for `POST /runs/:id/images/get-batch` endpoint.
type GetRunImagesRequest []struct {
Name string `json:"name"`
Context map[string]string `json:"context"`
}

// GetRunImagesBatchRequest is a request object for `POST /runs/images/get-batch` endpoint.
type GetRunImagesBatchRequest []string

// GetRunsActiveRequest is a request object for `GET /runs/active` endpoint.
type GetRunsActiveRequest struct {
BaseSearchRequest
Expand Down
131 changes: 126 additions & 5 deletions pkg/api/aim/api/response/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import (
"encoding/json"
"fmt"
"math"
"path/filepath"
"slices"
"strings"
"time"

"github.com/gofiber/fiber/v2"
Expand All @@ -19,6 +21,8 @@ import (
"github.com/G-Research/fasttrackml/pkg/api/aim/dao/models"
"github.com/G-Research/fasttrackml/pkg/api/aim/dao/repositories"
"github.com/G-Research/fasttrackml/pkg/api/aim/encoding"
mlflowCommon "github.com/G-Research/fasttrackml/pkg/api/mlflow/common"
"github.com/G-Research/fasttrackml/pkg/api/mlflow/services/artifact/storage"
suprjinx marked this conversation as resolved.
Show resolved Hide resolved
"github.com/G-Research/fasttrackml/pkg/common"
"github.com/G-Research/fasttrackml/pkg/database"
)
Expand All @@ -37,10 +41,10 @@ type GetRunInfoParamsPartial map[string]any
type GetRunInfoTracesPartial struct {
Tags map[string]string `json:"tags"`
Logs map[string]string `json:"logs"`
Texts map[string]string `json:"texts"`
Texts []GetRunInfoTracesMetricPartial `json:"texts"`
Audios map[string]string `json:"audios"`
Metric []GetRunInfoTracesMetricPartial `json:"metric"`
Images map[string]string `json:"images"`
Images []GetRunInfoTracesMetricPartial `json:"images"`
Figures map[string]string `json:"figures"`
LogRecords map[string]string `json:"log_records"`
Distributions map[string]string `json:"distributions"`
Expand Down Expand Up @@ -74,7 +78,7 @@ type GetRunInfoResponse struct {
}

// NewGetRunInfoResponse creates new response object for `GER runs/:id/info` endpoint.
func NewGetRunInfoResponse(run *models.Run) *GetRunInfoResponse {
func NewGetRunInfoResponse(run *models.Run, artifacts []storage.ArtifactObject) *GetRunInfoResponse {
metrics := make([]GetRunInfoTracesMetricPartial, len(run.LatestMetrics))
for i, metric := range run.LatestMetrics {
metrics[i] = GetRunInfoTracesMetricPartial{
Expand All @@ -84,6 +88,43 @@ func NewGetRunInfoResponse(run *models.Run) *GetRunInfoResponse {
}
}

imagesCounter := 0
textsCounter := 0
const imageMimeType = "image/"
const textMimeType = "text/"
for _, artifact := range artifacts {
filename := filepath.Base(artifact.Path)
mime := mlflowCommon.GetContentType(filename)
if strings.HasPrefix(mime, imageMimeType) {
imagesCounter++
} else if strings.HasPrefix(mime, textMimeType) {
textsCounter++
}
}

images := make([]GetRunInfoTracesMetricPartial, imagesCounter)
texts := make([]GetRunInfoTracesMetricPartial, textsCounter)
imagesCounter = 0
textsCounter = 0
for _, artifact := range artifacts {
filename := filepath.Base(artifact.Path)
mime := mlflowCommon.GetContentType(filename)
if strings.HasPrefix(mime, imageMimeType) {
images[imagesCounter] = GetRunInfoTracesMetricPartial{
Name: artifact.Path,
Context: nil,
LastValue: 0,
}
imagesCounter++
} else if strings.HasPrefix(mime, textMimeType) {
texts[textsCounter] = GetRunInfoTracesMetricPartial{
Name: artifact.Path,
Context: nil,
LastValue: 0,
}
textsCounter++
}
}
params := make(GetRunInfoParamsPartial, len(run.Params)+1)
for _, p := range run.Params {
params[p.Key] = p.ValueAny()
Expand All @@ -99,10 +140,10 @@ func NewGetRunInfoResponse(run *models.Run) *GetRunInfoResponse {
Traces: GetRunInfoTracesPartial{
Tags: map[string]string{},
Logs: map[string]string{},
Texts: map[string]string{},
Texts: texts,
Audios: map[string]string{},
Metric: metrics,
Images: map[string]string{},
Images: images,
Figures: map[string]string{},
LogRecords: map[string]string{},
Distributions: map[string]string{},
Expand Down Expand Up @@ -918,3 +959,83 @@ func NewActiveRunsStreamResponse(ctx *fiber.Ctx, runs []models.Run, reportProgre
})
return nil
}

// NewRunImagesStreamResponse streams the provided images to the fiber context.
func NewRunImagesStreamResponse(ctx *fiber.Ctx, images []models.Image) error {
ctx.Set("Content-Type", "application/octet-stream")
ctx.Context().SetBodyStreamWriter(func(w *bufio.Writer) {
start := time.Now()
if err := func() error {
var values [][]map[string]interface{}
var valuesResult []map[string]interface{}

for _, image := range images {
for _, valueArray := range image.Values {
for _, val := range valueArray {
valMap := map[string]interface{}{
"blob_uri": val.BlobURI,
"caption": val.Caption,
"context": val.Context,
"format": val.Format,
"height": val.Height,
"index": val.Index,
"key": val.Key,
"seqKey": val.SeqKey,
"name": val.Name,
"run": val.Run,
"step": val.Step,
"width": val.Width,
}
valuesResult = append(valuesResult, valMap)
}
}

values = append(values, valuesResult)
imgMap := map[string]interface{}{
"record_range": image.RecordRange,
"index_range": image.IndexRange,
"name": image.Name,
"context": image.Context,
"values": values,
"iters": image.Iters,
}

if err := encoding.EncodeTree(w, imgMap); err != nil {
return err
}
}

return nil
}(); err != nil {
log.Errorf("Error encountered in %s %s: error streaming active runs: %s", ctx.Method(), ctx.Path(), err)
}

log.Infof("body - %s %s %s", time.Since(start), ctx.Method(), ctx.Path())
})
return nil
}

// NewRunImagesBatchStreamResponse streams the provided images to the fiber context.
func NewRunImagesBatchStreamResponse(ctx *fiber.Ctx, imagesMap map[string]any) error {
ctx.Context().Response.SetBodyStreamWriter(func(w *bufio.Writer) {
start := time.Now()
if err := func() error {
if err := encoding.EncodeTree(w, imagesMap); err != nil {
return err
}
if err := w.Flush(); err != nil {
return eris.Wrap(err, "error flushing output stream")
}
return nil
}(); err != nil {
log.Errorf(
"error encountered in %s %s: error streaming artifact: %s",
ctx.Method(),
ctx.Path(),
err,
)
}
log.Infof("body - %s %s %s", time.Since(start), ctx.Method(), ctx.Path())
})
return nil
}
4 changes: 4 additions & 0 deletions pkg/api/aim/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ import (
"github.com/G-Research/fasttrackml/pkg/api/aim/services/project"
"github.com/G-Research/fasttrackml/pkg/api/aim/services/run"
"github.com/G-Research/fasttrackml/pkg/api/aim/services/tag"
"github.com/G-Research/fasttrackml/pkg/api/mlflow/services/artifact"
)

// Controller handles all the input HTTP requests.
type Controller struct {
tagService *tag.Service
appService *app.Service
runService *run.Service
artifactService *artifact.Service
projectService *project.Service
dashboardService *dashboard.Service
experimentService *experiment.Service
Expand All @@ -24,6 +26,7 @@ func NewController(
tagService *tag.Service,
appService *app.Service,
runService *run.Service,
artifactService *artifact.Service,
projectService *project.Service,
dashboardService *dashboard.Service,
experimentService *experiment.Service,
Expand All @@ -32,6 +35,7 @@ func NewController(
tagService: tagService,
appService: appService,
runService: runService,
artifactService: artifactService,
projectService: projectService,
dashboardService: dashboardService,
experimentService: experimentService,
Expand Down
21 changes: 21 additions & 0 deletions pkg/api/aim/controller/helpers.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
package controller

import (
"bytes"
"io"

"github.com/gofiber/fiber/v2"
"github.com/rotisserie/eris"

"github.com/G-Research/fasttrackml/pkg/api/aim/api/request"
"github.com/G-Research/fasttrackml/pkg/common/api"
)

Expand All @@ -16,3 +21,19 @@ func convertError(err error) error {
}
return err
}

func convertImagesToMap(
images []io.ReadCloser, req request.GetRunImagesBatchRequest,
) (map[string]any, error) {
imagesMap := make(map[string]any)

for i, image := range images {
var buffer bytes.Buffer
_, err := io.CopyBuffer(&buffer, image, make([]byte, 4096))
if err != nil {
return nil, eris.Wrap(err, "error copying artifact Reader to output stream")
}
imagesMap[req[i]] = buffer.Bytes()
}
return imagesMap, nil
}
60 changes: 58 additions & 2 deletions pkg/api/aim/controller/runs.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/G-Research/fasttrackml/pkg/api/aim/api/request"
"github.com/G-Research/fasttrackml/pkg/api/aim/api/response"
"github.com/G-Research/fasttrackml/pkg/api/aim/services/run"
mlflowRequest "github.com/G-Research/fasttrackml/pkg/api/mlflow/api/request"
"github.com/G-Research/fasttrackml/pkg/common/api"
"github.com/G-Research/fasttrackml/pkg/common/middleware"
)
Expand All @@ -34,12 +35,21 @@ func (c Controller) GetRunInfo(ctx *fiber.Ctx) error {
return err
}

resp := response.NewGetRunInfoResponse(runInfo)
artifactReq := mlflowRequest.ListArtifactsRequest{
RunUUID: req.ID,
}

_, artifacts, err := c.artifactService.ListArtifacts(ctx.Context(), ns, &artifactReq)
if err != nil {
return err
}

resp := response.NewGetRunInfoResponse(runInfo, artifacts)
log.Debugf("getRunInfo response: %#v", resp)
return ctx.JSON(resp)
}

// GetRunMetrics handles `GET /runs/:id/metric/get-batch` endpoint.
// GetRunMetrics handles `POST /runs/:id/metric/get-batch` endpoint.
func (c Controller) GetRunMetrics(ctx *fiber.Ctx) error {
ns, err := middleware.GetNamespaceFromContext(ctx.Context())
if err != nil {
Expand All @@ -62,6 +72,52 @@ func (c Controller) GetRunMetrics(ctx *fiber.Ctx) error {
return ctx.JSON(resp)
}

// GetRunImages handles `POST /runs/:id/images/get-batch` endpoint.
func (c Controller) GetRunImages(ctx *fiber.Ctx) error {
ns, err := middleware.GetNamespaceFromContext(ctx.Context())
if err != nil {
return api.NewInternalError("error getting namespace from context")
}
log.Debugf("getRunImages namespace: %s", ns.Code)

req := request.GetRunImagesRequest{}
if err := ctx.BodyParser(&req); err != nil {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}

images, err := c.runService.GetRunImages(ctx.Context(), ns.ID, ctx.Params("id"), &req)
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, err.Error())
}

return response.NewRunImagesStreamResponse(ctx, images)
}

// GetRunImagesBatch handles `POST /runs/images/get-batch` endpoint.
func (c Controller) GetRunImagesBatch(ctx *fiber.Ctx) error {
ns, err := middleware.GetNamespaceFromContext(ctx.Context())
if err != nil {
return api.NewInternalError("error getting namespace from context")
}
log.Debugf("getRunImages namespace: %s", ns.Code)

req := request.GetRunImagesBatchRequest{}
if err := ctx.BodyParser(&req); err != nil {
return fiber.NewError(fiber.StatusUnprocessableEntity, err.Error())
}

images, err := c.runService.GetRunImagesBatch(ctx.Context(), &req)
if err != nil {
return err
}
imagesMap, err := convertImagesToMap(images, req)
if err != nil {
return err
}

return response.NewRunImagesBatchStreamResponse(ctx, imagesMap)
}

// GetRunsActive handles `GET /runs/active` endpoint.
func (c Controller) GetRunsActive(ctx *fiber.Ctx) error {
ns, err := middleware.GetNamespaceFromContext(ctx.Context())
Expand Down
25 changes: 25 additions & 0 deletions pkg/api/aim/dao/models/image.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package models

type ImageValues struct {
BlobURI string `json:"blob_uri"`
Caption string `json:"caption"`
Context interface{} `json:"context"`
Format string `json:"format"`
Height interface{} `json:"height"`
Index interface{} `json:"index"`
Key string `json:"key"`
SeqKey string `json:"seqKey"`
Name string `json:"name"`
Run interface{} `json:"run"`
Step int `json:"step"`
Width interface{} `json:"width"`
}

type Image struct {
fabiovincenzi marked this conversation as resolved.
Show resolved Hide resolved
RecordRange interface{} `json:"record_range"`
IndexRange interface{} `json:"index_range"`
Name string `json:"name"`
Context interface{} `json:"context"`
Values [][]ImageValues `json:"values"`
suprjinx marked this conversation as resolved.
Show resolved Hide resolved
Iters int `json:"iters"`
}
2 changes: 1 addition & 1 deletion pkg/api/aim/dao/repositories/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func (r RunRepository) GetRunByNamespaceIDAndRunID(
) (*models.Run, error) {
var run models.Run
if err := r.GetDB().WithContext(ctx).Select(
"ID",
"ID", "ArtifactURI",
).InnerJoins(
"Experiment",
database.DB.Select(
Expand Down
2 changes: 2 additions & 0 deletions pkg/api/aim/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ func (r *Router) Init(server fiber.Router) {
runs.Post("/:id/tags/new", r.controller.AddRunTag)
runs.Delete("/:id/tags/:tagID", r.controller.DeleteRunTag)
runs.Post("/:id/metric/get-batch/", r.controller.GetRunMetrics)
runs.Post("/:id/images/get-batch/", r.controller.GetRunImages)
runs.Post("/images/get-batch/", r.controller.GetRunImagesBatch)
runs.Put("/:id/", r.controller.UpdateRun)
runs.Get("/:id/logs", r.controller.GetRunLogs)
runs.Delete("/:id/", r.controller.DeleteRun)
Expand Down
Loading