diff --git a/ai-training-api/Dockerfile b/ai-training-api/Dockerfile index 5b9071e..4a164bf 100644 --- a/ai-training-api/Dockerfile +++ b/ai-training-api/Dockerfile @@ -28,6 +28,7 @@ COPY --link ai-training-api ./ FROM prep AS development +# Air fixed at 1.52.3 because 1.60.0 would force upgrading go to 1.23 and this is easier RUN --mount=type=cache,id=go-cache-ai-training-api,target=/opt/go go install github.com/air-verse/air@v1.52.3 FROM prep as build diff --git a/ai-training-api/app/api.go b/ai-training-api/app/api.go index d8acaf6..a74ec41 100644 --- a/ai-training-api/app/api.go +++ b/ai-training-api/app/api.go @@ -1,7 +1,6 @@ package api import ( - "bytes" "encoding/json" "errors" "fmt" @@ -34,7 +33,7 @@ func (app *App) registerAPI(router *mux.Router) { router.HandleFunc("/processes", requestMiddleware(app.listProcess)).Methods("GET") router.HandleFunc("/process/{id}/update-metadata", requestMiddleware(app.updateProcessMetadata)).Methods("POST") router.HandleFunc("/process/{id}/model-metrics", requestMiddleware(app.addModelMetrics)).Methods("POST") - + router.HandleFunc("/process/{id}/model-metrics", requestMiddleware(app.getModelMetrics)).Methods("GET") router.HandleFunc("/group/new", requestMiddleware(app.registerNewGroup)).Methods("POST") router.HandleFunc("/group/{id}", requestMiddleware(app.getGroup)).Methods("GET") router.HandleFunc("/groups", requestMiddleware(app.getGroups)).Methods("GET") @@ -401,49 +400,6 @@ func (a *App) deleteGroup(tenantID string, req *http.Request) (interface{}, erro return nil, err } -// addModelMetrics proxies logs related model-metrics to Loki. -func (a *App) addModelMetrics(tenantID string, req *http.Request) (interface{}, error) { - // TODO: Integrate with GCom API to find the corresponding Loki TenantID associated - // with the tenantID. - - body, err := io.ReadAll(req.Body) - if err != nil { - return nil, middleware.ErrBadRequest(err) - } - defer req.Body.Close() - - level.Info(a.logger).Log("msg", "forwarding model-metrics to Loki", "tenantID", tenantID, "body", string(body)) - - // Forward the request to the Loki endpoint. - httpClient := &http.Client{} - lokiEndpoint := a.lokiAddress - lokiReq, err := http.NewRequest("POST", lokiEndpoint, bytes.NewBuffer(body)) - if err != nil { - return nil, middleware.ErrBadRequest(err) - } - lokiReq.Header.Set("Content-Type", "application/json") - if a.lokiTenant != "" { - level.Info(a.logger).Log("msg", "adding X-Scope-OrgID header to loki request", "received_org_id", req.Header.Get("X-Scope-OrgID"), "forwarded_org_id", a.lokiTenant) - lokiReq.Header.Set("X-Scope-OrgID", a.lokiTenant) - } - lokiResp, err := httpClient.Do(lokiReq) - if err != nil { - level.Error(a.logger).Log("msg", "error forwarding model-metrics to Loki", "err", err) - return nil, middleware.ErrBadRequest(err) - } - defer lokiResp.Body.Close() - - // Read the response body. - lokiRespBody, err := io.ReadAll(lokiResp.Body) - if err != nil { - level.Error(a.logger).Log("msg", "error reading response body from Loki", "err", err) - return nil, middleware.ErrBadRequest(err) - } - - // Return the response body. - return string(lokiRespBody), nil -} - func namedParam(req *http.Request, name string) string { return mux.Vars(req)[name] } diff --git a/ai-training-api/app/app.go b/ai-training-api/app/app.go index d8ff7be..18d6fcf 100644 --- a/ai-training-api/app/app.go +++ b/ai-training-api/app/app.go @@ -70,17 +70,25 @@ func New( return nil, fmt.Errorf("error migrating Process table: %w", err) } level.Info(logger).Log("msg", "checking tables", "process_table_exists", db.Migrator().HasTable(&model.Process{})) + err = db.AutoMigrate(&model.Group{}) if err != nil { return nil, fmt.Errorf("error migrating Group table: %w", err) } level.Info(logger).Log("msg", "checking tables", "group_table_exists", db.Migrator().HasTable(&model.Group{})) + err = db.AutoMigrate(&model.MetadataKV{}) if err != nil { return nil, fmt.Errorf("error migrating MetadataKV table: %w", err) } level.Info(logger).Log("msg", "checking tables", "metadata_kv_table_exists", db.Migrator().HasTable(&model.MetadataKV{})) + err = db.AutoMigrate(&model.ModelMetrics{}) + if err != nil { + return nil, fmt.Errorf("error migrating ModelMetrics table: %w", err) + } + level.Info(logger).Log("msg", "checking tables", "model_metrics_table_exists", db.Migrator().HasTable(&model.MetadataKV{})) + // Create server and router. serverLogLevel := &dskit_log.Level{} serverLogLevel.Set(promlogConfig.Level.String()) diff --git a/ai-training-api/app/model_metrics.go b/ai-training-api/app/model_metrics.go new file mode 100644 index 0000000..5fe2dd8 --- /dev/null +++ b/ai-training-api/app/model_metrics.go @@ -0,0 +1,284 @@ +package api + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + + "github.com/google/uuid" + "github.com/gorilla/mux" + "gorm.io/gorm" + + "github.com/grafana/ai-training-o11y/ai-training-api/middleware" + "github.com/grafana/ai-training-o11y/ai-training-api/model" +) + +// Incoming format is an array of these +type ModelMetricsSeries struct { + MetricName string `json:"metric_name"` + StepName string `json:"step_name"` + Points []struct { + Step uint32 `json:"step"` + Value json.Number `json:"value"` + } `json:"points"` +} + +type AddModelMetricsResponse struct { + Message string `json:"message"` + MetricsCreated uint32 `json:"metricsCreated"` +} + +// This is for return +// We want an array of objects that contain grafana dataframes +// For visualizing +type DataFrame struct { + Name string `json:"name"` + Type string `json:"type"` + Values []interface{} `json:"values"` +} + +// To make it less painful to unmarshal and group them +type DataFrameWrapper struct { + MetricName string `json:"MetricName"` + StepName string `json:"StepName"` + Fields []DataFrame `json:"fields"` +} + +type GetModelMetricsResponse []DataFrameWrapper + + +func (a *App) addModelMetrics(tenantID string, req *http.Request) (interface{}, error) { + // Extract and validate ProcessID + processID, err := extractAndValidateProcessID(req) + if err != nil { + return nil, err + } + + // Validate ProcessID exists + if err := a.validateProcessExists(req.Context(), processID); err != nil { + return nil, err + } + + // Parse and validate the request body + metricsData, err := parseAndValidateModelMetricsRequest(req) + if err != nil { + return nil, err + } + + // Convert tenantID to uint64 for StackID + stackID, err := strconv.ParseUint(tenantID, 10, 64) + if err != nil { + return nil, middleware.ErrBadRequest(fmt.Errorf("invalid tenant ID: %w", err)) + } + + // Save the metrics and get the count of created metrics + createdCount, err := a.saveModelMetrics(req.Context(), stackID, processID, metricsData) + if err != nil { + return nil, err + } + + // Return a JSON response with success message and count of metrics inserted + response := map[string]interface{}{ + "message": "Metrics successfully added", + "metricsCreated": createdCount, + } + + return response, nil +} + +func extractAndValidateProcessID(req *http.Request) (uuid.UUID, error) { + vars := mux.Vars(req) + if vars == nil { + return uuid.Nil, fmt.Errorf("mux.Vars(req) returned nil") + } + + processIDStr, ok := vars["id"] + if !ok { + return uuid.Nil, middleware.ErrBadRequest(fmt.Errorf("process ID not provided in URL")) + } + + // This case handles when the ID is provided in the URL but is empty + if processIDStr == "" { + return uuid.Nil, middleware.ErrBadRequest(fmt.Errorf("process ID is empty")) + } + + processID, err := uuid.Parse(processIDStr) + if err != nil { + return uuid.Nil, middleware.ErrBadRequest(fmt.Errorf("invalid process ID: %w", err)) + } + + return processID, nil +} + +func (a *App) validateProcessExists(ctx context.Context, processID uuid.UUID) error { + var process model.Process + if err := a.db(ctx).First(&process, "id = ?", processID).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return middleware.ErrNotFound(fmt.Errorf("process not found")) + } + return fmt.Errorf("error checking process: %w", err) + } + return nil +} + +func parseAndValidateModelMetricsRequest(req *http.Request) ([]ModelMetricsSeries, error) { + var metricsData []ModelMetricsSeries + + decoder := json.NewDecoder(req.Body) + + if err := decoder.Decode(&metricsData); err != nil { + return nil, middleware.ErrBadRequest(fmt.Errorf("invalid JSON: %v", err)) + } + + fmt.Println(metricsData) + + for _, metric := range metricsData { + if err := validateModelMetricRequest(&metric); err != nil { + return nil, middleware.ErrBadRequest(err) + } + } + + return metricsData, nil +} + +func validateModelMetricRequest(m *ModelMetricsSeries) error { + if len(m.MetricName) == 0 || len(m.MetricName) > 32 { + return fmt.Errorf("metric name must be between 1 and 32 characters") + } + if len(m.StepName) == 0 || len(m.StepName) > 32 { + return fmt.Errorf("step name must be between 1 and 32 characters") + } + for _, point := range m.Points { + if point.Step == 0 { + return fmt.Errorf("step must be a positive number") + } + if point.Value.String() == "" { + return fmt.Errorf("metric value cannot be empty") + } + // Validate that Value is a valid number + if _, err := point.Value.Float64(); err != nil { + return fmt.Errorf("invalid numeric value: %v", err) + } + } + return nil +} + +func (a *App) saveModelMetrics(ctx context.Context, stackID uint64, processID uuid.UUID, metricsData []ModelMetricsSeries) (int, error) { + var createdCount int + + // Start a transaction + tx := a.db(ctx).Begin() + if tx.Error != nil { + return 0, fmt.Errorf("error starting transaction: %w", tx.Error) + } + + for _, metricData := range metricsData { + for _, point := range metricData.Points { + metric := model.ModelMetrics{ + StackID: stackID, + ProcessID: processID, + MetricName: metricData.MetricName, + StepName: metricData.StepName, + Step: point.Step, + MetricValue: point.Value.String(), + } + + // Save to database + if err := tx.Create(&metric).Error; err != nil { + tx.Rollback() + return 0, fmt.Errorf("error creating model metric: %w", err) + } + createdCount++ + } + } + + // Commit the transaction + if err := tx.Commit().Error; err != nil { + return 0, fmt.Errorf("error committing transaction: %w", err) + } + + return createdCount, nil +} + +func (a *App) getModelMetrics(tenantID string, req *http.Request) (interface{}, error) { + + // Extract and validate ProcessID + processID, err := extractAndValidateProcessID(req) + if err != nil { + return nil, err + } + + // Convert tenantID to uint64 for StackID + stackID, err := strconv.ParseUint(tenantID, 10, 64) + if err != nil { + return nil, middleware.ErrBadRequest(fmt.Errorf("invalid tenant ID: %w", err)) + } + + // Retrieved from DB + var rows []model.ModelMetrics + + // Retrieve all relevant metrics from the database + err = a.db(req.Context()). + Where("stack_id = ? AND process_id = ?", stackID, processID). + Order("metric_name ASC, step_name ASC, step ASC"). + Find(&rows).Error + + if err != nil { + return nil, fmt.Errorf("error retrieving model metrics: %w", err) + } + + // Iterate over the metrics and build the series data + var response GetModelMetricsResponse + var currentWrapper *DataFrameWrapper + var stepSlice []interface{} + var valueSlice []interface{} + + for _, row := range rows { + currSeriesKey := fmt.Sprintf("%s_%s", row.MetricName, row.StepName) + + if currentWrapper == nil || currSeriesKey != fmt.Sprintf("%s_%s", currentWrapper.MetricName, currentWrapper.StepName) { + // We've encountered a new series, so append the current wrapper (if it exists) and create a new one + if currentWrapper != nil { + response = append(response, *currentWrapper) + } + + stepSlice = make([]interface{}, 0) + valueSlice = make([]interface{}, 0) + + currentWrapper = &DataFrameWrapper{ + MetricName: row.MetricName, + StepName: row.StepName, + Fields: []DataFrame{ + { + Name: row.StepName, + Type: "number", + Values: stepSlice, + }, + { + Name: row.MetricName, + Type: "number", + Values: valueSlice, + }, + }, + } + } + + // Append the step and metricValue to the slices + stepSlice = append(stepSlice, row.Step) + valueSlice = append(valueSlice, row.MetricValue) + + // Update the Values in the DataFrameWrapper + currentWrapper.Fields[0].Values = stepSlice + currentWrapper.Fields[1].Values = valueSlice + } + + // Append the last wrapper if it exists + if currentWrapper != nil { + response = append(response, *currentWrapper) + } + + return response, nil +} diff --git a/ai-training-api/app/model_metrics_test.go b/ai-training-api/app/model_metrics_test.go new file mode 100644 index 0000000..8202d46 --- /dev/null +++ b/ai-training-api/app/model_metrics_test.go @@ -0,0 +1,401 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "sync" + "testing" + + "github.com/go-kit/log" + "github.com/google/uuid" + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + + "github.com/grafana/ai-training-o11y/ai-training-api/model" +) + +type testApp struct { + App +} + +func (a *testApp) db(ctx context.Context) *gorm.DB { + return a.App._db +} + +func setupTestDB(t *testing.T) (*gorm.DB, func()) { + db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) + require.NoError(t, err) + + err = db.AutoMigrate(&model.Process{}, &model.ModelMetrics{}) + require.NoError(t, err) + + return db, func() { + sqlDB, err := db.DB() + require.NoError(t, err) + sqlDB.Close() + } +} + +func TestExtractAndValidateProcessID(t *testing.T) { + tests := []struct { + name string + url string + expectedID uuid.UUID + expectedErrMsg string + }{ + { + name: "Valid UUID", + url: "/process/123e4567-e89b-12d3-a456-426614174000/model-metrics", + expectedID: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"), + }, + { + name: "Invalid UUID", + url: "/process/invalid-uuid/model-metrics", + expectedErrMsg: "invalid process ID", + }, + { + name: "Empty ID", + url: "/process//model-metrics", + expectedErrMsg: "process ID is empty", + }, + { + name: "No ID in URL", + url: "/process/model-metrics", + expectedErrMsg: "process ID not provided in URL", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + router := mux.NewRouter() + router.HandleFunc("/process/{id}/model-metrics", func(w http.ResponseWriter, r *http.Request) { + processID, err := extractAndValidateProcessID(r) + + if tt.expectedErrMsg != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedErrMsg) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedID, processID) + } + }).Methods("POST") + + req, err := http.NewRequest("POST", tt.url, nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + }) + } +} + +func TestValidateProcessExists(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + if db == nil { + t.Fatal("setupTestDB returned a nil database") + } + + app := &testApp{ + App: App{ + _db: db, + dbMux: &sync.Mutex{}, + logger: log.NewLogfmtLogger(log.NewSyncWriter(os.Stderr)), + }, + } + + tests := []struct { + name string + setupDB func(*gorm.DB) + expectedErrMsg string + }{ + { + name: "Process exists", + setupDB: func(db *gorm.DB) { + process := model.Process{ID: uuid.New()} + result := db.Create(&process) + if result.Error != nil { + t.Fatalf("Failed to create process: %v", result.Error) + } + }, + }, + { + name: "Process does not exist", + setupDB: func(db *gorm.DB) {}, + expectedErrMsg: "process not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupDB(db) + + // Use a fixed UUID for testing to ensure we're looking for the correct process + testUUID := uuid.New() + if tt.name == "Process exists" { + process := model.Process{ID: testUUID} + result := db.Create(&process) + if result.Error != nil { + t.Fatalf("Failed to create process: %v", result.Error) + } + } + + err := app.validateProcessExists(context.Background(), testUUID) + + if tt.expectedErrMsg != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedErrMsg) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestParseAndValidateModelMetricsRequest(t *testing.T) { + tests := []struct { + name string + requestBody interface{} + expectedLen int + expectedErrMsg string + }{ + { + name: "Valid request", + requestBody: []ModelMetricsSeries{ + { + MetricName: "accuracy", + StepName: "training", + Points: []struct { + Step uint32 `json:"step"` + Value json.Number `json:"value"` + }{ + {Step: 1, Value: "0.75"}, + {Step: 2, Value: "0.85"}, + }, + }, + }, + expectedLen: 1, + }, + { + name: "Invalid metric name", + requestBody: []ModelMetricsSeries{ + { + MetricName: "", + StepName: "training", + Points: []struct { + Step uint32 `json:"step"` + Value json.Number `json:"value"` + }{{Step: 1, Value: "0.75"}}, + }, + }, + expectedErrMsg: "metric name must be between 1 and 32 characters", + }, + { + name: "Invalid step name", + requestBody: []ModelMetricsSeries{ + { + MetricName: "accuracy", + StepName: "", + Points: []struct { + Step uint32 `json:"step"` + Value json.Number `json:"value"` + }{{Step: 1, Value: "0.75"}}, + }, + }, + expectedErrMsg: "step name must be between 1 and 32 characters", + }, + { + name: "Invalid step value", + requestBody: []ModelMetricsSeries{ + { + MetricName: "accuracy", + StepName: "training", + Points: []struct { + Step uint32 `json:"step"` + Value json.Number `json:"value"` + }{{Step: 0, Value: "0.75"}}, + }, + }, + expectedErrMsg: "step must be a positive number", + }, + { + name: "Invalid metric value (empty string)", + requestBody: []interface{}{ + map[string]interface{}{ + "metric_name": "accuracy", + "step_name": "training", + "points": []interface{}{ + map[string]interface{}{ + "step": 1, + "value": "", + }, + }, + }, + }, + expectedErrMsg: "invalid JSON: json: invalid number literal, trying to unmarshal \"\\\"\\\"\" into Number", + }, + { + name: "Invalid metric value (not a number)", + requestBody: []interface{}{ + map[string]interface{}{ + "metric_name": "accuracy", + "step_name": "training", + "points": []interface{}{ + map[string]interface{}{ + "step": 1, + "value": "not a number", + }, + }, + }, + }, + expectedErrMsg: "invalid JSON: json: invalid number literal, trying to unmarshal \"\\\"not a number\\\"\" into Number", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body, err := json.Marshal(tt.requestBody) + require.NoError(t, err) + + req, err := http.NewRequest("POST", "/process/123/model-metrics", bytes.NewBuffer(body)) + require.NoError(t, err) + + result, err := parseAndValidateModelMetricsRequest(req) + + if tt.expectedErrMsg != "" { + if err == nil { + t.Errorf("Expected error containing '%s', but got nil error", tt.expectedErrMsg) + } else { + assert.Contains(t, err.Error(), tt.expectedErrMsg) + } + } else { + assert.NoError(t, err) + assert.Len(t, result, tt.expectedLen) + } + + // Add this line for debugging + t.Logf("Test case '%s': error = %v, result = %+v", tt.name, err, result) + }) + } +} + +func TestGetModelMetrics(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + app := &testApp{ + App: App{_db: db}, + } + + type testCase struct { + name string + metrics []model.ModelMetrics + check func(*testing.T, GetModelMetricsResponse) + } + + testCases := []testCase{ + { + name: "Basic case", + metrics: []model.ModelMetrics{ + {MetricName: "accuracy", StepName: "train", Step: 1, MetricValue: "0.75"}, + {MetricName: "accuracy", StepName: "train", Step: 2, MetricValue: "0.80"}, + {MetricName: "loss", StepName: "train", Step: 1, MetricValue: "0.5"}, + {MetricName: "loss", StepName: "train", Step: 2, MetricValue: "0.4"}, + }, + check: func(t *testing.T, response GetModelMetricsResponse) { + require.Len(t, response, 2) // Two DataFrameWrappers: one for accuracy, one for loss + + // Check accuracy metrics + require.Equal(t, "accuracy", response[0].MetricName) + require.Equal(t, "train", response[0].StepName) + require.Len(t, response[0].Fields, 2) + require.Equal(t, []interface{}{uint32(1), uint32(2)}, response[0].Fields[0].Values) + require.Equal(t, []interface{}{"0.75", "0.80"}, response[0].Fields[1].Values) + + // Check loss metrics + require.Equal(t, "loss", response[1].MetricName) + require.Equal(t, "train", response[1].StepName) + require.Len(t, response[1].Fields, 2) + require.Equal(t, []interface{}{uint32(1), uint32(2)}, response[1].Fields[0].Values) + require.Equal(t, []interface{}{"0.5", "0.4"}, response[1].Fields[1].Values) + }, + }, + { + name: "No metrics", + metrics: []model.ModelMetrics{}, + check: func(t *testing.T, response GetModelMetricsResponse) { + require.Len(t, response, 0) + }, + }, + { + name: "Single metric", + metrics: []model.ModelMetrics{ + {MetricName: "accuracy", StepName: "train", Step: 1, MetricValue: "0.75"}, + }, + check: func(t *testing.T, response GetModelMetricsResponse) { + require.Len(t, response, 1) + require.Equal(t, "accuracy", response[0].MetricName) + require.Equal(t, "train", response[0].StepName) + require.Len(t, response[0].Fields, 2) + require.Equal(t, []interface{}{uint32(1)}, response[0].Fields[0].Values) + require.Equal(t, []interface{}{"0.75"}, response[0].Fields[1].Values) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Clear the database + db.Exec("DELETE FROM model_metrics") + + processID := uuid.New() + for i := range tc.metrics { + tc.metrics[i].ProcessID = processID + } + insertTestMetrics(t, db, tc.metrics) + + req := setupTestRequest(processID.String()) + result, err := app.getModelMetrics("0", req) + require.NoError(t, err) + response, ok := result.(GetModelMetricsResponse) + require.True(t, ok) + + // Print out the entire response for debugging + t.Logf("Response: %+v", response) + + if tc.name == "Basic case" { + // Print out the Values slices for debugging + t.Logf("Step Values: %+v", response[0].Fields[0].Values) + t.Logf("Metric Values: %+v", response[0].Fields[1].Values) + } + + // Run the check function + tc.check(t, response) + }) + } +} + +func insertTestMetrics(t *testing.T, db *gorm.DB, metrics []model.ModelMetrics) { + for _, metric := range metrics { + err := db.Create(&metric).Error + require.NoError(t, err) + } +} + +func setupTestRequest(processID string) *http.Request { + req, _ := http.NewRequest("GET", "/process/"+processID+"/model-metrics", nil) + vars := map[string]string{ + "id": processID, + } + return mux.SetURLVars(req, vars) +} diff --git a/ai-training-api/go.mod b/ai-training-api/go.mod index 83c4c09..f1a73c2 100644 --- a/ai-training-api/go.mod +++ b/ai-training-api/go.mod @@ -10,6 +10,7 @@ require ( github.com/gorilla/mux v1.8.1 github.com/grafana/dskit v0.0.0-20240411172511-de4086540f6f github.com/jeremywohl/flatten/v2 v2.0.0-20211013061545-07e4a09fb8e4 + github.com/mattn/go-sqlite3 v1.14.17 github.com/prometheus/client_golang v1.19.0 github.com/prometheus/common v0.52.3 github.com/stretchr/testify v1.9.0 @@ -43,7 +44,6 @@ require ( github.com/jinzhu/now v1.1.5 // indirect github.com/jpillora/backoff v1.0.0 // indirect github.com/klauspost/compress v1.17.3 // indirect - github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f // indirect github.com/opentracing-contrib/go-grpc v0.0.0-20210225150812-73cb765af46e // indirect github.com/opentracing-contrib/go-stdlib v1.0.0 // indirect @@ -55,6 +55,7 @@ require ( github.com/prometheus/exporter-toolkit v0.10.1-0.20230714054209-2f4150c63f97 // indirect github.com/prometheus/procfs v0.12.0 // indirect github.com/sercand/kuberesolver/v5 v5.1.1 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/uber/jaeger-client-go v2.28.0+incompatible // indirect github.com/uber/jaeger-lib v2.2.0+incompatible // indirect github.com/xhit/go-str2duration/v2 v2.1.0 // indirect diff --git a/ai-training-api/model/model_metrics.go b/ai-training-api/model/model_metrics.go new file mode 100644 index 0000000..8da335b --- /dev/null +++ b/ai-training-api/model/model_metrics.go @@ -0,0 +1,24 @@ +package model + +import ( + "github.com/google/uuid" + "gorm.io/gorm" +) + +type ModelMetrics struct { + StackID uint64 `json:"stack_id" gorm:"not null;primaryKey"` + ProcessID uuid.UUID `json:"process_id" gorm:"type:char(36);not null;primaryKey;foreignKey:ProcessID;references:ID"` // Foreign key + MetricName string `json:"metric_name" gorm:"size:32;not null;primaryKey"` + StepName string `json:"step_name" gorm:"size:32;not null;primaryKey"` + Step uint32 `json:"step" gorm:"not null;primaryKey"` + MetricValue string `json:"metric_value" gorm:"size:64;not null"` + + Process Process `gorm:"foreignKey:ProcessID;references:ID"` // Relationship definition +} +// Add a custom hook if necessary for additional logic. +// Example: AfterCreate hook for custom logic +func (m *ModelMetrics) AfterCreate(tx *gorm.DB) error { + // Custom logic after creating a metric entry + tx.Logger.Info(tx.Statement.Context, "AfterCreate hook called for ModelMetrics") + return nil +} diff --git a/grafana-aitraining-app/src/components/App/App.tsx b/grafana-aitraining-app/src/components/App/App.tsx index 009bc8d..e821241 100644 --- a/grafana-aitraining-app/src/components/App/App.tsx +++ b/grafana-aitraining-app/src/components/App/App.tsx @@ -10,7 +10,7 @@ import { ReactRouter5Adapter } from 'use-query-params/adapters/react-router-5'; import { PluginPropsContext } from '../../utils/utils.plugin'; import { Routes } from '../Routes'; import { TrainingApiDatasource } from '../../datasource/Datasource'; -import { makeProcessGetter } from 'utils/api'; +import { doRequest, makeProcessGetter } from 'utils/api'; export class App extends React.PureComponent { componentDidMount() { @@ -31,6 +31,14 @@ export class App extends React.PureComponent { getProcesses = makeProcessGetter(this.props.meta.id); + getModelMetrics = (processUuid: string) => { + const response = doRequest({ + url: `/api/plugins/${this.props.meta.id}/resources/metadata/api/v1/process/${processUuid}/model-metrics`, + method: 'GET', + }); + return response; + } + render() { return ( { objectToSearchString: stringify, }} > - + diff --git a/grafana-aitraining-app/src/pages/Home/Home.tsx b/grafana-aitraining-app/src/pages/Home/Home.tsx index f2e8109..24984d3 100644 --- a/grafana-aitraining-app/src/pages/Home/Home.tsx +++ b/grafana-aitraining-app/src/pages/Home/Home.tsx @@ -64,9 +64,12 @@ export const Home = () => { case 404: setProcessesQueryStatus('notFound'); break; - case 500 || 502: - setProcessesQueryStatus('serverError'); - break; + case 500: + setProcessesQueryStatus('serverError'); + break; + case 502: + setProcessesQueryStatus('serverError'); + break; default: setProcessesQueryStatus('error'); } @@ -115,7 +118,6 @@ export const Home = () => { ) : ( - -
-

Organized Data:

- {organizedLokiData ? ( -
{JSON.stringify(organizedLokiData, null, 2)}
- ) : ( -

No organized data available

- )} -
- -
-

Query Data:

- {Object.keys(lokiQueryData).map((key) => ( - -

Results for process: {key}

-
{JSON.stringify(lokiQueryData[key].lokiData?.series[0].fields, null, 2)}
-
- ))} -
- -
-

Selected Rows:

-
{JSON.stringify(rows, null, 2)}
-
- + {organizedData && Object.keys(organizedData).map((section) => { + console.log(`Rendering section: ${section}`); + const panels = createPanelList(section); + console.log(`Panels for section ${section}:`, panels); + return ( + + + + ); + })} ); }; diff --git a/grafana-aitraining-app/src/utils/api.ts b/grafana-aitraining-app/src/utils/api.ts index c216103..f21cf81 100644 --- a/grafana-aitraining-app/src/utils/api.ts +++ b/grafana-aitraining-app/src/utils/api.ts @@ -8,7 +8,7 @@ type JSONArray = JSONValue[]; type JSONValue = JSONPrimitive | JSONObject | JSONArray; // Actually make a request from the plugin backend -function doRequest(fetchOptions: any): Promise { +export function doRequest(fetchOptions: any): Promise { return lastValueFrom(getBackendSrv().fetch(fetchOptions)).then((response) => { if (!response.ok) { throw response.data; diff --git a/grafana-aitraining-app/src/utils/utils.plugin.ts b/grafana-aitraining-app/src/utils/utils.plugin.ts index 5d0472a..c96c442 100644 --- a/grafana-aitraining-app/src/utils/utils.plugin.ts +++ b/grafana-aitraining-app/src/utils/utils.plugin.ts @@ -3,6 +3,7 @@ import { AppRootProps, DataQueryResponseData } from '@grafana/data'; export interface PluginProps extends AppRootProps { getProcesses: () => Promise; + getModelMetrics: (processUuid: string) => Promise; } export const PluginPropsContext = React.createContext(null); @@ -19,3 +20,8 @@ export const useGetProcesses = () => { const pluginProps = usePluginProps(); return pluginProps.getProcesses; }; + +export const useGetModelMetrics = () => { + const pluginProps = usePluginProps(); + return pluginProps.getModelMetrics; +} diff --git a/o11y/src/o11y/_internal/client.py b/o11y/src/o11y/_internal/client.py index 1beb805..864ff3f 100644 --- a/o11y/src/o11y/_internal/client.py +++ b/o11y/src/o11y/_internal/client.py @@ -23,6 +23,7 @@ def __init__(self): self.url = None self.token = None self.tenant_id = None + self.step = 1 login_string = os.environ.get('GF_AI_TRAINING_CREDS') self.set_credentials(login_string) @@ -76,6 +77,7 @@ def register_process(self, data): if self.process_uuid: self.process_uuid = None self.user_metadata = None + self.step = 1 if not self.tenant_id or not self.token: logger.error("User ID or token is not set.") @@ -146,35 +148,26 @@ def send_model_metrics(self, log: Dict[str, Any], *, x_axis: Optional[Dict[str, if not self.process_uuid: logger.error("No process registered, unable to send logs") return False - - timestamp = str(time.time_ns()) - metadata: Dict[str, Any] = { - "process_uuid": self.process_uuid, - "type": "model-metrics" - } + if not x_axis: + x_axis = { + "step": self.step + } + self.step += 1 - if x_axis: - x_key = list(x_axis.keys())[0] - metadata['x_axis'] = x_key - metadata['x_value'] = str(x_axis[x_key]) + step_name, step_value = next(iter(x_axis.items())) + metric_name, metric_value = next(iter(log.items())) - json_data = { - "streams": [ + json_data = [{ + "metric_name": metric_name, + "step_name": step_name, + "points": [ { - "stream": { - "job": "o11y", - }, - "values": [ - [ - timestamp, - json.dumps(log), - metadata, - ] - ] + "step": step_value, + "value": metric_value } ] - } + }] url = f'{self.url}/api/v1/process/{self.process_uuid}/model-metrics' diff --git a/o11y/src/o11y/exported/log.py b/o11y/src/o11y/exported/log.py index ab3778f..b16a9af 100644 --- a/o11y/src/o11y/exported/log.py +++ b/o11y/src/o11y/exported/log.py @@ -1,29 +1,35 @@ from typing import Dict, Union, Optional +from decimal import Decimal # SPDX-License-Identifier: Apache-2.0 from .. import client from .. import logger -def log(log: Dict[str, Union[int, float]], *, x_axis: Optional[Dict[str, Union[int, float]]] = None) -> bool: +def log( + log_data: Dict[str, Union[int, float, Decimal]], + *, + x_axis: Optional[Dict[str, int]] = None + ) -> bool: """ Sends a log to the Loki server. Args: - log (Dict[str, Union[int, float]]): The log message as a dictionary with string keys and numeric values. + log (Dict[str, Union[int, float, Decimal]]): The log message as a dictionary with string keys and numeric values. x_axis (Optional[Dict[str, Union[int, float]]], optional): A single-item dictionary representing the x-axis. Defaults to None. Returns: bool: True if the log was sent successfully, False otherwise. """ - if not isinstance(log, dict): + + if not isinstance(log_data, dict): logger.error("Log must be a dict") return False - if not all(isinstance(key, str) and isinstance(value, (int, float)) for key, value in log.items()): + if not all(isinstance(key, str) and isinstance(value, (int, float, Decimal)) for key, value in log_data.items()): logger.error("Log must contain only string keys and numeric values") return False - - if x_axis is None: - return bool(client.send_model_metrics(log)) + + if not x_axis: + return bool(client.send_model_metrics(log_data)) if not isinstance(x_axis, dict) or len(x_axis) != 1: logger.error("x_axis must be a dict with one key") @@ -31,12 +37,12 @@ def log(log: Dict[str, Union[int, float]], *, x_axis: Optional[Dict[str, Union[i x_key, x_value = next(iter(x_axis.items())) - if not isinstance(x_key, str) or not isinstance(x_value, (int, float)): + if not isinstance(x_key, str) or not isinstance(x_value, int): logger.error("x_axis must have a string key and a numeric value") return False - if x_key in log and x_value != log[x_key]: + if x_key in log_data and x_value != log_data[x_key]: logger.error("x_axis key must not be in your metrics, or must have the same value") return False - return bool(client.send_model_metrics(log, x_axis=x_axis)) + return bool(client.send_model_metrics(log_data, x_axis=x_axis))