Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to list models in remote repository via jmm models #15

Merged
merged 1 commit into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 37 additions & 46 deletions pkg/cmd/models/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@ package models

import (
"fmt"
"io/fs"
"jmm/pkg/lib/storage"
"os"
"path"
"path/filepath"
"text/tabwriter"

"github.com/spf13/cobra"
"github.com/spf13/viper"
"oras.land/oras-go/v2/registry"
)

const (
Expand All @@ -19,17 +17,33 @@ const (
)

var (
opts *ModelsOptions
flags *ModelsFlags
opts *ModelsOptions
)

type ModelsFlags struct {
UseHTTP bool
}

type ModelsOptions struct {
configHome string
storageHome string
remoteRef *registry.Reference
usehttp bool
}

func (opts *ModelsOptions) complete() {
func (opts *ModelsOptions) complete(flags *ModelsFlags, args []string) error {
opts.configHome = viper.GetString("config")
opts.storageHome = path.Join(opts.configHome, "storage")
if len(args) > 0 {
remoteRef, err := registry.ParseReference(args[0])
if err != nil {
return fmt.Errorf("invalid reference: %w", err)
}
opts.remoteRef = &remoteRef
}
opts.usehttp = flags.UseHTTP
return nil
}

func (opts *ModelsOptions) validate() error {
Expand All @@ -38,77 +52,54 @@ func (opts *ModelsOptions) validate() error {

// ModelsCommand represents the models command
func ModelsCommand() *cobra.Command {
flags = &ModelsFlags{}
opts = &ModelsOptions{}

cmd := &cobra.Command{
Use: "models",
Use: "models [repository]",
Short: shortDesc,
Long: longDesc,
Run: RunCommand(opts),
}

cmd.Args = cobra.MaximumNArgs(1)
cmd.Flags().BoolVar(&flags.UseHTTP, "http", false, "Use plain HTTP when connecting to remote registries")
return cmd
}

func RunCommand(options *ModelsOptions) func(*cobra.Command, []string) {
return func(cmd *cobra.Command, args []string) {
options.complete()
err := options.validate()
if err != nil {
fmt.Println(err)
if err := options.complete(flags, args); err != nil {
fmt.Printf("Failed to parse argument: %s", err)
return
}

storeDirs, err := findRepos(opts.storageHome)
if err != nil {
if err := options.validate(); err != nil {
fmt.Println(err)
return
}

var allInfoLines []string
for _, storeDir := range storeDirs {
store := storage.NewLocalStore(opts.storageHome, storeDir)

infolines, err := listModels(store)
if opts.remoteRef == nil {
lines, err := listLocalModels(opts.storageHome)
if err != nil {
fmt.Println(err)
return
}
allInfoLines = lines
} else {
lines, err := listRemoteModels(cmd.Context(), opts.remoteRef, opts.usehttp)
if err != nil {
fmt.Println(err)
return
}
allInfoLines = append(allInfoLines, infolines...)
allInfoLines = lines
}

printSummary(allInfoLines)

}
}

func findRepos(storePath string) ([]string, error) {
var indexPaths []string
err := filepath.WalkDir(storePath, func(file string, info fs.DirEntry, err error) error {
if err != nil {
return err
}
if info.Name() == "index.json" && !info.IsDir() {
dir := filepath.Dir(file)
relDir, err := filepath.Rel(storePath, dir)
if err != nil {
return err
}
if relDir == "." {
relDir = ""
}
indexPaths = append(indexPaths, relDir)
}
return nil
})
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, fmt.Errorf("failed to read local storage: %w", err)
}
return indexPaths, nil
}

func printSummary(lines []string) {
tw := tabwriter.NewWriter(os.Stdout, 0, 2, 3, ' ', 0)
fmt.Fprintln(tw, ModelsTableHeader)
Expand Down
59 changes: 53 additions & 6 deletions pkg/cmd/models/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ import (
"context"
"encoding/json"
"fmt"
"io/fs"
"jmm/pkg/artifact"
"jmm/pkg/lib/storage"
"math"
"os"
"path/filepath"
"strings"

ocispec "github.com/opencontainers/image-spec/specs-go/v1"
Expand All @@ -20,6 +23,25 @@ const (
ModelsTableFmt = "%s\t%s\t%s\t%s\t%s\t%s\t"
)

func listLocalModels(storageRoot string) ([]string, error) {
storeDirs, err := findRepos(storageRoot)
if err != nil {
return nil, err
}

var allInfoLines []string
for _, storeDir := range storeDirs {
store := storage.NewLocalStore(storageRoot, storeDir)

infolines, err := listModels(store)
if err != nil {
return nil, err
}
allInfoLines = append(allInfoLines, infolines...)
}
return allInfoLines, nil
}

func listModels(store storage.Store) ([]string, error) {
index, err := store.ParseIndexJson()
if err != nil {
Expand All @@ -37,10 +59,7 @@ func listModels(store storage.Store) ([]string, error) {
return nil, err
}
// TODO: filter list for our manifests only, ignore other artifacts
infoline, err := getManifestInfoLine(store.GetRepository(), manifestDesc, manifest, manifestConf)
if err != nil {
return nil, err
}
infoline := getManifestInfoLine(store.GetRepository(), manifestDesc, manifest, manifestConf)
infolines = append(infolines, infoline)
}

Expand Down Expand Up @@ -71,7 +90,7 @@ func readManifestConfig(store storage.Store, manifest *ocispec.Manifest) (*artif
return config, nil
}

func getManifestInfoLine(repo string, desc ocispec.Descriptor, manifest *ocispec.Manifest, config *artifact.JozuFile) (string, error) {
func getManifestInfoLine(repo string, desc ocispec.Descriptor, manifest *ocispec.Manifest, config *artifact.JozuFile) string {
ref := desc.Annotations[ocispec.AnnotationRefName]
if ref == "" {
ref = "<none>"
Expand All @@ -90,7 +109,35 @@ func getManifestInfoLine(repo string, desc ocispec.Descriptor, manifest *ocispec
sizeStr := formatBytes(size)

info := fmt.Sprintf(ModelsTableFmt, repo, ref, config.Package.Authors[0], config.Package.Name, sizeStr, desc.Digest)
return info, nil
return info
}

func findRepos(storePath string) ([]string, error) {
var indexPaths []string
err := filepath.WalkDir(storePath, func(file string, info fs.DirEntry, err error) error {
if err != nil {
return err
}
if info.Name() == "index.json" && !info.IsDir() {
dir := filepath.Dir(file)
relDir, err := filepath.Rel(storePath, dir)
if err != nil {
return err
}
if relDir == "." {
relDir = ""
}
indexPaths = append(indexPaths, relDir)
}
return nil
})
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, fmt.Errorf("failed to read local storage: %w", err)
}
return indexPaths, nil
}

func formatBytes(i int64) string {
Expand Down
96 changes: 96 additions & 0 deletions pkg/cmd/models/remote.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package models

import (
"context"
"encoding/json"
"fmt"
"io"
"jmm/pkg/artifact"

ocispec "github.com/opencontainers/image-spec/specs-go/v1"
"oras.land/oras-go/v2/registry"
"oras.land/oras-go/v2/registry/remote"
)

func listRemoteModels(ctx context.Context, remoteRef *registry.Reference, useHttp bool) ([]string, error) {
remoteRegistry, err := remote.NewRegistry(remoteRef.Registry)
if err != nil {
return nil, fmt.Errorf("failed to read registry: %w", err)
}
remoteRegistry.PlainHTTP = useHttp

repo, err := remoteRegistry.Repository(ctx, remoteRef.Repository)
if err != nil {
return nil, fmt.Errorf("failed to read repository: %w", err)
}
if remoteRef.Reference != "" {
return listImageTag(ctx, repo, remoteRef)
}
return listTags(ctx, repo, remoteRef)
}

func listTags(ctx context.Context, repo registry.Repository, ref *registry.Reference) ([]string, error) {
var tags []string
err := repo.Tags(ctx, "", func(tagsPage []string) error {
tags = append(tags, tagsPage...)
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to list tags on repostory: %w", err)
}

var allLines []string
for _, tag := range tags {
tagRef := &registry.Reference{
Registry: ref.Registry,
Repository: ref.Repository,
Reference: tag,
}
infoLines, err := listImageTag(ctx, repo, tagRef)
if err != nil {
return nil, err
}
allLines = append(allLines, infoLines...)
}

return allLines, nil
}

func listImageTag(ctx context.Context, repo registry.Repository, ref *registry.Reference) ([]string, error) {
manifestDesc, manifestReader, err := repo.FetchReference(ctx, ref.Reference)
if err != nil {
return nil, fmt.Errorf("failed to read reference: %w", err)
}
manifestBytes, err := io.ReadAll(manifestReader)
if err != nil {
return nil, fmt.Errorf("failed to read manifest: %w", err)
}
manifest := &ocispec.Manifest{}
if err := json.Unmarshal(manifestBytes, manifest); err != nil {
return nil, fmt.Errorf("failed to parse manifest: %w", err)
}

configReader, err := repo.Fetch(ctx, manifest.Config)
if err != nil {
return nil, fmt.Errorf("failed to read config reference: %w", err)
}
configBytes, err := io.ReadAll(configReader)
if err != nil {
return nil, fmt.Errorf("failed to read config: %w", err)
}
config := &artifact.JozuFile{}
if err := json.Unmarshal(configBytes, config); err != nil {
return nil, fmt.Errorf("failed to parse config: %w", err)
}

// Manifest descriptor may not have annotation for tag, add it here for safety
if _, ok := manifestDesc.Annotations[ocispec.AnnotationRefName]; !ok {
if manifestDesc.Annotations == nil {
manifestDesc.Annotations = map[string]string{}
}
manifestDesc.Annotations[ocispec.AnnotationRefName] = ref.Reference
}

info := getManifestInfoLine(ref.Repository, manifestDesc, manifest, config)
return []string{info}, nil
}
2 changes: 1 addition & 1 deletion pkg/cmd/pull/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func PullCommand() *cobra.Command {
}

cmd.Args = cobra.ExactArgs(1)
cmd.Flags().BoolVar(&flags.UseHTTP, "http", false, "Push to http registry")
cmd.Flags().BoolVar(&flags.UseHTTP, "http", false, "Use plain HTTP when connecting to remote registries")
return cmd
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/cmd/push/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func PushCommand() *cobra.Command {
}

cmd.Args = cobra.ExactArgs(1)
cmd.Flags().BoolVar(&flags.UseHTTP, "http", false, "Push to http registry")
cmd.Flags().BoolVar(&flags.UseHTTP, "http", false, "Use plain HTTP when connecting to remote registries")
return cmd
}

Expand Down