Skip to content

Commit

Permalink
set up for manual testing
Browse files Browse the repository at this point in the history
  • Loading branch information
vpandiarajan20 committed Sep 9, 2024
1 parent 273571a commit 2d1b088
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 17 deletions.
14 changes: 12 additions & 2 deletions cli/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ var app = &cli.App{
UsageText: createUsageText("train submit custom from-registry",
[]string{
datasetFlagDatasetID, generalFlagOrgID, trainFlagModelName,
mlTrainingFlagName, mlTrainingFlagVersion,
mlTrainingFlagName, mlTrainingFlagVersion, trainCliArgs,
}, true),
Flags: []cli.Flag{
&cli.StringFlag{
Expand Down Expand Up @@ -959,6 +959,11 @@ var app = &cli.App{
Usage: "version of the ML training script to use for training.",
Required: true,
},
&cli.StringFlag{
Name: trainCliArgs,
Usage: "optional command line arguments to run the training script with " + "which should be formatted as --option1=value1,--option2=value2",
Required: false,
},
},
Action: MLSubmitCustomTrainingJob,
},
Expand All @@ -968,7 +973,7 @@ var app = &cli.App{
UsageText: createUsageText("train submit custom with-upload",
[]string{
datasetFlagDatasetID, generalFlagOrgID, trainFlagModelName,
mlTrainingFlagPath, mlTrainingFlagName,
mlTrainingFlagPath, mlTrainingFlagName, trainCliArgs,
}, true),
Flags: []cli.Flag{
&cli.StringFlag{
Expand Down Expand Up @@ -1015,6 +1020,11 @@ var app = &cli.App{
Usage: "task type of the ML training script to upload, can be: " + strings.Join(modelTypes, ", "),
Required: false,
},
&cli.StringFlag{
Name: trainCliArgs,
Usage: "optional command line arguments to run the training script with " + "which should be formatted as --option1=value1,--option2=value2",
Required: false,
},
},
Action: MLSubmitCustomTrainingJobWithUpload,
},
Expand Down
40 changes: 28 additions & 12 deletions cli/ml_training.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const (
trainFlagModelVersion = "model-version"
trainFlagModelType = "model-type"
trainFlagModelLabels = "model-labels"
trainCliArgs = "args"

trainingStatusPrefix = "TRAINING_STATUS_"
)
Expand All @@ -37,7 +38,7 @@ func MLSubmitCustomTrainingJob(c *cli.Context) error {

trainingJobID, err := client.mlSubmitCustomTrainingJob(
c.String(datasetFlagDatasetID), c.String(mlTrainingFlagName), c.String(mlTrainingFlagVersion), c.String(generalFlagOrgID),
c.String(trainFlagModelName), c.String(trainFlagModelVersion))
c.String(trainFlagModelName), c.String(trainFlagModelVersion), c.String(trainCliArgs))
if err != nil {
return err
}
Expand Down Expand Up @@ -70,7 +71,7 @@ func MLSubmitCustomTrainingJobWithUpload(c *cli.Context) error {
registryItemID)
trainingJobID, err := client.mlSubmitCustomTrainingJob(
c.String(datasetFlagDatasetID), registryItemID, resp.Version, c.String(trainFlagModelOrgID),
c.String(trainFlagModelName), c.String(trainFlagModelVersion))
c.String(trainFlagModelName), c.String(trainFlagModelVersion), c.String(trainCliArgs))
if err != nil {
return err
}
Expand Down Expand Up @@ -125,7 +126,7 @@ func (c *viamClient) mlSubmitTrainingJob(datasetID, orgID, modelName, modelVersi

// mlSubmitCustomTrainingJob trains on data with the specified dataset and registry item.
func (c *viamClient) mlSubmitCustomTrainingJob(datasetID, registryItemID, registryItemVersion, orgID, modelName,
modelVersion string,
modelVersion, args string,
) (string, error) {
if err := c.ensureLoggedIn(); err != nil {
return "", err
Expand All @@ -140,15 +141,30 @@ func (c *viamClient) mlSubmitCustomTrainingJob(datasetID, registryItemID, regist
modelVersion = time.Now().Format("2006-01-02T15-04-05")
}

resp, err := c.mlTrainingClient.SubmitCustomTrainingJob(context.Background(),
&mltrainingpb.SubmitCustomTrainingJobRequest{
DatasetId: datasetID,
RegistryItemId: registryItemID,
RegistryItemVersion: registryItemVersion,
OrganizationId: orgID,
ModelName: modelName,
ModelVersion: modelVersion,
})
req := &mltrainingpb.SubmitCustomTrainingJobRequest{
DatasetId: datasetID,
RegistryItemId: registryItemID,
RegistryItemVersion: registryItemVersion,
OrganizationId: orgID,
ModelName: modelName,
ModelVersion: modelVersion,
}

var argMap map[string]string
splitArgs := strings.Split(args, ",")

for _, optionVal := range splitArgs {
splitOptionVal := strings.Split(optionVal, "=")
if len(splitOptionVal) != 2 {
return "", errors.Errorf("invalid format for command line arguments, passed:", args)
}
argMap[splitOptionVal[0]] = splitOptionVal[1]
}
if argMap != nil {
req.Arguments = argMap
}

resp, err := c.mlTrainingClient.SubmitCustomTrainingJob(context.Background(), req)
if err != nil {
return "", errors.Wrapf(err, "received error from server")
}
Expand Down
37 changes: 37 additions & 0 deletions cli/ml_training_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package cli

import "testing"

// import (
// "context"
// "testing"

// mltrainingpb "go.viam.com/api/mltrainingpb/v1"

// "google.golang.org/grpc"
// )

func TestMLSubmitCustomTrainingJob(t *testing.T) {
// flags := make(map[string]any)
// flags[datasetFlagDatasetID] = "01"
// flags[mlTrainingFlagName] = "02"
// flags[mlTrainingFlagVersion] = "03"
// flags[generalFlagOrgID] = "04"
// flags[trainFlagModelName] = "05"
// flags[trainFlagModelVersion] = "06"
// flags[trainCliArgs] = "07"

// cCtx, ac, out, errOut := setup(nil, nil, nil, nil, nil, "")

// ac.mlSubmitCustomTrainingJob("456", "123", "789", "ab-cd", "goodModel", "latest",
// "--food==ice_cream,--bfast=bagel,din=pizza")

// submitCustomTrainingJob := func(ctx context.Context, in *mltrainingpb.SubmitCustomTrainingJobRequest, opts ...grpc.CallOption) (*mltrainingpb.SubmitCustomTrainingJobResponse, error) {
// return &mltrainingpb.SubmitCustomerTrainingResponse{Id: 200}
// }

// fake_MLServe := &inject.
// {SubmitCustomTrainingJob: submitCustomTrainingJob}
}

// Attempted to test MLSubmitCustomTrainingJob, but was a little too difficult to setup, going to focus on manual testing
6 changes: 5 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module go.viam.com/rdk

go 1.21
go 1.21.13

toolchain go1.23.1

require (
github.com/AlekSi/gocov-xml v1.0.0
Expand Down Expand Up @@ -402,3 +404,5 @@ require (
github.com/ziutek/mymysql v1.5.4 // indirect
golang.org/x/exp v0.0.0-20230725012225-302865e7556b
)

replace go.viam.com/api => github.com/viamrobotics/api v0.1.338-0.20240906184640-16a974029c9e // <FEATURE-BRANCH-NAME>
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1439,6 +1439,8 @@ github.com/valyala/quicktemplate v1.6.3/go.mod h1:fwPzK2fHuYEODzJ9pkw0ipCPNHZ2tD
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
github.com/viam-labs/go-libjpeg v0.3.1 h1:J/byavXHFqRI1PFPrnPbP+wFCr1y+Cn1CwKXrORCPD0=
github.com/viam-labs/go-libjpeg v0.3.1/go.mod h1:b0ISpf9lJv9MO1h1gXAmSA/osG19cKGYjfYc6aeEjqs=
github.com/viamrobotics/api v0.1.338-0.20240906184640-16a974029c9e h1:b2zCyRPLgKpGBkmsf/OYC42sIcVjSNTyQ7lfRB4rKak=
github.com/viamrobotics/api v0.1.338-0.20240906184640-16a974029c9e/go.mod h1:5lpVRxMsKFCaahqsnJfPGwJ9baoQ6PIKQu3lxvy6Wtw=
github.com/viamrobotics/evdev v0.1.3 h1:mR4HFafvbc5Wx4Vp1AUJp6/aITfVx9AKyXWx+rWjpfc=
github.com/viamrobotics/evdev v0.1.3/go.mod h1:N6nuZmPz7HEIpM7esNWwLxbYzqWqLSZkfI/1Sccckqk=
github.com/viamrobotics/webrtc/v3 v3.99.9 h1:5FCctlMhO9lr4SJ1TC2WCFocBIriUMb3Sw7i9oDlz2o=
Expand Down Expand Up @@ -1539,8 +1541,6 @@ go.uber.org/zap v1.18.1/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI=
go.uber.org/zap v1.23.0/go.mod h1:D+nX8jyLsMHMYrln8A0rJjFt/T/9/bGgIhAqxv5URuY=
go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60=
go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
go.viam.com/api v0.1.336 h1:mcz3Y5rivgXhsTu/bXkAVDw0/otarq3lCPIRcxhNnIY=
go.viam.com/api v0.1.336/go.mod h1:msa4TPrMVeRDcG4YzKA/S6wLEUC7GyHQE973JklrQ10=
go.viam.com/test v1.1.1-0.20220913152726-5da9916c08a2 h1:oBiK580EnEIzgFLU4lHOXmGAE3MxnVbeR7s1wp/F3Ps=
go.viam.com/test v1.1.1-0.20220913152726-5da9916c08a2/go.mod h1:XM0tej6riszsiNLT16uoyq1YjuYPWlRBweTPRDanIts=
go.viam.com/utils v0.1.96 h1:lZ7m1jm6pkjn60e9OWeHyHcivFIRrFvS/5NGrPXFaDM=
Expand Down

0 comments on commit 2d1b088

Please sign in to comment.