Skip to content

Commit

Permalink
Initial logs logic, TODO - Refactor dupe code
Browse files Browse the repository at this point in the history
  • Loading branch information
tahiyasalam committed Aug 13, 2024
1 parent 0ba1ede commit 52af5f9
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 9 deletions.
23 changes: 20 additions & 3 deletions cli/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,7 @@ var app = &cli.App{
Required: true,
},
},
Action: DataGetTrainingJob,
Action: MLGetTrainingJob,
},
{
Name: "cancel",
Expand All @@ -1043,7 +1043,7 @@ var app = &cli.App{
Required: true,
},
},
Action: DataCancelTrainingJob,
Action: MLCancelTrainingJob,
},
{
Name: "list",
Expand All @@ -1061,7 +1061,24 @@ var app = &cli.App{
Required: true,
},
},
Action: DataListTrainingJobs,
Action: MLListTrainingJobs,
},
{
Name: "logs",
Usage: "returns logs for specified training job ID",
UsageText: createUsageText("train logs", []string{trainFlagJobID}, true),
Flags: []cli.Flag{
&cli.StringFlag{
Name: trainFlagJobID,
DefaultText: "training job ID",
},
&cli.IntFlag{
Name: logsFlagCount,
Usage: fmt.Sprintf("number of logs to fetch (max %v)", maxNumLogs),
DefaultText: fmt.Sprintf("%v", defaultNumLogs),
},
},
Action: MLGetLogs,
},
},
},
Expand Down
72 changes: 66 additions & 6 deletions cli/ml_training.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ func (c *viamClient) mlSubmitCustomTrainingJob(datasetID, registryItemID, regist
return resp.Id, nil
}

// DataGetTrainingJob is the corresponding action for 'data train get'.
func DataGetTrainingJob(c *cli.Context) error {
// MLGetTrainingJob is the corresponding action for 'train get'.
func MLGetTrainingJob(c *cli.Context) error {
client, err := newViamClient(c)
if err != nil {
return err
Expand All @@ -181,8 +181,8 @@ func (c *viamClient) dataGetTrainingJob(trainingJobID string) (*mltrainingpb.Tra
return resp.Metadata, nil
}

// DataCancelTrainingJob is the corresponding action for 'data train cancel'.
func DataCancelTrainingJob(c *cli.Context) error {
// MLCancelTrainingJob is the corresponding action for 'train cancel'.
func MLCancelTrainingJob(c *cli.Context) error {
client, err := newViamClient(c)
if err != nil {
return err
Expand All @@ -207,8 +207,8 @@ func (c *viamClient) dataCancelTrainingJob(trainingJobID string) error {
return nil
}

// DataListTrainingJobs is the corresponding action for 'data train list'.
func DataListTrainingJobs(c *cli.Context) error {
// MLListTrainingJobs is the corresponding action for 'train list'.
func MLListTrainingJobs(c *cli.Context) error {
client, err := newViamClient(c)
if err != nil {
return err
Expand Down Expand Up @@ -475,3 +475,63 @@ func convertVisibilityToProto(visibility string) (*v1.Visibility, error) {

return &visibilityProto, nil
}

// MLGetLogs is the corresponding action for 'train logs'.
func MLGetLogs(c *cli.Context) error {
client, err := newViamClient(c)
if err != nil {
return err
}
numLogs, err := getNumLogs(c)
if err != nil {
return err
}
jobs, err := client.mlLogsForTrainingJob(c.String(trainFlagJobID), numLogs)
if err != nil {
return err
}
for _, job := range jobs {
printf(c.App.Writer, "Training job: %s\n", job)
}
return nil
}

// mlLogsForTrainingJob get the logs for the given training job.
func (c *viamClient) mlLogsForTrainingJob(trainingJobID string, numLogs int) ([]*mltrainingpb.TrainingJobLogEntry, error) {
if err := c.ensureLoggedIn(); err != nil {
return nil, err
}

// Use page tokens to get batches of 100 up to numLogs and throw away any
// extra logs in last batch.
logs := make([]*mltrainingpb.TrainingJobLogEntry, 0, numLogs)
var pageToken string
for i := 0; i < numLogs; {
resp, err := c.mlTrainingClient.GetTrainingJobLogs(c.c.Context, &mltrainingpb.GetTrainingJobLogsRequest{
Id: trainingJobID,
PageToken: &pageToken,
})
if err != nil {
return nil, err
}

pageToken = resp.NextPageToken
// Break in the event of no logs in Get*LogsResponse or when
// page token is empty (no more pages).
if resp.Logs == nil || pageToken == "" {
break
}

// Truncate this intermediate slice of resp.Logs based on how many logs
// are still required by numLogs.
remainingLogsNeeded := numLogs - i
if remainingLogsNeeded < len(resp.Logs) {
resp.Logs = resp.Logs[:remainingLogsNeeded]
}
logs = append(logs, resp.Logs...)

i += len(resp.Logs)
}

return logs, nil
}

0 comments on commit 52af5f9

Please sign in to comment.