-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sandersaarond/mysql model metrics #98
Draft
SandersAaronD
wants to merge
20
commits into
main
Choose a base branch
from
sandersaarond/mysql-model-metrics
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
ee71829
add model metrics table
SandersAaronD d2879e0
Pin air version to continue supporting go 1.22
SandersAaronD 2b3326a
switch model_metrics endpoint to write to mysql
SandersAaronD dfd7ce8
factor model metrics endpoint into its own file, add tests
SandersAaronD f26f957
First stab at a reader for model metrics from mysql
SandersAaronD 6c041ab
add model metrics table
SandersAaronD 858c354
Merge branch 'main' into sandersaarond/mysql-model-metrics
SandersAaronD 34a59ce
Unbroken, almost there ...
SandersAaronD e068d1e
Some cleanup
SandersAaronD 2ef0aca
working tests, bugfix
SandersAaronD 0f7d9c8
Refactor some tests
SandersAaronD 84f708a
Tidy up frontend and backend contracts to store model metrics
SandersAaronD c1e05bc
Remove debug logging
SandersAaronD d9c5ea3
Merge branch 'main' into sandersaarond/mysql-model-metrics
SandersAaronD 33552fd
Add getter for model metrics from backend service
SandersAaronD 530baea
Committing a clean-ish WIP
SandersAaronD 07046ff
Add config field
SandersAaronD e178512
Remove debug logging
SandersAaronD be998de
Slight fix to error handler
SandersAaronD 83dd594
Viz working somewhat finally
SandersAaronD File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ COPY --link ai-training-api ./ | |
|
||
FROM prep AS development | ||
|
||
# Air fixed at 1.52.3 because 1.60.0 would force upgrading go to 1.23 and this is easier | ||
RUN --mount=type=cache,id=go-cache-ai-training-api,target=/opt/go go install github.com/air-verse/[email protected] | ||
|
||
FROM prep as build | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,284 @@ | ||
package api | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"net/http" | ||
"strconv" | ||
|
||
"github.com/google/uuid" | ||
"github.com/gorilla/mux" | ||
"gorm.io/gorm" | ||
|
||
"github.com/grafana/ai-training-o11y/ai-training-api/middleware" | ||
"github.com/grafana/ai-training-o11y/ai-training-api/model" | ||
) | ||
|
||
// Incoming format is an array of these | ||
type ModelMetricsSeries struct { | ||
MetricName string `json:"metric_name"` | ||
StepName string `json:"step_name"` | ||
Points []struct { | ||
Step uint32 `json:"step"` | ||
Value json.Number `json:"value"` | ||
} `json:"points"` | ||
} | ||
|
||
type AddModelMetricsResponse struct { | ||
Message string `json:"message"` | ||
MetricsCreated uint32 `json:"metricsCreated"` | ||
} | ||
|
||
// This is for return | ||
// We want an array of objects that contain grafana dataframes | ||
// For visualizing | ||
type DataFrame struct { | ||
Name string `json:"name"` | ||
Type string `json:"type"` | ||
Values []interface{} `json:"values"` | ||
} | ||
|
||
// To make it less painful to unmarshal and group them | ||
type DataFrameWrapper struct { | ||
MetricName string `json:"MetricName"` | ||
StepName string `json:"StepName"` | ||
Fields []DataFrame `json:"fields"` | ||
} | ||
|
||
type GetModelMetricsResponse []DataFrameWrapper | ||
|
||
|
||
func (a *App) addModelMetrics(tenantID string, req *http.Request) (interface{}, error) { | ||
// Extract and validate ProcessID | ||
processID, err := extractAndValidateProcessID(req) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
// Validate ProcessID exists | ||
if err := a.validateProcessExists(req.Context(), processID); err != nil { | ||
return nil, err | ||
} | ||
|
||
// Parse and validate the request body | ||
metricsData, err := parseAndValidateModelMetricsRequest(req) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
// Convert tenantID to uint64 for StackID | ||
stackID, err := strconv.ParseUint(tenantID, 10, 64) | ||
if err != nil { | ||
return nil, middleware.ErrBadRequest(fmt.Errorf("invalid tenant ID: %w", err)) | ||
} | ||
|
||
// Save the metrics and get the count of created metrics | ||
createdCount, err := a.saveModelMetrics(req.Context(), stackID, processID, metricsData) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
// Return a JSON response with success message and count of metrics inserted | ||
response := map[string]interface{}{ | ||
"message": "Metrics successfully added", | ||
"metricsCreated": createdCount, | ||
} | ||
|
||
return response, nil | ||
} | ||
|
||
func extractAndValidateProcessID(req *http.Request) (uuid.UUID, error) { | ||
vars := mux.Vars(req) | ||
if vars == nil { | ||
return uuid.Nil, fmt.Errorf("mux.Vars(req) returned nil") | ||
} | ||
|
||
processIDStr, ok := vars["id"] | ||
if !ok { | ||
return uuid.Nil, middleware.ErrBadRequest(fmt.Errorf("process ID not provided in URL")) | ||
} | ||
|
||
// This case handles when the ID is provided in the URL but is empty | ||
if processIDStr == "" { | ||
return uuid.Nil, middleware.ErrBadRequest(fmt.Errorf("process ID is empty")) | ||
} | ||
|
||
processID, err := uuid.Parse(processIDStr) | ||
if err != nil { | ||
return uuid.Nil, middleware.ErrBadRequest(fmt.Errorf("invalid process ID: %w", err)) | ||
} | ||
|
||
return processID, nil | ||
} | ||
|
||
func (a *App) validateProcessExists(ctx context.Context, processID uuid.UUID) error { | ||
var process model.Process | ||
if err := a.db(ctx).First(&process, "id = ?", processID).Error; err != nil { | ||
if errors.Is(err, gorm.ErrRecordNotFound) { | ||
return middleware.ErrNotFound(fmt.Errorf("process not found")) | ||
} | ||
return fmt.Errorf("error checking process: %w", err) | ||
} | ||
return nil | ||
} | ||
|
||
func parseAndValidateModelMetricsRequest(req *http.Request) ([]ModelMetricsSeries, error) { | ||
var metricsData []ModelMetricsSeries | ||
|
||
decoder := json.NewDecoder(req.Body) | ||
|
||
if err := decoder.Decode(&metricsData); err != nil { | ||
return nil, middleware.ErrBadRequest(fmt.Errorf("invalid JSON: %v", err)) | ||
} | ||
|
||
fmt.Println(metricsData) | ||
|
||
for _, metric := range metricsData { | ||
if err := validateModelMetricRequest(&metric); err != nil { | ||
return nil, middleware.ErrBadRequest(err) | ||
} | ||
} | ||
|
||
return metricsData, nil | ||
} | ||
|
||
func validateModelMetricRequest(m *ModelMetricsSeries) error { | ||
if len(m.MetricName) == 0 || len(m.MetricName) > 32 { | ||
return fmt.Errorf("metric name must be between 1 and 32 characters") | ||
} | ||
if len(m.StepName) == 0 || len(m.StepName) > 32 { | ||
return fmt.Errorf("step name must be between 1 and 32 characters") | ||
} | ||
for _, point := range m.Points { | ||
if point.Step == 0 { | ||
return fmt.Errorf("step must be a positive number") | ||
} | ||
if point.Value.String() == "" { | ||
return fmt.Errorf("metric value cannot be empty") | ||
} | ||
// Validate that Value is a valid number | ||
if _, err := point.Value.Float64(); err != nil { | ||
return fmt.Errorf("invalid numeric value: %v", err) | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
func (a *App) saveModelMetrics(ctx context.Context, stackID uint64, processID uuid.UUID, metricsData []ModelMetricsSeries) (int, error) { | ||
var createdCount int | ||
|
||
// Start a transaction | ||
tx := a.db(ctx).Begin() | ||
if tx.Error != nil { | ||
return 0, fmt.Errorf("error starting transaction: %w", tx.Error) | ||
} | ||
|
||
for _, metricData := range metricsData { | ||
for _, point := range metricData.Points { | ||
metric := model.ModelMetrics{ | ||
StackID: stackID, | ||
ProcessID: processID, | ||
MetricName: metricData.MetricName, | ||
StepName: metricData.StepName, | ||
Step: point.Step, | ||
MetricValue: point.Value.String(), | ||
} | ||
|
||
// Save to database | ||
if err := tx.Create(&metric).Error; err != nil { | ||
tx.Rollback() | ||
return 0, fmt.Errorf("error creating model metric: %w", err) | ||
} | ||
createdCount++ | ||
} | ||
} | ||
|
||
// Commit the transaction | ||
if err := tx.Commit().Error; err != nil { | ||
return 0, fmt.Errorf("error committing transaction: %w", err) | ||
} | ||
|
||
return createdCount, nil | ||
} | ||
|
||
func (a *App) getModelMetrics(tenantID string, req *http.Request) (interface{}, error) { | ||
|
||
// Extract and validate ProcessID | ||
processID, err := extractAndValidateProcessID(req) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
// Convert tenantID to uint64 for StackID | ||
stackID, err := strconv.ParseUint(tenantID, 10, 64) | ||
if err != nil { | ||
return nil, middleware.ErrBadRequest(fmt.Errorf("invalid tenant ID: %w", err)) | ||
} | ||
|
||
// Retrieved from DB | ||
var rows []model.ModelMetrics | ||
|
||
// Retrieve all relevant metrics from the database | ||
err = a.db(req.Context()). | ||
Where("stack_id = ? AND process_id = ?", stackID, processID). | ||
Order("metric_name ASC, step_name ASC, step ASC"). | ||
Find(&rows).Error | ||
|
||
if err != nil { | ||
return nil, fmt.Errorf("error retrieving model metrics: %w", err) | ||
} | ||
|
||
// Iterate over the metrics and build the series data | ||
var response GetModelMetricsResponse | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there's some indentation mismatch in this file from this point on |
||
var currentWrapper *DataFrameWrapper | ||
var stepSlice []interface{} | ||
var valueSlice []interface{} | ||
|
||
for _, row := range rows { | ||
currSeriesKey := fmt.Sprintf("%s_%s", row.MetricName, row.StepName) | ||
|
||
if currentWrapper == nil || currSeriesKey != fmt.Sprintf("%s_%s", currentWrapper.MetricName, currentWrapper.StepName) { | ||
// We've encountered a new series, so append the current wrapper (if it exists) and create a new one | ||
if currentWrapper != nil { | ||
response = append(response, *currentWrapper) | ||
} | ||
|
||
stepSlice = make([]interface{}, 0) | ||
valueSlice = make([]interface{}, 0) | ||
|
||
currentWrapper = &DataFrameWrapper{ | ||
MetricName: row.MetricName, | ||
StepName: row.StepName, | ||
Fields: []DataFrame{ | ||
{ | ||
Name: row.StepName, | ||
Type: "number", | ||
Values: stepSlice, | ||
}, | ||
{ | ||
Name: row.MetricName, | ||
Type: "number", | ||
Values: valueSlice, | ||
}, | ||
}, | ||
} | ||
} | ||
|
||
// Append the step and metricValue to the slices | ||
stepSlice = append(stepSlice, row.Step) | ||
valueSlice = append(valueSlice, row.MetricValue) | ||
|
||
// Update the Values in the DataFrameWrapper | ||
currentWrapper.Fields[0].Values = stepSlice | ||
currentWrapper.Fields[1].Values = valueSlice | ||
} | ||
|
||
// Append the last wrapper if it exists | ||
if currentWrapper != nil { | ||
response = append(response, *currentWrapper) | ||
} | ||
|
||
return response, nil | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we storing metrics as individual rows (one row each for a step and value) in mySQL? Why not store an array like in this struct? Is it not performant? Curious if there was some benchmarking done to choose the former.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, it would also limit us from compressing the points in the future which could end up with a lot more disk and network usage. Do we ever not pull back all of the points at once?