diff --git a/cli/app.go b/cli/app.go index eed81cf7cca..0f9d02f192e 100644 --- a/cli/app.go +++ b/cli/app.go @@ -1030,7 +1030,7 @@ var app = &cli.App{ Required: true, }, }, - Action: DataGetTrainingJob, + Action: MLGetTrainingJob, }, { Name: "cancel", @@ -1043,7 +1043,7 @@ var app = &cli.App{ Required: true, }, }, - Action: DataCancelTrainingJob, + Action: MLCancelTrainingJob, }, { Name: "list", @@ -1061,7 +1061,24 @@ var app = &cli.App{ Required: true, }, }, - Action: DataListTrainingJobs, + Action: MLListTrainingJobs, + }, + { + Name: "logs", + Usage: "returns logs for specified training job ID", + UsageText: createUsageText("train logs", []string{trainFlagJobID}, true), + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: trainFlagJobID, + DefaultText: "training job ID", + }, + &cli.IntFlag{ + Name: logsFlagCount, + Usage: fmt.Sprintf("number of logs to fetch (max %v)", maxNumLogs), + DefaultText: fmt.Sprintf("%v", defaultNumLogs), + }, + }, + Action: MLGetLogs, }, }, }, diff --git a/cli/ml_training.go b/cli/ml_training.go index 87ba84ca853..9b53df8df61 100644 --- a/cli/ml_training.go +++ b/cli/ml_training.go @@ -155,8 +155,8 @@ func (c *viamClient) mlSubmitCustomTrainingJob(datasetID, registryItemID, regist return resp.Id, nil } -// DataGetTrainingJob is the corresponding action for 'data train get'. -func DataGetTrainingJob(c *cli.Context) error { +// MLGetTrainingJob is the corresponding action for 'train get'. +func MLGetTrainingJob(c *cli.Context) error { client, err := newViamClient(c) if err != nil { return err @@ -181,8 +181,8 @@ func (c *viamClient) dataGetTrainingJob(trainingJobID string) (*mltrainingpb.Tra return resp.Metadata, nil } -// DataCancelTrainingJob is the corresponding action for 'data train cancel'. -func DataCancelTrainingJob(c *cli.Context) error { +// MLCancelTrainingJob is the corresponding action for 'train cancel'. +func MLCancelTrainingJob(c *cli.Context) error { client, err := newViamClient(c) if err != nil { return err @@ -207,8 +207,8 @@ func (c *viamClient) dataCancelTrainingJob(trainingJobID string) error { return nil } -// DataListTrainingJobs is the corresponding action for 'data train list'. -func DataListTrainingJobs(c *cli.Context) error { +// MLListTrainingJobs is the corresponding action for 'train list'. +func MLListTrainingJobs(c *cli.Context) error { client, err := newViamClient(c) if err != nil { return err @@ -475,3 +475,63 @@ func convertVisibilityToProto(visibility string) (*v1.Visibility, error) { return &visibilityProto, nil } + +// MLGetLogs is the corresponding action for 'train logs'. +func MLGetLogs(c *cli.Context) error { + client, err := newViamClient(c) + if err != nil { + return err + } + numLogs, err := getNumLogs(c) + if err != nil { + return err + } + jobs, err := client.mlLogsForTrainingJob(c.String(trainFlagJobID), numLogs) + if err != nil { + return err + } + for _, job := range jobs { + printf(c.App.Writer, "Training job: %s\n", job) + } + return nil +} + +// mlLogsForTrainingJob get the logs for the given training job. +func (c *viamClient) mlLogsForTrainingJob(trainingJobID string, numLogs int) ([]*mltrainingpb.TrainingJobLogEntry, error) { + if err := c.ensureLoggedIn(); err != nil { + return nil, err + } + + // Use page tokens to get batches of 100 up to numLogs and throw away any + // extra logs in last batch. + logs := make([]*mltrainingpb.TrainingJobLogEntry, 0, numLogs) + var pageToken string + for i := 0; i < numLogs; { + resp, err := c.mlTrainingClient.GetTrainingJobLogs(c.c.Context, &mltrainingpb.GetTrainingJobLogsRequest{ + Id: trainingJobID, + PageToken: &pageToken, + }) + if err != nil { + return nil, err + } + + pageToken = resp.NextPageToken + // Break in the event of no logs in Get*LogsResponse or when + // page token is empty (no more pages). + if resp.Logs == nil || pageToken == "" { + break + } + + // Truncate this intermediate slice of resp.Logs based on how many logs + // are still required by numLogs. + remainingLogsNeeded := numLogs - i + if remainingLogsNeeded < len(resp.Logs) { + resp.Logs = resp.Logs[:remainingLogsNeeded] + } + logs = append(logs, resp.Logs...) + + i += len(resp.Logs) + } + + return logs, nil +}