Skip to content

Commit

Permalink
Merge branch 'viamrobotics:main' into rsdk-7753
Browse files Browse the repository at this point in the history
  • Loading branch information
martha-johnston committed Sep 11, 2024
2 parents 1e1b14f + 98d78c3 commit e9def7b
Show file tree
Hide file tree
Showing 94 changed files with 2,443 additions and 2,040 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,6 @@ web/cmd/server/*.core
# gomobile / android
*.aar
*.jar

# direnv (optional dev tool)
.envrc
59 changes: 41 additions & 18 deletions cli/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const (
moduleFlagLocal = "local"
moduleFlagHomeDir = "home"
moduleCreateLocalOnly = "local-only"
moduleFlagID = "id"

moduleBuildFlagPath = "module"
moduleBuildFlagRef = "ref"
Expand Down Expand Up @@ -1064,7 +1065,7 @@ var app = &cli.App{
{
Name: "list",
Usage: "list training jobs in Viam cloud based on organization ID",
UsageText: createUsageText("train list", []string{generalFlagOrgID, trainFlagJobStatus}, false),
UsageText: createUsageText("train list", []string{generalFlagOrgID}, true),
Flags: []cli.Flag{
&cli.StringFlag{
Name: generalFlagOrgID,
Expand All @@ -1074,7 +1075,8 @@ var app = &cli.App{
&cli.StringFlag{
Name: trainFlagJobStatus,
Usage: "training status to filter for. can be one of " + allTrainingStatusValues(),
Required: true,
Required: false,
Value: defaultTrainingStatus(),
},
},
Action: DataListTrainingJobs,
Expand Down Expand Up @@ -1745,6 +1747,32 @@ This won't work unless you have an existing installation of our GitHub app on yo
},
Action: ReloadModuleAction,
},
{
Name: "download",
Usage: "download a module package from the registry",
UsageText: createUsageText("module download", []string{}, false),
Flags: []cli.Flag{
&cli.PathFlag{
Name: packageFlagDestination,
Usage: "output directory for downloaded package",
Value: ".",
},
&cli.StringFlag{
Name: moduleFlagID,
Usage: "module ID as org-id:name or namespace:name. if missing, will try to read from meta.json",
},
&cli.StringFlag{
Name: packageFlagVersion,
Usage: "version of the requested package, can be `latest` to get the most recent version",
Value: "latest",
},
&cli.StringFlag{
Name: moduleFlagPlatform,
Usage: "platform like 'linux/amd64'. if missing, will use platform of the CLI binary",
},
},
Action: DownloadModuleAction,
},
},
},
{
Expand All @@ -1756,30 +1784,25 @@ This won't work unless you have an existing installation of our GitHub app on yo
Name: "export",
Usage: "download a package from Viam cloud",
UsageText: createUsageText("packages export",
[]string{
packageFlagDestination, generalFlagOrgID, packageFlagName,
packageFlagVersion, packageFlagType,
}, false),
[]string{packageFlagType}, false),
Flags: []cli.Flag{
&cli.PathFlag{
Name: packageFlagDestination,
Required: true,
Usage: "output directory for downloaded package",
Name: packageFlagDestination,
Usage: "output directory for downloaded package",
Value: ".",
},
&cli.StringFlag{
Name: generalFlagOrgID,
Required: true,
Usage: "organization ID of the requested package",
Name: generalFlagOrgID,
Usage: "organization ID or namespace of the requested package. if missing, will try to read from meta.json",
},
&cli.StringFlag{
Name: packageFlagName,
Required: true,
Usage: "name of the requested package",
Name: packageFlagName,
Usage: "name of the requested package. if missing, will try to read from meta.json",
},
&cli.StringFlag{
Name: packageFlagVersion,
Required: true,
Usage: "version of the requested package, can be `latest` to get the most recent version",
Name: packageFlagVersion,
Usage: "version of the requested package, can be `latest` to get the most recent version",
Value: "latest",
},
&cli.StringFlag{
Name: packageFlagType,
Expand Down
44 changes: 34 additions & 10 deletions cli/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ func (c *viamClient) binaryData(dst string, filter *datapb.Filter, parallelDownl

return c.performActionOnBinaryDataFromFilter(
func(id *datapb.BinaryID) error {
return downloadBinary(c.c.Context, c.dataClient, dst, id, c.authFlow.httpClient, c.conf.Auth)
return c.downloadBinary(dst, id)
},
filter, parallelDownloads,
func(i int32) {
Expand Down Expand Up @@ -412,35 +412,39 @@ func getMatchingBinaryIDs(ctx context.Context, client datapb.DataServiceClient,
}
}

func downloadBinary(ctx context.Context, client datapb.DataServiceClient, dst string, id *datapb.BinaryID,
httpClient *http.Client, auth authMethod,
) error {
func (c *viamClient) downloadBinary(dst string, id *datapb.BinaryID) error {
debugf(c.c.App.Writer, c.c.Bool(debugFlag), "Attempting to download binary file %s", id.FileId)

var resp *datapb.BinaryDataByIDsResponse
var err error
largeFile := false
// To begin, we assume the file is small and downloadable, so we try getting the binary directly
for count := 0; count < maxRetryCount; count++ {
resp, err = client.BinaryDataByIDs(ctx, &datapb.BinaryDataByIDsRequest{
resp, err = c.dataClient.BinaryDataByIDs(c.c.Context, &datapb.BinaryDataByIDsRequest{
BinaryIds: []*datapb.BinaryID{id},
IncludeBinary: !largeFile,
})
// If the file is too large, we break and try a different pathway for downloading
if err == nil || status.Code(err) == codes.ResourceExhausted {
debugf(c.c.App.Writer, c.c.Bool(debugFlag), "Small file download file %s: attempt %d/%d succeeded", id.FileId, count+1, maxRetryCount)
break
}
debugf(c.c.App.Writer, c.c.Bool(debugFlag), "Small file download for file %s: attempt %d/%d failed", id.FileId, count+1, maxRetryCount)
}
// For large files, we get the metadata but not the binary itself
// Resource exhausted is returned when the message we're receiving exceeds the GRPC maximum limit
if err != nil && status.Code(err) == codes.ResourceExhausted {
largeFile = true
for count := 0; count < maxRetryCount; count++ {
resp, err = client.BinaryDataByIDs(ctx, &datapb.BinaryDataByIDsRequest{
resp, err = c.dataClient.BinaryDataByIDs(c.c.Context, &datapb.BinaryDataByIDsRequest{
BinaryIds: []*datapb.BinaryID{id},
IncludeBinary: !largeFile,
})
if err == nil {
debugf(c.c.App.Writer, c.c.Bool(debugFlag), "Metadata fetch for file %s: attempt %d/%d succeeded", id.FileId, count+1, maxRetryCount)
break
}
debugf(c.c.App.Writer, c.c.Bool(debugFlag), "Metadata fetch for file %s: attempt %d/%d failed", id.FileId, count+1, maxRetryCount)
}
}
if err != nil {
Expand Down Expand Up @@ -477,29 +481,43 @@ func downloadBinary(ctx context.Context, client datapb.DataServiceClient, dst st

var bin []byte
if largeFile {
debugf(c.c.App.Writer, c.c.Bool(debugFlag), "Attempting file %s as a large file download", id.FileId)
// Make request to the URI for large files since we exceed the message limit for gRPC
req, err := http.NewRequestWithContext(ctx, http.MethodGet, datum.GetMetadata().GetUri(), nil)
req, err := http.NewRequestWithContext(c.c.Context, http.MethodGet, datum.GetMetadata().GetUri(), nil)
if err != nil {
return errors.Wrapf(err, serverErrorMessage)
}

// Set the headers so HTTP requests that are not gRPC calls can still be authenticated in app
// We can authenticate via token or API key, so we try both.
token, ok := auth.(*token)
token, ok := c.conf.Auth.(*token)
if ok {
req.Header.Add(rpc.MetadataFieldAuthorization, rpc.AuthorizationValuePrefixBearer+token.AccessToken)
}
apiKey, ok := auth.(*apiKey)
apiKey, ok := c.conf.Auth.(*apiKey)
if ok {
req.Header.Add("key_id", apiKey.KeyID)
req.Header.Add("key", apiKey.KeyCrypto)
}

res, err := httpClient.Do(req)
var res *http.Response
for count := 0; count < maxRetryCount; count++ {
res, err = c.authFlow.httpClient.Do(req)

if err == nil && res.StatusCode == http.StatusOK {
debugf(c.c.App.Writer, c.c.Bool(debugFlag),
"Large file download for file %s: attempt %d/%d succeeded", id.FileId, count+1, maxRetryCount)
break
}
debugf(c.c.App.Writer, c.c.Bool(debugFlag), "Large file download for file %s: attempt %d/%d failed", id.FileId, count+1, maxRetryCount)
}

if err != nil {
debugf(c.c.App.Writer, c.c.Bool(debugFlag), "Failed downloading large file %s: %s", id.FileId, err)
return errors.Wrapf(err, serverErrorMessage)
}
if res.StatusCode != http.StatusOK {
debugf(c.c.App.Writer, c.c.Bool(debugFlag), "Failed downloading large file %s: Server returned %d response", id.FileId, res.StatusCode)
return errors.New(serverErrorMessage)
}
defer func() {
Expand All @@ -508,6 +526,7 @@ func downloadBinary(ctx context.Context, client datapb.DataServiceClient, dst st

bin, err = io.ReadAll(res.Body)
if err != nil {
debugf(c.c.App.Writer, c.c.Bool(debugFlag), "Failed downloading large file %s, error occurred while reading: %s", id.FileId, err)
return errors.Wrapf(err, serverErrorMessage)
}
} else {
Expand All @@ -525,6 +544,7 @@ func downloadBinary(ctx context.Context, client datapb.DataServiceClient, dst st
if ext == gzFileExt {
r, err = gzip.NewReader(r)
if err != nil {
debugf(c.c.App.Writer, c.c.Bool(debugFlag), "Failed unzipping file %s: %s", id.FileId, err)
return err
}
} else if filepath.Ext(dataPath) != ext {
Expand All @@ -534,18 +554,22 @@ func downloadBinary(ctx context.Context, client datapb.DataServiceClient, dst st
}

if err := os.MkdirAll(filepath.Dir(dataPath), 0o700); err != nil {
debugf(c.c.App.Writer, c.c.Bool(debugFlag), "Failed creating data directory %s: %s", dataPath, err)
return errors.Wrapf(err, "could not create data directory %s", filepath.Dir(dataPath))
}
//nolint:gosec
dataFile, err := os.Create(dataPath)
if err != nil {
debugf(c.c.App.Writer, c.c.Bool(debugFlag), "Failed creating file %s: %s", id.FileId, err)
return errors.Wrapf(err, fmt.Sprintf("could not create file for datum %s", datum.GetMetadata().GetId()))
}
//nolint:gosec
if _, err := io.Copy(dataFile, r); err != nil {
debugf(c.c.App.Writer, c.c.Bool(debugFlag), "Failed writing data to file %s: %s", id.FileId, err)
return err
}
if err := r.Close(); err != nil {
debugf(c.c.App.Writer, c.c.Bool(debugFlag), "Failed closing file %s: %s", id.FileId, err)
return err
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion cli/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func (c *viamClient) downloadDataset(dst, datasetID string, includeJSONLines boo

return c.performActionOnBinaryDataFromFilter(
func(id *datapb.BinaryID) error {
downloadErr := downloadBinary(c.c.Context, c.dataClient, dst, id, c.authFlow.httpClient, c.conf.Auth)
downloadErr := c.downloadBinary(dst, id)
var datasetErr error
if includeJSONLines {
datasetErr = binaryDataToJSONLines(c.c.Context, c.dataClient, dst, datasetFile, id)
Expand Down
4 changes: 4 additions & 0 deletions cli/ml_training.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,10 @@ func allTrainingStatusValues() string {
return "[" + strings.Join(formattedStatuses, ", ") + "]"
}

func defaultTrainingStatus() string {
return strings.ToLower(strings.TrimPrefix(mltrainingpb.TrainingStatus_TRAINING_STATUS_UNSPECIFIED.String(), trainingStatusPrefix))
}

// MLTrainingUploadAction uploads a new custom training script.
func MLTrainingUploadAction(c *cli.Context) error {
client, err := newViamClient(c)
Expand Down
76 changes: 76 additions & 0 deletions cli/module_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ import (
"math"
"os"
"os/exec"
"path"
"path/filepath"
"runtime"
"slices"
"strings"

"github.com/google/uuid"
Expand Down Expand Up @@ -877,3 +880,76 @@ func getNextModuleUploadRequest(file *os.File) (*apppb.UploadModuleFileRequest,
},
}, nil
}

// DownloadModuleAction downloads a module.
func DownloadModuleAction(c *cli.Context) error {
moduleID := c.String(moduleFlagID)
if moduleID == "" {
manifest, err := loadManifest(defaultManifestFilename)
if err != nil {
return errors.Wrap(err, "trying to get package ID from meta.json")
}
moduleID = manifest.ModuleID
}
client, err := newViamClient(c)
if err != nil {
return err
}
if err := client.ensureLoggedIn(); err != nil {
return err
}
req := &apppb.GetModuleRequest{ModuleId: moduleID}
res, err := client.client.GetModule(c.Context, req)
if err != nil {
return err
}
if len(res.Module.Versions) == 0 {
return errors.New("module has 0 uploaded versions, nothing to download")
}
requestedVersion := c.String(packageFlagVersion)
var ver *apppb.VersionHistory
if requestedVersion == "latest" {
ver = res.Module.Versions[len(res.Module.Versions)-1]
} else {
for _, iVer := range res.Module.Versions {
if iVer.Version == requestedVersion {
ver = iVer
break
}
}
if ver == nil {
return fmt.Errorf("version %s not found in versions for module", requestedVersion)
}
}
infof(c.App.ErrWriter, "found version %s", ver.Version)
if len(ver.Files) == 0 {
return fmt.Errorf("version %s has 0 files uploaded", ver.Version)
}
platform := c.String(moduleFlagPlatform)
if platform == "" {
platform = fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH)
infof(c.App.ErrWriter, "using default platform %s", platform)
}
if !slices.ContainsFunc(ver.Files, func(file *apppb.Uploads) bool { return file.Platform == platform }) {
return fmt.Errorf("platform %s not present for version %s", platform, ver.Version)
}
include := true
packageType := packagespb.PackageType_PACKAGE_TYPE_MODULE
// note: this is working around a GetPackage quirk where platform messes with version
fullVersion := fmt.Sprintf("%s-%s", ver.Version, strings.ReplaceAll(platform, "/", "-"))
pkg, err := client.packageClient.GetPackage(c.Context, &packagespb.GetPackageRequest{
Id: strings.ReplaceAll(moduleID, ":", "/"),
Version: fullVersion,
IncludeUrl: &include,
Type: &packageType,
})
if err != nil {
return err
}
destName := strings.ReplaceAll(moduleID, ":", "-")
infof(c.App.ErrWriter, "saving to %s", path.Join(c.String(packageFlagDestination), fullVersion, destName+".tar.gz"))
return downloadPackageFromURL(c.Context, client.authFlow.httpClient,
c.String(packageFlagDestination), destName,
fullVersion, pkg.Package.Url, client.conf.Auth,
)
}
10 changes: 10 additions & 0 deletions cli/packages.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ func (c *viamClient) packageExportAction(orgID, name, version, packageType, dest
if err := c.ensureLoggedIn(); err != nil {
return err
}
if orgID == "" || name == "" {
if orgID != "" || name != "" {
return fmt.Errorf("if either of %s or %s is missing, both must be missing", generalFlagOrgID, packageFlagName)
}
manifest, err := loadManifest(defaultManifestFilename)
if err != nil {
return errors.Wrap(err, "trying to get package ID from meta.json")
}
orgID, name, _ = strings.Cut(manifest.ModuleID, ":")
}
// Package ID is the <organization-ID>/<package-name>
packageID := path.Join(orgID, name)
packageTypeProto, err := convertPackageTypeToProto(packageType)
Expand Down
Loading

0 comments on commit e9def7b

Please sign in to comment.