diff --git a/pkg/cmd/models/cmd.go b/pkg/cmd/models/cmd.go index 492a501d..1ee8f450 100644 --- a/pkg/cmd/models/cmd.go +++ b/pkg/cmd/models/cmd.go @@ -2,15 +2,13 @@ package models import ( "fmt" - "io/fs" - "jmm/pkg/lib/storage" "os" "path" - "path/filepath" "text/tabwriter" "github.com/spf13/cobra" "github.com/spf13/viper" + "oras.land/oras-go/v2/registry" ) const ( @@ -19,17 +17,33 @@ const ( ) var ( - opts *ModelsOptions + flags *ModelsFlags + opts *ModelsOptions ) +type ModelsFlags struct { + UseHTTP bool +} + type ModelsOptions struct { configHome string storageHome string + remoteRef *registry.Reference + usehttp bool } -func (opts *ModelsOptions) complete() { +func (opts *ModelsOptions) complete(flags *ModelsFlags, args []string) error { opts.configHome = viper.GetString("config") opts.storageHome = path.Join(opts.configHome, "storage") + if len(args) > 0 { + remoteRef, err := registry.ParseReference(args[0]) + if err != nil { + return fmt.Errorf("invalid reference: %w", err) + } + opts.remoteRef = &remoteRef + } + opts.usehttp = flags.UseHTTP + return nil } func (opts *ModelsOptions) validate() error { @@ -38,42 +52,47 @@ func (opts *ModelsOptions) validate() error { // ModelsCommand represents the models command func ModelsCommand() *cobra.Command { + flags = &ModelsFlags{} opts = &ModelsOptions{} cmd := &cobra.Command{ - Use: "models", + Use: "models [repository]", Short: shortDesc, Long: longDesc, Run: RunCommand(opts), } + cmd.Args = cobra.MaximumNArgs(1) + cmd.Flags().BoolVar(&flags.UseHTTP, "http", false, "Use plain HTTP when connecting to remote registries") return cmd } func RunCommand(options *ModelsOptions) func(*cobra.Command, []string) { return func(cmd *cobra.Command, args []string) { - options.complete() - err := options.validate() - if err != nil { - fmt.Println(err) + if err := options.complete(flags, args); err != nil { + fmt.Printf("Failed to parse argument: %s", err) return } - - storeDirs, err := findRepos(opts.storageHome) - if err != nil { + if err := options.validate(); err != nil { fmt.Println(err) + return } var allInfoLines []string - for _, storeDir := range storeDirs { - store := storage.NewLocalStore(opts.storageHome, storeDir) - - infolines, err := listModels(store) + if opts.remoteRef == nil { + lines, err := listLocalModels(opts.storageHome) + if err != nil { + fmt.Println(err) + return + } + allInfoLines = lines + } else { + lines, err := listRemoteModels(cmd.Context(), opts.remoteRef, opts.usehttp) if err != nil { fmt.Println(err) return } - allInfoLines = append(allInfoLines, infolines...) + allInfoLines = lines } printSummary(allInfoLines) @@ -81,34 +100,6 @@ func RunCommand(options *ModelsOptions) func(*cobra.Command, []string) { } } -func findRepos(storePath string) ([]string, error) { - var indexPaths []string - err := filepath.WalkDir(storePath, func(file string, info fs.DirEntry, err error) error { - if err != nil { - return err - } - if info.Name() == "index.json" && !info.IsDir() { - dir := filepath.Dir(file) - relDir, err := filepath.Rel(storePath, dir) - if err != nil { - return err - } - if relDir == "." { - relDir = "" - } - indexPaths = append(indexPaths, relDir) - } - return nil - }) - if err != nil { - if os.IsNotExist(err) { - return nil, nil - } - return nil, fmt.Errorf("failed to read local storage: %w", err) - } - return indexPaths, nil -} - func printSummary(lines []string) { tw := tabwriter.NewWriter(os.Stdout, 0, 2, 3, ' ', 0) fmt.Fprintln(tw, ModelsTableHeader) diff --git a/pkg/cmd/models/models.go b/pkg/cmd/models/models.go index 718ea91d..4a17b1ab 100644 --- a/pkg/cmd/models/models.go +++ b/pkg/cmd/models/models.go @@ -7,9 +7,12 @@ import ( "context" "encoding/json" "fmt" + "io/fs" "jmm/pkg/artifact" "jmm/pkg/lib/storage" "math" + "os" + "path/filepath" "strings" ocispec "github.com/opencontainers/image-spec/specs-go/v1" @@ -20,6 +23,25 @@ const ( ModelsTableFmt = "%s\t%s\t%s\t%s\t%s\t%s\t" ) +func listLocalModels(storageRoot string) ([]string, error) { + storeDirs, err := findRepos(storageRoot) + if err != nil { + return nil, err + } + + var allInfoLines []string + for _, storeDir := range storeDirs { + store := storage.NewLocalStore(storageRoot, storeDir) + + infolines, err := listModels(store) + if err != nil { + return nil, err + } + allInfoLines = append(allInfoLines, infolines...) + } + return allInfoLines, nil +} + func listModels(store storage.Store) ([]string, error) { index, err := store.ParseIndexJson() if err != nil { @@ -37,10 +59,7 @@ func listModels(store storage.Store) ([]string, error) { return nil, err } // TODO: filter list for our manifests only, ignore other artifacts - infoline, err := getManifestInfoLine(store.GetRepository(), manifestDesc, manifest, manifestConf) - if err != nil { - return nil, err - } + infoline := getManifestInfoLine(store.GetRepository(), manifestDesc, manifest, manifestConf) infolines = append(infolines, infoline) } @@ -71,7 +90,7 @@ func readManifestConfig(store storage.Store, manifest *ocispec.Manifest) (*artif return config, nil } -func getManifestInfoLine(repo string, desc ocispec.Descriptor, manifest *ocispec.Manifest, config *artifact.JozuFile) (string, error) { +func getManifestInfoLine(repo string, desc ocispec.Descriptor, manifest *ocispec.Manifest, config *artifact.JozuFile) string { ref := desc.Annotations[ocispec.AnnotationRefName] if ref == "" { ref = "" @@ -90,7 +109,35 @@ func getManifestInfoLine(repo string, desc ocispec.Descriptor, manifest *ocispec sizeStr := formatBytes(size) info := fmt.Sprintf(ModelsTableFmt, repo, ref, config.Package.Authors[0], config.Package.Name, sizeStr, desc.Digest) - return info, nil + return info +} + +func findRepos(storePath string) ([]string, error) { + var indexPaths []string + err := filepath.WalkDir(storePath, func(file string, info fs.DirEntry, err error) error { + if err != nil { + return err + } + if info.Name() == "index.json" && !info.IsDir() { + dir := filepath.Dir(file) + relDir, err := filepath.Rel(storePath, dir) + if err != nil { + return err + } + if relDir == "." { + relDir = "" + } + indexPaths = append(indexPaths, relDir) + } + return nil + }) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("failed to read local storage: %w", err) + } + return indexPaths, nil } func formatBytes(i int64) string { diff --git a/pkg/cmd/models/remote.go b/pkg/cmd/models/remote.go new file mode 100644 index 00000000..eaac0880 --- /dev/null +++ b/pkg/cmd/models/remote.go @@ -0,0 +1,96 @@ +package models + +import ( + "context" + "encoding/json" + "fmt" + "io" + "jmm/pkg/artifact" + + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "oras.land/oras-go/v2/registry" + "oras.land/oras-go/v2/registry/remote" +) + +func listRemoteModels(ctx context.Context, remoteRef *registry.Reference, useHttp bool) ([]string, error) { + remoteRegistry, err := remote.NewRegistry(remoteRef.Registry) + if err != nil { + return nil, fmt.Errorf("failed to read registry: %w", err) + } + remoteRegistry.PlainHTTP = useHttp + + repo, err := remoteRegistry.Repository(ctx, remoteRef.Repository) + if err != nil { + return nil, fmt.Errorf("failed to read repository: %w", err) + } + if remoteRef.Reference != "" { + return listImageTag(ctx, repo, remoteRef) + } + return listTags(ctx, repo, remoteRef) +} + +func listTags(ctx context.Context, repo registry.Repository, ref *registry.Reference) ([]string, error) { + var tags []string + err := repo.Tags(ctx, "", func(tagsPage []string) error { + tags = append(tags, tagsPage...) + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to list tags on repostory: %w", err) + } + + var allLines []string + for _, tag := range tags { + tagRef := ®istry.Reference{ + Registry: ref.Registry, + Repository: ref.Repository, + Reference: tag, + } + infoLines, err := listImageTag(ctx, repo, tagRef) + if err != nil { + return nil, err + } + allLines = append(allLines, infoLines...) + } + + return allLines, nil +} + +func listImageTag(ctx context.Context, repo registry.Repository, ref *registry.Reference) ([]string, error) { + manifestDesc, manifestReader, err := repo.FetchReference(ctx, ref.Reference) + if err != nil { + return nil, fmt.Errorf("failed to read reference: %w", err) + } + manifestBytes, err := io.ReadAll(manifestReader) + if err != nil { + return nil, fmt.Errorf("failed to read manifest: %w", err) + } + manifest := &ocispec.Manifest{} + if err := json.Unmarshal(manifestBytes, manifest); err != nil { + return nil, fmt.Errorf("failed to parse manifest: %w", err) + } + + configReader, err := repo.Fetch(ctx, manifest.Config) + if err != nil { + return nil, fmt.Errorf("failed to read config reference: %w", err) + } + configBytes, err := io.ReadAll(configReader) + if err != nil { + return nil, fmt.Errorf("failed to read config: %w", err) + } + config := &artifact.JozuFile{} + if err := json.Unmarshal(configBytes, config); err != nil { + return nil, fmt.Errorf("failed to parse config: %w", err) + } + + // Manifest descriptor may not have annotation for tag, add it here for safety + if _, ok := manifestDesc.Annotations[ocispec.AnnotationRefName]; !ok { + if manifestDesc.Annotations == nil { + manifestDesc.Annotations = map[string]string{} + } + manifestDesc.Annotations[ocispec.AnnotationRefName] = ref.Reference + } + + info := getManifestInfoLine(ref.Repository, manifestDesc, manifest, config) + return []string{info}, nil +} diff --git a/pkg/cmd/pull/cmd.go b/pkg/cmd/pull/cmd.go index 3dbb176c..70ba9568 100644 --- a/pkg/cmd/pull/cmd.go +++ b/pkg/cmd/pull/cmd.go @@ -60,7 +60,7 @@ func PullCommand() *cobra.Command { } cmd.Args = cobra.ExactArgs(1) - cmd.Flags().BoolVar(&flags.UseHTTP, "http", false, "Push to http registry") + cmd.Flags().BoolVar(&flags.UseHTTP, "http", false, "Use plain HTTP when connecting to remote registries") return cmd } diff --git a/pkg/cmd/push/cmd.go b/pkg/cmd/push/cmd.go index 536c244d..9efc5f3e 100644 --- a/pkg/cmd/push/cmd.go +++ b/pkg/cmd/push/cmd.go @@ -60,7 +60,7 @@ func PushCommand() *cobra.Command { } cmd.Args = cobra.ExactArgs(1) - cmd.Flags().BoolVar(&flags.UseHTTP, "http", false, "Push to http registry") + cmd.Flags().BoolVar(&flags.UseHTTP, "http", false, "Use plain HTTP when connecting to remote registries") return cmd }