Skip to content

Commit

Permalink
Add ability to list remote repository models via 'jmm models' command
Browse files Browse the repository at this point in the history
Add optional argument to jmm models:

  jmm models <reference>

to optionally list models in a remote repository. If <reference>
includes a tag, models only lists the model with that tag, if it exists.
  • Loading branch information
amisevsk committed Feb 14, 2024
1 parent dadcae1 commit a84051c
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 54 deletions.
83 changes: 37 additions & 46 deletions pkg/cmd/models/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@ package models

import (
"fmt"
"io/fs"
"jmm/pkg/lib/storage"
"os"
"path"
"path/filepath"
"text/tabwriter"

"github.com/spf13/cobra"
"github.com/spf13/viper"
"oras.land/oras-go/v2/registry"
)

const (
Expand All @@ -19,17 +17,33 @@ const (
)

var (
opts *ModelsOptions
flags *ModelsFlags
opts *ModelsOptions
)

type ModelsFlags struct {
UseHTTP bool
}

type ModelsOptions struct {
configHome string
storageHome string
remoteRef *registry.Reference
usehttp bool
}

func (opts *ModelsOptions) complete() {
func (opts *ModelsOptions) complete(flags *ModelsFlags, args []string) error {
opts.configHome = viper.GetString("config")
opts.storageHome = path.Join(opts.configHome, "storage")
if len(args) > 0 {
remoteRef, err := registry.ParseReference(args[0])
if err != nil {
return fmt.Errorf("invalid reference: %w", err)
}
opts.remoteRef = &remoteRef
}
opts.usehttp = flags.UseHTTP
return nil
}

func (opts *ModelsOptions) validate() error {
Expand All @@ -38,77 +52,54 @@ func (opts *ModelsOptions) validate() error {

// ModelsCommand represents the models command
func ModelsCommand() *cobra.Command {
flags = &ModelsFlags{}
opts = &ModelsOptions{}

cmd := &cobra.Command{
Use: "models",
Use: "models [repository]",
Short: shortDesc,
Long: longDesc,
Run: RunCommand(opts),
}

cmd.Args = cobra.MaximumNArgs(1)
cmd.Flags().BoolVar(&flags.UseHTTP, "http", false, "Use plain HTTP when connecting to remote registries")
return cmd
}

func RunCommand(options *ModelsOptions) func(*cobra.Command, []string) {
return func(cmd *cobra.Command, args []string) {
options.complete()
err := options.validate()
if err != nil {
fmt.Println(err)
if err := options.complete(flags, args); err != nil {
fmt.Printf("Failed to parse argument: %s", err)
return
}

storeDirs, err := findRepos(opts.storageHome)
if err != nil {
if err := options.validate(); err != nil {
fmt.Println(err)
return
}

var allInfoLines []string
for _, storeDir := range storeDirs {
store := storage.NewLocalStore(opts.storageHome, storeDir)

infolines, err := listModels(store)
if opts.remoteRef == nil {
lines, err := listLocalModels(opts.storageHome)
if err != nil {
fmt.Println(err)
return
}
allInfoLines = lines
} else {
lines, err := listRemoteModels(cmd.Context(), opts.remoteRef, opts.usehttp)
if err != nil {
fmt.Println(err)
return
}
allInfoLines = append(allInfoLines, infolines...)
allInfoLines = lines
}

printSummary(allInfoLines)

}
}

func findRepos(storePath string) ([]string, error) {
var indexPaths []string
err := filepath.WalkDir(storePath, func(file string, info fs.DirEntry, err error) error {
if err != nil {
return err
}
if info.Name() == "index.json" && !info.IsDir() {
dir := filepath.Dir(file)
relDir, err := filepath.Rel(storePath, dir)
if err != nil {
return err
}
if relDir == "." {
relDir = ""
}
indexPaths = append(indexPaths, relDir)
}
return nil
})
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, fmt.Errorf("failed to read local storage: %w", err)
}
return indexPaths, nil
}

func printSummary(lines []string) {
tw := tabwriter.NewWriter(os.Stdout, 0, 2, 3, ' ', 0)
fmt.Fprintln(tw, ModelsTableHeader)
Expand Down
59 changes: 53 additions & 6 deletions pkg/cmd/models/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ import (
"context"
"encoding/json"
"fmt"
"io/fs"
"jmm/pkg/artifact"
"jmm/pkg/lib/storage"
"math"
"os"
"path/filepath"
"strings"

ocispec "github.com/opencontainers/image-spec/specs-go/v1"
Expand All @@ -20,6 +23,25 @@ const (
ModelsTableFmt = "%s\t%s\t%s\t%s\t%s\t%s\t"
)

func listLocalModels(storageRoot string) ([]string, error) {
storeDirs, err := findRepos(storageRoot)
if err != nil {
return nil, err
}

var allInfoLines []string
for _, storeDir := range storeDirs {
store := storage.NewLocalStore(storageRoot, storeDir)

infolines, err := listModels(store)
if err != nil {
return nil, err
}
allInfoLines = append(allInfoLines, infolines...)
}
return allInfoLines, nil
}

func listModels(store storage.Store) ([]string, error) {
index, err := store.ParseIndexJson()
if err != nil {
Expand All @@ -37,10 +59,7 @@ func listModels(store storage.Store) ([]string, error) {
return nil, err
}
// TODO: filter list for our manifests only, ignore other artifacts
infoline, err := getManifestInfoLine(store.GetRepository(), manifestDesc, manifest, manifestConf)
if err != nil {
return nil, err
}
infoline := getManifestInfoLine(store.GetRepository(), manifestDesc, manifest, manifestConf)
infolines = append(infolines, infoline)
}

Expand Down Expand Up @@ -71,7 +90,7 @@ func readManifestConfig(store storage.Store, manifest *ocispec.Manifest) (*artif
return config, nil
}

func getManifestInfoLine(repo string, desc ocispec.Descriptor, manifest *ocispec.Manifest, config *artifact.JozuFile) (string, error) {
func getManifestInfoLine(repo string, desc ocispec.Descriptor, manifest *ocispec.Manifest, config *artifact.JozuFile) string {
ref := desc.Annotations[ocispec.AnnotationRefName]
if ref == "" {
ref = "<none>"
Expand All @@ -90,7 +109,35 @@ func getManifestInfoLine(repo string, desc ocispec.Descriptor, manifest *ocispec
sizeStr := formatBytes(size)

info := fmt.Sprintf(ModelsTableFmt, repo, ref, config.Package.Authors[0], config.Package.Name, sizeStr, desc.Digest)
return info, nil
return info
}

func findRepos(storePath string) ([]string, error) {
var indexPaths []string
err := filepath.WalkDir(storePath, func(file string, info fs.DirEntry, err error) error {
if err != nil {
return err
}
if info.Name() == "index.json" && !info.IsDir() {
dir := filepath.Dir(file)
relDir, err := filepath.Rel(storePath, dir)
if err != nil {
return err
}
if relDir == "." {
relDir = ""
}
indexPaths = append(indexPaths, relDir)
}
return nil
})
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, fmt.Errorf("failed to read local storage: %w", err)
}
return indexPaths, nil
}

func formatBytes(i int64) string {
Expand Down
96 changes: 96 additions & 0 deletions pkg/cmd/models/remote.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package models

import (
"context"
"encoding/json"
"fmt"
"io"
"jmm/pkg/artifact"

ocispec "github.com/opencontainers/image-spec/specs-go/v1"
"oras.land/oras-go/v2/registry"
"oras.land/oras-go/v2/registry/remote"
)

func listRemoteModels(ctx context.Context, remoteRef *registry.Reference, useHttp bool) ([]string, error) {
remoteRegistry, err := remote.NewRegistry(remoteRef.Registry)
if err != nil {
return nil, fmt.Errorf("failed to read registry: %w", err)
}
remoteRegistry.PlainHTTP = useHttp

repo, err := remoteRegistry.Repository(ctx, remoteRef.Repository)
if err != nil {
return nil, fmt.Errorf("failed to read repository: %w", err)
}
if remoteRef.Reference != "" {
return listImageTag(ctx, repo, remoteRef)
}
return listTags(ctx, repo, remoteRef)
}

func listTags(ctx context.Context, repo registry.Repository, ref *registry.Reference) ([]string, error) {
var tags []string
err := repo.Tags(ctx, "", func(tagsPage []string) error {
tags = append(tags, tagsPage...)
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to list tags on repostory: %w", err)
}

var allLines []string
for _, tag := range tags {
tagRef := &registry.Reference{
Registry: ref.Registry,
Repository: ref.Repository,
Reference: tag,
}
infoLines, err := listImageTag(ctx, repo, tagRef)
if err != nil {
return nil, err
}
allLines = append(allLines, infoLines...)
}

return allLines, nil
}

func listImageTag(ctx context.Context, repo registry.Repository, ref *registry.Reference) ([]string, error) {
manifestDesc, manifestReader, err := repo.FetchReference(ctx, ref.Reference)
if err != nil {
return nil, fmt.Errorf("failed to read reference: %w", err)
}
manifestBytes, err := io.ReadAll(manifestReader)
if err != nil {
return nil, fmt.Errorf("failed to read manifest: %w", err)
}
manifest := &ocispec.Manifest{}
if err := json.Unmarshal(manifestBytes, manifest); err != nil {
return nil, fmt.Errorf("failed to parse manifest: %w", err)
}

configReader, err := repo.Fetch(ctx, manifest.Config)
if err != nil {
return nil, fmt.Errorf("failed to read config reference: %w", err)
}
configBytes, err := io.ReadAll(configReader)
if err != nil {
return nil, fmt.Errorf("failed to read config: %w", err)
}
config := &artifact.JozuFile{}
if err := json.Unmarshal(configBytes, config); err != nil {
return nil, fmt.Errorf("failed to parse config: %w", err)
}

// Manifest descriptor may not have annotation for tag, add it here for safety
if _, ok := manifestDesc.Annotations[ocispec.AnnotationRefName]; !ok {
if manifestDesc.Annotations == nil {
manifestDesc.Annotations = map[string]string{}
}
manifestDesc.Annotations[ocispec.AnnotationRefName] = ref.Reference
}

info := getManifestInfoLine(ref.Repository, manifestDesc, manifest, config)
return []string{info}, nil
}
2 changes: 1 addition & 1 deletion pkg/cmd/pull/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func PullCommand() *cobra.Command {
}

cmd.Args = cobra.ExactArgs(1)
cmd.Flags().BoolVar(&flags.UseHTTP, "http", false, "Push to http registry")
cmd.Flags().BoolVar(&flags.UseHTTP, "http", false, "Use plain HTTP when connecting to remote registries")
return cmd
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/cmd/push/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func PushCommand() *cobra.Command {
}

cmd.Args = cobra.ExactArgs(1)
cmd.Flags().BoolVar(&flags.UseHTTP, "http", false, "Push to http registry")
cmd.Flags().BoolVar(&flags.UseHTTP, "http", false, "Use plain HTTP when connecting to remote registries")
return cmd
}

Expand Down

0 comments on commit a84051c

Please sign in to comment.