From 39dbf3ce31458cc80184ea265c16504a9b0a1c2d Mon Sep 17 00:00:00 2001 From: Brent Salisbury Date: Sun, 26 Jan 2025 01:03:15 -0500 Subject: [PATCH] Add any served models in a column in the jobs table Enables the user to leave and return to the model chat eval without having launch a new job, instead use the existing job for either pre-train or post-train. Signed-off-by: Brent Salisbury --- api-server/.gitignore | 3 ++ api-server/handlers.go | 108 ++++++++++++++++++++++++++++++++++++----- api-server/jobs.go | 14 ++++-- api-server/main.go | 71 +++++++++++++++++++++------ 4 files changed, 163 insertions(+), 33 deletions(-) diff --git a/api-server/.gitignore b/api-server/.gitignore index cd9e8a34..ff87a0cf 100644 --- a/api-server/.gitignore +++ b/api-server/.gitignore @@ -21,6 +21,9 @@ go.work.sum # env file .env +# binary +api-server + # app specific logs/ jobs.json diff --git a/api-server/handlers.go b/api-server/handlers.go index 147669ce..e5728d21 100644 --- a/api-server/handlers.go +++ b/api-server/handlers.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "database/sql" "encoding/json" "fmt" "github.com/gorilla/mux" @@ -243,12 +244,21 @@ func (srv *ILabServer) getVllmStatusHandler(w http.ResponseWriter, r *http.Reque return } - srv.jobIDsMutex.RLock() - jobID, ok := srv.servedModelJobIDs[modelName] - srv.jobIDsMutex.RUnlock() + // Directly query the DB for the job associated with this model + var jobID string + err = srv.db.QueryRow(` + SELECT job_id + FROM jobs + WHERE served_model_name = ? AND status = 'running' + LIMIT 1 + `, modelName).Scan(&jobID) - if !ok { - srv.log.Infof("WTF jobid not found for model '%s'", modelName) + if err == sql.ErrNoRows { + srv.log.Infof("No running job found for model '%s'", modelName) + _ = json.NewEncoder(w).Encode(map[string]string{"status": "loading"}) + return + } else if err != nil { + srv.log.Errorf("Error querying job for model '%s': %v", modelName, err) _ = json.NewEncoder(w).Encode(map[string]string{"status": "loading"}) return } @@ -629,6 +639,26 @@ func (srv *ILabServer) runVllmContainerHandler( gpuIndex int, hostVolume, containerVolume string, w http.ResponseWriter, ) { + // Check if a job is already running for the requested model + existingJob, err := srv.getRunningJobByModel(servedModelName) + if err != nil { + srv.log.Errorf("Error checking existing jobs for model '%s': %v", servedModelName, err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + if existingJob != nil { + srv.log.Infof("A job is already running for model '%s' with job_id: %s", servedModelName, existingJob.JobID) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "status": "already_running", + "job_id": existingJob.JobID, + "message": fmt.Sprintf("Model '%s' is already being served.", servedModelName), + }) + return + } + + srv.log.Infof("No existing job found for model '%s'. Starting a new job.", servedModelName) + cmdArgs := []string{ "run", "--rm", fmt.Sprintf("--device=nvidia.com/gpu=%d", gpuIndex), @@ -681,13 +711,14 @@ func (srv *ILabServer) runVllmContainerHandler( // Create a Job record and store it in the DB newJob := &Job{ - JobID: jobID, - Cmd: "podman", - Args: cmdArgs, - Status: "running", - PID: cmd.Process.Pid, - LogFile: logFilePath, - StartTime: time.Now(), + JobID: jobID, + Cmd: "podman", + Args: cmdArgs, + Status: "running", + PID: cmd.Process.Pid, + LogFile: logFilePath, + StartTime: time.Now(), + ServedModelName: servedModelName, } if err := srv.createJob(newJob); err != nil { srv.log.Errorf("Failed to create job in DB for %s: %v", jobID, err) @@ -859,6 +890,59 @@ func (srv *ILabServer) serveModelHandler(modelPath, port string, w http.Response _ = json.NewEncoder(w).Encode(map[string]string{"status": "model process started", "job_id": jobID}) } +// getRunningJobByModel retrieves a running job for the specified served_model_name. +// Returns nil if no such job exists. +func (srv *ILabServer) getRunningJobByModel(servedModelName string) (*Job, error) { + var job Job + var argsJSON string + var startTimeStr, endTimeStr sql.NullString + + row := srv.db.QueryRow(` + SELECT job_id, cmd, args, status, pid, log_file, start_time, end_time, branch, served_model_name + FROM jobs + WHERE served_model_name = ? AND status = 'running' + LIMIT 1 + `, servedModelName) + + err := row.Scan( + &job.JobID, + &job.Cmd, + &argsJSON, + &job.Status, + &job.PID, + &job.LogFile, + &startTimeStr, + &endTimeStr, + &job.Branch, + &job.ServedModelName, + ) + if err == sql.ErrNoRows { + return nil, nil + } else if err != nil { + return nil, err + } + + if err := json.Unmarshal([]byte(argsJSON), &job.Args); err != nil { + srv.log.Errorf("Failed to unmarshal Args for job '%s': %v", job.JobID, err) + return nil, fmt.Errorf("failed to unmarshal Args for job '%s': %v", job.JobID, err) + } + + if startTimeStr.Valid { + t, err := time.Parse(time.RFC3339, startTimeStr.String) + if err == nil { + job.StartTime = t + } + } + if endTimeStr.Valid && endTimeStr.String != "" { + t, err := time.Parse(time.RFC3339, endTimeStr.String) + if err == nil { + job.EndTime = &t + } + } + + return &job, nil +} + // listServedModelJobIDsHandler is a debug endpoint to list current model to jobID mappings. func (srv *ILabServer) listServedModelJobIDsHandler(w http.ResponseWriter, r *http.Request) { srv.jobIDsMutex.RLock() diff --git a/api-server/jobs.go b/api-server/jobs.go index 48d6d164..a42788da 100644 --- a/api-server/jobs.go +++ b/api-server/jobs.go @@ -33,7 +33,8 @@ func (srv *ILabServer) initDB() { log_file TEXT, start_time TEXT, end_time TEXT, - branch TEXT + branch TEXT, + served_model_name TEXT ); ` _, err = srv.db.Exec(createTableSQL) @@ -58,8 +59,8 @@ func (srv *ILabServer) createJob(job *Job) error { endTimeStr = &s } _, err = srv.db.Exec(` - INSERT INTO jobs (job_id, cmd, args, status, pid, log_file, start_time, end_time, branch) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO jobs (job_id, cmd, args, status, pid, log_file, start_time, end_time, branch, served_model_name) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) `, job.JobID, job.Cmd, @@ -70,6 +71,7 @@ func (srv *ILabServer) createJob(job *Job) error { job.StartTime.Format(time.RFC3339), endTimeStr, job.Branch, + job.ServedModelName, ) if err != nil { return fmt.Errorf("failed to insert job: %v", err) @@ -79,7 +81,7 @@ func (srv *ILabServer) createJob(job *Job) error { // getJob fetches a single job by job_id. func (srv *ILabServer) getJob(jobID string) (*Job, error) { - row := srv.db.QueryRow("SELECT job_id, cmd, args, status, pid, log_file, start_time, end_time, branch FROM jobs WHERE job_id = ?", jobID) + row := srv.db.QueryRow("SELECT job_id, cmd, args, status, pid, log_file, start_time, end_time, branch, served_model_name FROM jobs WHERE job_id = ?", jobID) var j Job var argsJSON string @@ -95,6 +97,7 @@ func (srv *ILabServer) getJob(jobID string) (*Job, error) { &startTimeStr, &endTimeStr, &j.Branch, + &j.ServedModelName, ) if err == sql.ErrNoRows { return nil, nil // not found @@ -133,7 +136,7 @@ func (srv *ILabServer) updateJob(job *Job) error { } _, err = srv.db.Exec(` UPDATE jobs - SET cmd = ?, args = ?, status = ?, pid = ?, log_file = ?, start_time = ?, end_time = ?, branch = ? + SET cmd = ?, args = ?, status = ?, pid = ?, log_file = ?, start_time = ?, end_time = ?, branch = ?, served_model_name = ? WHERE job_id = ? `, job.Cmd, @@ -144,6 +147,7 @@ func (srv *ILabServer) updateJob(job *Job) error { job.StartTime.Format(time.RFC3339), endTimeStr, job.Branch, + job.ServedModelName, job.JobID, ) if err != nil { diff --git a/api-server/main.go b/api-server/main.go index f6420aef..564c3ede 100644 --- a/api-server/main.go +++ b/api-server/main.go @@ -39,15 +39,16 @@ type Data struct { // Job represents a background job, including train/generate/pipeline/vllm-run jobs. type Job struct { - JobID string `json:"job_id"` - Cmd string `json:"cmd"` - Args []string `json:"args"` - Status string `json:"status"` // "running", "finished", "failed" - PID int `json:"pid"` - LogFile string `json:"log_file"` - StartTime time.Time `json:"start_time"` - EndTime *time.Time `json:"end_time,omitempty"` - Branch string `json:"branch"` + JobID string `json:"job_id"` + Cmd string `json:"cmd"` + Args []string `json:"args"` + Status string `json:"status"` // "running", "finished", "failed" + PID int `json:"pid"` + LogFile string `json:"log_file"` + StartTime time.Time `json:"start_time"` + EndTime *time.Time `json:"end_time,omitempty"` + Branch string `json:"branch"` + ServedModelName string `json:"served_model_name"` // Lock is not serialized; it protects updates to the Job in memory. Lock sync.Mutex `json:"-"` @@ -94,7 +95,7 @@ type ILabServer struct { useVllm bool pipelineType string debugEnabled bool - homeDir string // New field added + homeDir string // Logger logger *zap.Logger @@ -119,12 +120,7 @@ type ILabServer struct { modelCache ModelCache } -// ----------------------------------------------------------------------------- -// main(), flags and Cobra -// ----------------------------------------------------------------------------- - func main() { - // We create an instance of ILabServer to hold all state and methods. srv := &ILabServer{ baseModel: "instructlab/granite-7b-lab", servedModelJobIDs: make(map[string]string), @@ -135,7 +131,6 @@ func main() { Use: "ilab-server", Short: "ILab Server Application", Run: func(cmd *cobra.Command, args []string) { - // Now that flags are set, run the server method on the struct. srv.runServer(cmd, args) }, } @@ -248,6 +243,8 @@ func (srv *ILabServer) runServer(cmd *cobra.Command, args []string) { // Initialize the model cache srv.initializeModelCache() + srv.reconstructServedModelJobIDs() + // Create the logs directory if it doesn't exist err = os.MkdirAll("logs", os.ModePerm) if err != nil { @@ -348,6 +345,48 @@ func (srv *ILabServer) refreshModelCache() { srv.log.Infof("Model cache refreshed at %v with %d models.", srv.modelCache.Time, len(models)) } +// reconstructServedModelJobIDs rebuilds the servedModelJobIDs map by querying the database +func (srv *ILabServer) reconstructServedModelJobIDs() { + srv.log.Info("Reconstructing servedModelJobIDs from the database...") + + rows, err := srv.db.Query(` + SELECT job_id, served_model_name + FROM jobs + WHERE cmd = 'podman' AND status = 'running' + `) + if err != nil { + srv.log.Errorf("Error querying running vLLM jobs: %v", err) + return + } + defer rows.Close() + + for rows.Next() { + var jobID, servedModelName string + if err := rows.Scan(&jobID, &servedModelName); err != nil { + srv.log.Errorf("Error scanning row: %v", err) + continue + } + + // Validate servedModelName + if servedModelName != "pre-train" && servedModelName != "post-train" { + srv.log.Warnf("Invalid served_model_name '%s' for job_id '%s'", servedModelName, jobID) + continue + } + + // Update the servedModelJobIDs map + srv.jobIDsMutex.Lock() + srv.servedModelJobIDs[servedModelName] = jobID + srv.jobIDsMutex.Unlock() + srv.log.Infof("Mapped model '%s' to job_id '%s'", servedModelName, jobID) + } + + if err := rows.Err(); err != nil { + srv.log.Errorf("Error iterating over rows: %v", err) + } + + srv.log.Info("Reconstruction of servedModelJobIDs completed.") +} + // ----------------------------------------------------------------------------- // Start Generate Data Job // -----------------------------------------------------------------------------