diff --git a/cmd/root.go b/cmd/root.go index d88dd7c0..783e625d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -5,6 +5,7 @@ package cmd import ( "jmm/pkg/cmd/build" + "jmm/pkg/cmd/export" "jmm/pkg/cmd/login" "jmm/pkg/cmd/models" "jmm/pkg/cmd/pull" @@ -42,6 +43,7 @@ func init() { rootCmd.AddCommand(pull.PullCommand()) rootCmd.AddCommand(push.PushCommand()) rootCmd.AddCommand(models.ModelsCommand()) + rootCmd.AddCommand(export.ExportCommand()) rootCmd.AddCommand(version.NewCmdVersion()) } diff --git a/pkg/artifact/model-layer.go b/pkg/artifact/model-layer.go index b6dd5003..d1365bdc 100644 --- a/pkg/artifact/model-layer.go +++ b/pkg/artifact/model-layer.go @@ -7,26 +7,16 @@ import ( "os" "path/filepath" "strings" - - ocispec "github.com/opencontainers/image-spec/specs-go/v1" ) type ModelLayer struct { - contextDir string - MediaType string - Descriptor ocispec.Descriptor -} - -func NewLayer(rootpath string, mediaType string) *ModelLayer { - return &ModelLayer{ - contextDir: rootpath, - MediaType: mediaType, - } + BaseDir string + MediaType string } func (layer *ModelLayer) Apply(writers ...io.Writer) error { // Check if path exists - _, err := os.Stat(layer.contextDir) + _, err := os.Stat(layer.BaseDir) if err != nil { return err } @@ -40,46 +30,39 @@ func (layer *ModelLayer) Apply(writers ...io.Writer) error { defer tw.Close() // walk the context dir and tar everything - err = filepath.Walk(layer.contextDir, func(file string, fi os.FileInfo, err error) error { - + err = filepath.Walk(layer.BaseDir, func(file string, fi os.FileInfo, err error) error { if err != nil { return err } - - if !fi.Mode().IsRegular() { + // Skip anything that's not a regular file or directory + if !fi.Mode().IsRegular() && !fi.Mode().IsDir() { + return nil + } + // Skip the baseDir itself + if file == layer.BaseDir { return nil } - // create a new dir/file header header, err := tar.FileInfoHeader(fi, fi.Name()) if err != nil { return err } - parentDir := filepath.Dir(layer.contextDir) - - // update the name to correctly reflect the desired destination when untaring - header.Name = strings.TrimPrefix( - strings.Replace(file, parentDir, "", -1), string(filepath.Separator)) + // We want the path in the tarball to be relative to the layer's base directory + subPath := strings.TrimPrefix(strings.Replace(file, layer.BaseDir, "", -1), string(filepath.Separator)) + header.Name = subPath - // write the header if err := tw.WriteHeader(header); err != nil { return err } - // open files for taring - f, err := os.Open(file) - if err != nil { - return err - } - - // copy file data into tar writer - if _, err := io.Copy(tw, f); err != nil { - return err + if fi.Mode().IsRegular() { + err := writeFileToTar(file, tw) + if err != nil { + return err + } } - f.Close() - return nil }) @@ -88,3 +71,16 @@ func (layer *ModelLayer) Apply(writers ...io.Writer) error { } return nil } + +func writeFileToTar(file string, tw *tar.Writer) error { + f, err := os.Open(file) + if err != nil { + return err + } + defer f.Close() + + if _, err := io.Copy(tw, f); err != nil { + return err + } + return nil +} diff --git a/pkg/cmd/build/build.go b/pkg/cmd/build/build.go index 8c330c35..1f18085b 100644 --- a/pkg/cmd/build/build.go +++ b/pkg/cmd/build/build.go @@ -7,10 +7,10 @@ import ( "fmt" "os" "path" - "path/filepath" "jmm/pkg/artifact" "jmm/pkg/lib/constants" + "jmm/pkg/lib/filesystem" "jmm/pkg/lib/storage" "github.com/spf13/cobra" @@ -18,10 +18,6 @@ import ( "oras.land/oras-go/v2/registry" ) -const ( - DEFAULT_MODEL_FILE = "Jozufile" -) - var ( shortDesc = `Build a model` longDesc = `A longer description that spans multiple lines and likely contains examples @@ -83,7 +79,7 @@ func NewCmdBuild() *cobra.Command { func (options *BuildOptions) Complete(cmd *cobra.Command, argsIn []string) error { options.ContextDir = argsIn[0] if options.ModelFile == "" { - options.ModelFile = options.ContextDir + "/" + DEFAULT_MODEL_FILE + options.ModelFile = path.Join(options.ContextDir, constants.DefaultModelFileName) } options.configHome = viper.GetString("config") fmt.Println("config: ", options.configHome) @@ -114,31 +110,39 @@ func (options *BuildOptions) RunBuild() error { // 2. package the Code for _, code := range jozufile.Code { - codePath, err := toAbsPath(options.ContextDir, code.Path) + codePath, err := filesystem.VerifySubpath(options.ContextDir, code.Path) if err != nil { return err } - layer := artifact.NewLayer(codePath, constants.CodeLayerMediaType) + layer := &artifact.ModelLayer{ + BaseDir: codePath, + MediaType: constants.CodeLayerMediaType, + } model.Layers = append(model.Layers, *layer) } // 3. package the DataSets - datasetPath := "" for _, dataset := range jozufile.DataSets { - datasetPath, err = toAbsPath(options.ContextDir, dataset.Path) + datasetPath, err := filesystem.VerifySubpath(options.ContextDir, dataset.Path) if err != nil { return err } - layer := artifact.NewLayer(datasetPath, constants.DataSetLayerMediaType) + layer := &artifact.ModelLayer{ + BaseDir: datasetPath, + MediaType: constants.DataSetLayerMediaType, + } model.Layers = append(model.Layers, *layer) } // 4. package the TrainedModels for _, trainedModel := range jozufile.Models { - modelPath, err := toAbsPath(options.ContextDir, trainedModel.Path) + modelPath, err := filesystem.VerifySubpath(options.ContextDir, trainedModel.Path) if err != nil { return err } - layer := artifact.NewLayer(modelPath, constants.ModelLayerMediaType) + layer := &artifact.ModelLayer{ + BaseDir: modelPath, + MediaType: constants.ModelLayerMediaType, + } model.Layers = append(model.Layers, *layer) } @@ -192,16 +196,3 @@ func (flags *BuildFlags) AddFlags(cmd *cobra.Command) { func NewBuildFlags() *BuildFlags { return &BuildFlags{} } -func toAbsPath(context string, relativePath string) (string, error) { - - absContext, err := filepath.Abs(context) - if err != nil { - fmt.Println("Error resolving base path:", err) - return "", err - } - combinedPath := filepath.Join(absContext, relativePath) - - cleanPath := filepath.Clean(combinedPath) - return cleanPath, nil - -} diff --git a/pkg/cmd/build/build_test.go b/pkg/cmd/build/build_test.go index f9c46345..ae7c8bcf 100644 --- a/pkg/cmd/build/build_test.go +++ b/pkg/cmd/build/build_test.go @@ -1,6 +1,8 @@ package build import ( + "jmm/pkg/lib/constants" + "path" "testing" "github.com/spf13/cobra" @@ -26,7 +28,7 @@ func TestBuildOptions_Complete(t *testing.T) { assert.NoError(t, err) assert.Equal(t, args[0], options.ContextDir) - assert.Equal(t, options.ContextDir+"/"+DEFAULT_MODEL_FILE, options.ModelFile) + assert.Equal(t, path.Join(options.ContextDir, constants.DefaultModelFileName), options.ModelFile) } func TestBuildOptions_Validate(t *testing.T) { diff --git a/pkg/cmd/export/cmd.go b/pkg/cmd/export/cmd.go new file mode 100644 index 00000000..5f91337c --- /dev/null +++ b/pkg/cmd/export/cmd.go @@ -0,0 +1,175 @@ +package export + +import ( + "context" + "errors" + "fmt" + "jmm/pkg/lib/storage" + + "github.com/spf13/cobra" + "github.com/spf13/viper" + "oras.land/oras-go/v2" + "oras.land/oras-go/v2/content/oci" + "oras.land/oras-go/v2/errdef" + "oras.land/oras-go/v2/registry" + "oras.land/oras-go/v2/registry/remote" +) + +const ( + shortDesc = `Export model from registry` + longDesc = `Export model from registry TODO` +) + +var ( + flags *ExportFlags + opts *ExportOptions +) + +type ExportFlags struct { + Overwrite bool + UseHTTP bool + ExportConfig bool + ExportModels bool + ExportDatasets bool + ExportCode bool + ExportDir string +} + +type ExportOptions struct { + configHome string + storageHome string + exportDir string + overwrite bool + exportConf ExportConf + modelRef *registry.Reference + usehttp bool +} + +type ExportConf struct { + ExportConfig bool + ExportModels bool + ExportCode bool + ExportDatasets bool +} + +func (opts *ExportOptions) complete(args []string) error { + opts.configHome = viper.GetString("config") + opts.storageHome = storage.StorageHome(opts.configHome) + modelRef, extraTags, err := storage.ParseReference(args[0]) + if err != nil { + return fmt.Errorf("failed to parse reference %s: %w", args[0], err) + } + if len(extraTags) > 0 { + return fmt.Errorf("can not export multiple tags") + } + opts.modelRef = modelRef + opts.overwrite = flags.Overwrite + opts.usehttp = flags.UseHTTP + opts.exportDir = flags.ExportDir + + if !flags.ExportConfig && !flags.ExportModels && !flags.ExportCode && !flags.ExportDatasets { + opts.exportConf.ExportConfig = true + opts.exportConf.ExportModels = true + opts.exportConf.ExportCode = true + opts.exportConf.ExportDatasets = true + } else { + opts.exportConf.ExportConfig = flags.ExportConfig + opts.exportConf.ExportModels = flags.ExportModels + opts.exportConf.ExportCode = flags.ExportCode + opts.exportConf.ExportDatasets = flags.ExportDatasets + } + + return nil +} + +func (opts *ExportOptions) validate() error { + return nil +} + +func ExportCommand() *cobra.Command { + opts = &ExportOptions{} + flags = &ExportFlags{} + + cmd := &cobra.Command{ + Use: "export", + Short: shortDesc, + Long: longDesc, + Run: runCommand(opts), + } + + cmd.Args = cobra.ExactArgs(1) + cmd.Flags().StringVarP(&flags.ExportDir, "dir", "d", "", "Directory to export into. Will be created if it does not exist") + cmd.Flags().BoolVarP(&flags.Overwrite, "overwrite", "o", false, "Overwrite existing files and directories in the export dir") + cmd.Flags().BoolVar(&flags.ExportConfig, "config", false, "Export only config file") + cmd.Flags().BoolVar(&flags.ExportModels, "models", false, "Export only models") + cmd.Flags().BoolVar(&flags.ExportCode, "code", false, "Export only code") + cmd.Flags().BoolVar(&flags.ExportDatasets, "datasets", false, "Export only datasets") + cmd.Flags().BoolVar(&flags.UseHTTP, "http", false, "Use plain HTTP when connecting to remote registries") + + return cmd +} + +func runCommand(opts *ExportOptions) func(*cobra.Command, []string) { + return func(cmd *cobra.Command, args []string) { + if err := opts.complete(args); err != nil { + fmt.Printf("Failed to process arguments: %s", err) + return + } + err := opts.validate() + if err != nil { + fmt.Println(err) + return + } + + store, err := getStoreForRef(cmd.Context(), opts) + if err != nil { + fmt.Println(err) + return + } + + exportTo := opts.exportDir + if exportTo == "" { + exportTo = "current directory" + } + fmt.Printf("Exporting to %s\n", exportTo) + err = ExportModel(cmd.Context(), store, opts.modelRef, opts) + if err != nil { + fmt.Println(err) + return + } + } +} + +func getStoreForRef(ctx context.Context, opts *ExportOptions) (oras.Target, error) { + localStore, err := oci.New(storage.LocalStorePath(opts.storageHome, opts.modelRef)) + if err != nil { + return nil, fmt.Errorf("failed to read local storage: %s\n", err) + } + + if _, err := localStore.Resolve(ctx, opts.modelRef.Reference); err == nil { + // Reference is present in local storage + return localStore, nil + } + + // Not in local storage, check remote + remoteRegistry, err := remote.NewRegistry(opts.modelRef.Registry) + if err != nil { + return nil, fmt.Errorf("could not resolve registry %s: %w", opts.modelRef.Registry, err) + } + if opts.usehttp { + remoteRegistry.PlainHTTP = true + } + + repo, err := remoteRegistry.Repository(ctx, opts.modelRef.Repository) + if err != nil { + return nil, fmt.Errorf("could not resolve repository %s in registry %s", opts.modelRef.Repository, opts.modelRef.Registry) + } + if _, err := repo.Resolve(ctx, opts.modelRef.Reference); err != nil { + if errors.Is(err, errdef.ErrNotFound) { + return nil, fmt.Errorf("reference %s is not present in local storage and could not be found in remote", opts.modelRef.String()) + } + return nil, fmt.Errorf("unexpected error retrieving reference from remote: %w", err) + } + + return repo, nil +} diff --git a/pkg/cmd/export/export.go b/pkg/cmd/export/export.go new file mode 100644 index 00000000..aabc31cf --- /dev/null +++ b/pkg/cmd/export/export.go @@ -0,0 +1,190 @@ +package export + +import ( + "archive/tar" + "compress/gzip" + "context" + "fmt" + "io" + "jmm/pkg/artifact" + "jmm/pkg/lib/constants" + "jmm/pkg/lib/filesystem" + "jmm/pkg/lib/repo" + "os" + "path" + "path/filepath" + + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "oras.land/oras-go/v2" + "oras.land/oras-go/v2/content" + "oras.land/oras-go/v2/registry" + "sigs.k8s.io/yaml" +) + +func ExportModel(ctx context.Context, store oras.Target, ref *registry.Reference, options *ExportOptions) error { + manifestDesc, err := store.Resolve(ctx, ref.Reference) + if err != nil { + return fmt.Errorf("failed to resolve local reference: %w", err) + } + manifest, config, err := repo.GetManifestAndConfig(ctx, store, manifestDesc) + if err != nil { + return fmt.Errorf("failed to read local model: %s", err) + } + + if options.exportConf.ExportConfig { + if err := ExportConfig(config, options.exportDir, options.overwrite); err != nil { + return err + } + } + + // Since there might be multiple models, etc. we need to synchronously iterate + // through the config's relevant field to get the correct path for exporting + var modelIdx, codeIdx, datasetIdx int + for _, layerDesc := range manifest.Layers { + layerDir := "" + switch layerDesc.MediaType { + case constants.ModelLayerMediaType: + if !options.exportConf.ExportModels { + continue + } + modelEntry := config.Models[modelIdx] + layerDir = filepath.Join(options.exportDir, modelEntry.Path) + fmt.Printf("Exporting model %s to %s\n", modelEntry.Name, layerDir) + modelIdx += 1 + + case constants.CodeLayerMediaType: + if !options.exportConf.ExportCode { + continue + } + codeEntry := config.Code[codeIdx] + layerDir = filepath.Join(options.exportDir, codeEntry.Path) + fmt.Printf("Exporting code to %s\n", layerDir) + codeIdx += 1 + + case constants.DataSetLayerMediaType: + if !options.exportConf.ExportDatasets { + continue + } + datasetEntry := config.DataSets[datasetIdx] + layerDir = filepath.Join(options.exportDir, datasetEntry.Path) + fmt.Printf("Exporting dataset %s to %s\n", datasetEntry.Name, layerDir) + datasetIdx += 1 + } + if _, err := filesystem.VerifySubpath(options.exportDir, layerDir); err != nil { + return err + } + if err := ExportLayer(ctx, store, layerDesc, layerDir, options.overwrite); err != nil { + return err + } + } + + return nil +} + +func ExportConfig(config *artifact.JozuFile, exportDir string, overwrite bool) error { + configPath := path.Join(exportDir, constants.DefaultModelFileName) + if fi, exists := filesystem.PathExists(configPath); exists { + if !overwrite { + return fmt.Errorf("failed to export config: path %s already exists", exportDir) + } else if !fi.Mode().IsRegular() { + return fmt.Errorf("failed to export config: path %s exists and is not a regular file", exportDir) + } + } + + configBytes, err := yaml.Marshal(config) + if err != nil { + return fmt.Errorf("failed to export config: %w", err) + } + + fmt.Printf("Exporting config to %s\n", configPath) + if err := os.WriteFile(configPath, configBytes, 0644); err != nil { + return fmt.Errorf("failed to write config file: %w", err) + } + return nil +} + +func ExportLayer(ctx context.Context, store content.Storage, desc ocispec.Descriptor, exportDir string, overwrite bool) error { + rc, err := store.Fetch(ctx, desc) + if err != nil { + return fmt.Errorf("failed get layer %s: %w", desc.Digest, err) + } + defer rc.Close() + + gzr, err := gzip.NewReader(rc) + if err != nil { + return fmt.Errorf("error extracting gzipped file: %w", err) + } + defer gzr.Close() + tr := tar.NewReader(gzr) + + if fi, exists := filesystem.PathExists(exportDir); exists { + if !overwrite { + return fmt.Errorf("failed to export: path %s already exists", exportDir) + } else if !fi.IsDir() { + return fmt.Errorf("failed to export: path %s exists and is not a directory", exportDir) + } + } + if err := os.MkdirAll(exportDir, 0755); err != nil { + return fmt.Errorf("failed to create directory %s: %w", exportDir, err) + } + + return extractTar(tr, exportDir, overwrite) +} + +func extractTar(tr *tar.Reader, dir string, overwrite bool) error { + for { + header, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return err + } + outPath := path.Join(dir, header.Name) + fmt.Printf("Extracting %s\n", outPath) + + switch header.Typeflag { + case tar.TypeDir: + if fi, exists := filesystem.PathExists(outPath); exists { + if !overwrite { + return fmt.Errorf("path '%s' already exists", outPath) + } + if !fi.IsDir() { + return fmt.Errorf("path '%s' already exists and is not a directory", outPath) + } + } + fmt.Printf("Creating directory %s\n", outPath) + if err := os.MkdirAll(outPath, header.FileInfo().Mode()); err != nil { + return fmt.Errorf("failed to create directory %s: %w", outPath, err) + } + + case tar.TypeReg: + if fi, exists := filesystem.PathExists(outPath); exists { + if !overwrite { + return fmt.Errorf("path '%s' already exists", outPath) + } + if !fi.Mode().IsRegular() { + return fmt.Errorf("path '%s' already exists and is not a regular file", outPath) + } + } + fmt.Printf("Extracting file %s\n", outPath) + file, err := os.OpenFile(outPath, os.O_TRUNC|os.O_RDWR|os.O_EXCL, header.FileInfo().Mode()) + if err != nil { + return fmt.Errorf("failed to create file %s: %w", outPath, err) + } + defer file.Close() + + written, err := io.Copy(file, tr) + if err != nil { + return fmt.Errorf("failed to write file %s: %w", outPath, err) + } + if written != header.Size { + return fmt.Errorf("could not extract file %s", outPath) + } + + default: + return fmt.Errorf("Unrecognized type in archive: %s", header.Name) + } + } + return nil +} diff --git a/pkg/cmd/pull/cmd.go b/pkg/cmd/pull/cmd.go index 96471f68..1a9dea43 100644 --- a/pkg/cmd/pull/cmd.go +++ b/pkg/cmd/pull/cmd.go @@ -97,7 +97,7 @@ func runCommand(opts *PullOptions) func(*cobra.Command, []string) { } fmt.Printf("Pulling %s\n", opts.modelRef.String()) - desc, err := doPull(cmd.Context(), remoteRegistry, localStore, opts.modelRef) + desc, err := PullModel(cmd.Context(), remoteRegistry, localStore, opts.modelRef) if err != nil { fmt.Printf("Failed to pull: %s\n", err) return diff --git a/pkg/cmd/pull/pull.go b/pkg/cmd/pull/pull.go index a3bc100c..ac840cc9 100644 --- a/pkg/cmd/pull/pull.go +++ b/pkg/cmd/pull/pull.go @@ -17,7 +17,7 @@ import ( "oras.land/oras-go/v2/registry/remote" ) -func doPull(ctx context.Context, remoteRegistry *remote.Registry, localStore *oci.Store, ref *registry.Reference) (ocispec.Descriptor, error) { +func PullModel(ctx context.Context, remoteRegistry *remote.Registry, localStore *oci.Store, ref *registry.Reference) (ocispec.Descriptor, error) { repo, err := remoteRegistry.Repository(ctx, ref.Repository) if err != nil { return ocispec.DescriptorEmptyJSON, fmt.Errorf("failed to read repository: %w", err) diff --git a/pkg/cmd/push/cmd.go b/pkg/cmd/push/cmd.go index 66a2928a..fabbd842 100644 --- a/pkg/cmd/push/cmd.go +++ b/pkg/cmd/push/cmd.go @@ -97,7 +97,7 @@ func runCommand(opts *PushOptions) func(*cobra.Command, []string) { } fmt.Printf("Pushing %s\n", opts.modelRef.String()) - desc, err := doPush(cmd.Context(), localStore, remoteRegistry, opts.modelRef) + desc, err := PushModel(cmd.Context(), localStore, remoteRegistry, opts.modelRef) if err != nil { fmt.Printf("Failed to push: %s\n", err) return diff --git a/pkg/cmd/push/push.go b/pkg/cmd/push/push.go index 4763aad7..d47dd91f 100644 --- a/pkg/cmd/push/push.go +++ b/pkg/cmd/push/push.go @@ -14,7 +14,7 @@ import ( "oras.land/oras-go/v2/registry/remote" ) -func doPush(ctx context.Context, localStore *oci.Store, remoteRegistry *remote.Registry, ref *registry.Reference) (ocispec.Descriptor, error) { +func PushModel(ctx context.Context, localStore *oci.Store, remoteRegistry *remote.Registry, ref *registry.Reference) (ocispec.Descriptor, error) { repo, err := remoteRegistry.Repository(ctx, ref.Repository) if err != nil { return ocispec.DescriptorEmptyJSON, fmt.Errorf("failed to read repository: %w", err) diff --git a/pkg/lib/constants/consts.go b/pkg/lib/constants/consts.go index 5048e1bc..361f66a0 100644 --- a/pkg/lib/constants/consts.go +++ b/pkg/lib/constants/consts.go @@ -1,6 +1,8 @@ package constants const ( + DefaultModelFileName = "Jozufile" + // Media type for the model layer ModelLayerMediaType = "application/vnd.jozu.model.content.v1.tar+gzip" // Media type for the dataset layer diff --git a/pkg/lib/filesystem/paths.go b/pkg/lib/filesystem/paths.go new file mode 100644 index 00000000..852043ba --- /dev/null +++ b/pkg/lib/filesystem/paths.go @@ -0,0 +1,50 @@ +package filesystem + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" +) + +// VerifySubpath checks that path.Join(context, subDir) is a subdirectory of context, following +// symlinks if present. +func VerifySubpath(context, subDir string) (absPath string, err error) { + // Get absolute path for context and context + subDir + absContext, err := filepath.Abs(context) + if err != nil { + return "", fmt.Errorf("failed to resolve absolute path for %s: %w", context, err) + } + fullPath := filepath.Clean(filepath.Join(absContext, subDir)) + + // Get actual paths, ignoring symlinks along the way + resolvedContext, err := filepath.EvalSymlinks(absContext) + if err != nil { + return "", fmt.Errorf("error resolving %s: %w", absContext, err) + } + resolvedFullPath, err := filepath.EvalSymlinks(fullPath) + if err != nil { + return "", fmt.Errorf("error resolving %s: %w", absContext, err) + } + + // Get relative path between context and the full path to check if the + // actual full, absolute path is a subdirectory of context + relPath, err := filepath.Rel(resolvedContext, resolvedFullPath) + if err != nil { + return "", fmt.Errorf("failed to get relative path: %w", err) + } + if strings.Contains(relPath, "..") { + return "", fmt.Errorf("paths must be within context directory") + } + + return resolvedFullPath, nil +} + +func PathExists(path string) (fs.FileInfo, bool) { + fi, err := os.Stat(path) + if err != nil && os.IsNotExist(err) { + return nil, false + } + return fi, true +} diff --git a/pkg/lib/repo/repo.go b/pkg/lib/repo/repo.go new file mode 100644 index 00000000..3de579bc --- /dev/null +++ b/pkg/lib/repo/repo.go @@ -0,0 +1,52 @@ +package repo + +import ( + "context" + "encoding/json" + "fmt" + "jmm/pkg/artifact" + "jmm/pkg/lib/constants" + + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "oras.land/oras-go/v2/content" +) + +func GetManifestAndConfig(ctx context.Context, store content.Storage, manifestDesc ocispec.Descriptor) (*ocispec.Manifest, *artifact.JozuFile, error) { + manifest, err := GetManifest(ctx, store, manifestDesc) + if err != nil { + return nil, nil, err + } + config, err := GetConfig(ctx, store, manifest.Config) + if err != nil { + return nil, nil, err + } + return manifest, config, nil +} + +func GetManifest(ctx context.Context, store content.Storage, manifestDesc ocispec.Descriptor) (*ocispec.Manifest, error) { + manifestBytes, err := content.FetchAll(ctx, store, manifestDesc) + if err != nil { + return nil, fmt.Errorf("failed to read manifest %s: %w", manifestDesc.Digest, err) + } + manifest := &ocispec.Manifest{} + if err := json.Unmarshal(manifestBytes, &manifest); err != nil { + return nil, fmt.Errorf("failed to parse manifest %s: %w", manifestDesc.Digest, err) + } + if manifest.Config.MediaType != constants.ModelConfigMediaType { + return nil, fmt.Errorf("reference exists but is not a model") + } + + return manifest, nil +} + +func GetConfig(ctx context.Context, store content.Storage, configDesc ocispec.Descriptor) (*artifact.JozuFile, error) { + configBytes, err := content.FetchAll(ctx, store, configDesc) + if err != nil { + return nil, fmt.Errorf("failed to read config: %w", err) + } + config := &artifact.JozuFile{} + if err := json.Unmarshal(configBytes, config); err != nil { + return nil, fmt.Errorf("failed to parse config: %w", err) + } + return config, nil +} diff --git a/pkg/lib/storage/local.go b/pkg/lib/storage/local.go index e0c5326e..5aa86af9 100644 --- a/pkg/lib/storage/local.go +++ b/pkg/lib/storage/local.go @@ -43,19 +43,20 @@ func NewLocalStore(storeRoot, repo string) Store { } func (store *LocalStore) SaveModel(model *artifact.Model, tag string) (*ocispec.Descriptor, error) { - config, err := store.saveConfigFile(model.Config) + configDesc, err := store.saveConfigFile(model.Config) if err != nil { return nil, err } - for idx, layer := range model.Layers { + var layerDescs []ocispec.Descriptor + for _, layer := range model.Layers { layerDesc, err := store.saveContentLayer(&layer) if err != nil { return nil, err } - model.Layers[idx].Descriptor = *layerDesc + layerDescs = append(layerDescs, layerDesc) } - manifestDesc, err := store.saveModelManifest(model, config, tag) + manifestDesc, err := store.saveModelManifest(layerDescs, configDesc, tag) if err != nil { return nil, err } @@ -97,13 +98,13 @@ func (store *LocalStore) GetRepository() string { return store.repo } -func (store *LocalStore) saveContentLayer(layer *artifact.ModelLayer) (*ocispec.Descriptor, error) { +func (store *LocalStore) saveContentLayer(layer *artifact.ModelLayer) (ocispec.Descriptor, error) { ctx := context.Background() buf := &bytes.Buffer{} err := layer.Apply(buf) if err != nil { - return nil, err + return ocispec.DescriptorEmptyJSON, err } // Create a descriptor for the layer @@ -115,7 +116,7 @@ func (store *LocalStore) saveContentLayer(layer *artifact.ModelLayer) (*ocispec. exists, err := store.storage.Exists(ctx, desc) if err != nil { - return nil, err + return ocispec.DescriptorEmptyJSON, err } if exists { fmt.Println("Model layer already saved: ", desc.Digest) @@ -123,19 +124,19 @@ func (store *LocalStore) saveContentLayer(layer *artifact.ModelLayer) (*ocispec. // Does not exist in storage, need to push err = store.storage.Push(ctx, desc, buf) if err != nil { - return nil, err + return ocispec.DescriptorEmptyJSON, err } fmt.Println("Saved model layer: ", desc.Digest) } - return &desc, nil + return desc, nil } -func (store *LocalStore) saveConfigFile(model *artifact.JozuFile) (*ocispec.Descriptor, error) { +func (store *LocalStore) saveConfigFile(model *artifact.JozuFile) (ocispec.Descriptor, error) { ctx := context.Background() modelBytes, err := model.MarshalToJSON() if err != nil { - return nil, err + return ocispec.DescriptorEmptyJSON, err } desc := ocispec.Descriptor{ MediaType: constants.ModelConfigMediaType, @@ -145,30 +146,30 @@ func (store *LocalStore) saveConfigFile(model *artifact.JozuFile) (*ocispec.Desc exists, err := store.storage.Exists(ctx, desc) if err != nil { - return nil, err + return ocispec.DescriptorEmptyJSON, err } if !exists { // Does not exist in storage, need to push err = store.storage.Push(ctx, desc, bytes.NewReader(modelBytes)) if err != nil { - return nil, err + return ocispec.DescriptorEmptyJSON, err } } - return &desc, nil + return desc, nil } -func (store *LocalStore) saveModelManifest(model *artifact.Model, config *ocispec.Descriptor, tag string) (*ocispec.Descriptor, error) { +func (store *LocalStore) saveModelManifest(layerDescs []ocispec.Descriptor, config ocispec.Descriptor, tag string) (*ocispec.Descriptor, error) { ctx := context.Background() manifest := ocispec.Manifest{ Versioned: specs.Versioned{SchemaVersion: 2}, - Config: *config, + Config: config, Layers: []ocispec.Descriptor{}, Annotations: map[string]string{}, } // Add the layers to the manifest - for _, layer := range model.Layers { - manifest.Layers = append(manifest.Layers, layer.Descriptor) + for _, layerDesc := range layerDescs { + manifest.Layers = append(manifest.Layers, layerDesc) } manifestBytes, err := json.Marshal(manifest)