Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to ModelKit format #635

Merged
merged 8 commits into from
Dec 6, 2024
30 changes: 21 additions & 9 deletions pkg/artifact/kitfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ type (
Docs []Docs `json:"docs,omitempty" yaml:"docs,omitempty"`
}

Docs struct {
Path string `json:"path" yaml:"path"`
Description string `json:"description" yaml:"description"`
}

Package struct {
Name string `json:"name,omitempty" yaml:"name,omitempty"`
Version string `json:"version,omitempty" yaml:"version,omitempty"`
Expand All @@ -47,6 +42,12 @@ type (
Authors []string `json:"authors,omitempty" yaml:"authors,omitempty,flow"`
}

Docs struct {
Path string `json:"path" yaml:"path"`
Description string `json:"description" yaml:"description"`
*LayerInfo `json:",inline" yaml:",inline"`
}

Model struct {
Name string `json:"name,omitempty" yaml:"name,omitempty"`
Path string `json:"path,omitempty" yaml:"path,omitempty"`
Expand All @@ -62,19 +63,22 @@ type (
// * Numbers will be converted to decimal representations (0xFF -> 255, 1.2e+3 -> 1200)
// * Maps will be sorted alphabetically by key
Parameters any `json:"parameters,omitempty" yaml:"parameters,omitempty"`
*LayerInfo `json:",inline" yaml:",inline"`
}

ModelPart struct {
Name string `json:"name,omitempty" yaml:"name,omitempty"`
Path string `json:"path,omitempty" yaml:"path,omitempty"`
License string `json:"license,omitempty" yaml:"license,omitempty"`
Type string `json:"type,omitempty" yaml:"type,omitempty"`
Name string `json:"name,omitempty" yaml:"name,omitempty"`
Path string `json:"path,omitempty" yaml:"path,omitempty"`
License string `json:"license,omitempty" yaml:"license,omitempty"`
Type string `json:"type,omitempty" yaml:"type,omitempty"`
*LayerInfo `json:",inline" yaml:",inline"`
}

Code struct {
Path string `json:"path,omitempty" yaml:"path,omitempty"`
Description string `json:"description,omitempty" yaml:"description,omitempty"`
License string `json:"license,omitempty" yaml:"license,omitempty"`
*LayerInfo `json:",inline" yaml:",inline"`
}

DataSet struct {
Expand All @@ -90,6 +94,14 @@ type (
// * Maps will be sorted alphabetically by key
// * It's recommended to store metadata like preprocessing steps, formats, etc.
Parameters any `json:"parameters,omitempty" yaml:"parameters,omitempty"`
*LayerInfo `json:",inline" yaml:",inline"`
}

LayerInfo struct {
// Digest for the layer corresponding to this element
Digest string `json:"digest,omitempty" yaml:"-"`
// Diff ID (uncompressed digest) for the layer corresponding to this element
DiffId string `json:"diffId,omitempty" yaml:"-"`
}
)

Expand Down
99 changes: 72 additions & 27 deletions pkg/cmd/unpack/unpack.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,32 +77,52 @@ func runUnpackRecursive(ctx context.Context, opts *unpackOptions, visitedRefs []

// Since there might be multiple datasets, etc. we need to synchronously iterate
// through the config's relevant field to get the correct path for unpacking
// We need to support older ModelKits (that were packed without diffIDs and digest
// in the config) for now, so we need to continue using the old structure.
var modelPartIdx, codeIdx, datasetIdx, docsIdx int
for _, layerDesc := range manifest.Layers {
// This variable supports older-format tar layers (that don't include the
// layer path). For current ModelKits, this will be empty
var relPath string

mediaType := constants.ParseMediaType(layerDesc.MediaType)
switch mediaType.BaseType {
case constants.ModelType:
if !shouldUnpackLayer(config.Model, opts.filterConfs) {
continue
}
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, config.Model.Path)
if err != nil {
return fmt.Errorf("error resolving model path: %w", err)
if config.Model.LayerInfo != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also check if there is a config.Model ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only enter this branch of the switch if we encounter a model layer, which implies the existence of a model field in the Kitfile (i.e. how did we pack a model layer without a model section in the Kitfile?).

if config.Model.LayerInfo.Digest != layerDesc.Digest.String() {
return fmt.Errorf("digest in config and manifest do not match in model")
}
relPath = ""
} else {
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, config.Model.Path)
if err != nil {
return fmt.Errorf("error resolving model path: %w", err)
}
}
output.Infof("Unpacking model %s to %s", config.Model.Name, relPath)

output.Infof("Unpacking model %s to %s", config.Model.Name, config.Model.Path)

case constants.ModelPartType:
part := config.Model.Parts[modelPartIdx]
if !shouldUnpackLayer(part, opts.filterConfs) {
modelPartIdx += 1
continue
}
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, part.Path)
if err != nil {
return fmt.Errorf("error resolving code path: %w", err)
if part.LayerInfo != nil {
if part.LayerInfo.Digest != layerDesc.Digest.String() {
return fmt.Errorf("digest in config and manifest do not match in modelpart")
}
relPath = ""
} else {
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, part.Path)
if err != nil {
return fmt.Errorf("error resolving code path: %w", err)
}
}
output.Infof("Unpacking model part %s to %s", part.Name, relPath)
output.Infof("Unpacking model part %s to %s", part.Name, part.Path)
modelPartIdx += 1

case constants.CodeType:
Expand All @@ -111,11 +131,18 @@ func runUnpackRecursive(ctx context.Context, opts *unpackOptions, visitedRefs []
codeIdx += 1
continue
}
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, codeEntry.Path)
if err != nil {
return fmt.Errorf("error resolving code path: %w", err)
if codeEntry.LayerInfo != nil {
if codeEntry.LayerInfo.Digest != layerDesc.Digest.String() {
return fmt.Errorf("digest in config and manifest do not match in code layer")
}
relPath = ""
} else {
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, codeEntry.Path)
if err != nil {
return fmt.Errorf("error resolving code path: %w", err)
}
}
output.Infof("Unpacking code to %s", relPath)
output.Infof("Unpacking code to %s", codeEntry.Path)
codeIdx += 1

case constants.DatasetType:
Expand All @@ -124,11 +151,18 @@ func runUnpackRecursive(ctx context.Context, opts *unpackOptions, visitedRefs []
datasetIdx += 1
continue
}
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, datasetEntry.Path)
if err != nil {
return fmt.Errorf("error resolving dataset path for dataset %s: %w", datasetEntry.Name, err)
if datasetEntry.LayerInfo != nil {
if datasetEntry.LayerInfo.Digest != layerDesc.Digest.String() {
return fmt.Errorf("digest in config and manifest do not match in dataset layer")
}
relPath = ""
} else {
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, datasetEntry.Path)
if err != nil {
return fmt.Errorf("error resolving dataset path for dataset %s: %w", datasetEntry.Name, err)
}
}
output.Infof("Unpacking dataset %s to %s", datasetEntry.Name, relPath)
output.Infof("Unpacking dataset %s to %s", datasetEntry.Name, datasetEntry.Path)
datasetIdx += 1

case constants.DocsType:
Expand All @@ -137,9 +171,16 @@ func runUnpackRecursive(ctx context.Context, opts *unpackOptions, visitedRefs []
docsIdx += 1
continue
}
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, docsEntry.Path)
if err != nil {
return fmt.Errorf("error resolving path %s for docs: %w", docsEntry.Path, err)
if docsEntry.LayerInfo != nil {
if docsEntry.LayerInfo.Digest != layerDesc.Digest.String() {
return fmt.Errorf("digest in config and manifest do not match in docs layer")
}
relPath = ""
} else {
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, docsEntry.Path)
if err != nil {
return fmt.Errorf("error resolving path %s for docs: %w", docsEntry.Path, err)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicated code for handling every layer. Should it be refactored ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just kind of the way things end up happening in Golang. I can try to rework it but in my experience it ends up decreasing readability :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Latest commit should make this a little clearer.

}
output.Infof("Unpacking docs to %s", docsEntry.Path)
docsIdx += 1
Expand Down Expand Up @@ -225,19 +266,21 @@ func unpackLayer(ctx context.Context, store content.Storage, desc ocispec.Descri
defer cr.Close()
tr := tar.NewReader(cr)

unpackDir := filepath.Dir(unpackPath)
if err := os.MkdirAll(unpackDir, 0755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", unpackDir, err)
if unpackPath != "" {
unpackPath = filepath.Dir(unpackPath)
if err := os.MkdirAll(unpackPath, 0755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", unpackPath, err)
}
}

if err := extractTar(tr, unpackDir, overwrite, logger); err != nil {
if err := extractTar(tr, unpackPath, overwrite, logger); err != nil {
return err
}
logger.Wait()
return nil
}

func extractTar(tr *tar.Reader, dir string, overwrite bool, logger *output.ProgressLogger) (err error) {
func extractTar(tr *tar.Reader, extractDir string, overwrite bool, logger *output.ProgressLogger) (err error) {
for {
header, err := tr.Next()
if err == io.EOF {
Expand All @@ -246,9 +289,12 @@ func extractTar(tr *tar.Reader, dir string, overwrite bool, logger *output.Progr
if err != nil {
return err
}
outPath := filepath.Join(dir, header.Name)
outPath := header.Name
if extractDir != "" {
outPath = filepath.Join(extractDir, header.Name)
}
// Check if the outPath is within the target directory
_, _, err = filesystem.VerifySubpath(dir, outPath)
_, _, err = filesystem.VerifySubpath(extractDir, outPath)
if err != nil {
return fmt.Errorf("illegal file path: %s: %w", outPath, err)
}
Expand All @@ -259,7 +305,6 @@ func extractTar(tr *tar.Reader, dir string, overwrite bool, logger *output.Progr
if !fi.IsDir() {
return fmt.Errorf("path '%s' already exists and is not a directory", outPath)
}
logger.Debugf("Path %s already exists", outPath)
} else {
logger.Debugf("Creating directory %s", outPath)
if err := os.MkdirAll(outPath, header.FileInfo().Mode()); err != nil {
Expand Down
Loading
Loading