Skip to content

Commit

Permalink
Make unpack process handle both old and new ModelKits
Browse files Browse the repository at this point in the history
We need to handle existing ModelKits in addition to the new tar format
-- for new ModelKits, we can unpack the tars directly (and simply), but
for older ones, we need to pre-create the layer's directory and
concatenate paths from the tarball.

To avoid overcomplicating things (since we're still relatively early),
we detect new ModelKits via the digest/diffId fields in the config; this
is a lot simpler to support than using new versions on our media types.
  • Loading branch information
amisevsk committed Dec 3, 2024
1 parent 3b96dc8 commit 153c702
Showing 1 changed file with 71 additions and 25 deletions.
96 changes: 71 additions & 25 deletions pkg/cmd/unpack/unpack.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,32 +77,52 @@ func runUnpackRecursive(ctx context.Context, opts *unpackOptions, visitedRefs []

// Since there might be multiple datasets, etc. we need to synchronously iterate
// through the config's relevant field to get the correct path for unpacking
// We need to support older ModelKits (that were packed without diffIDs and digest
// in the config) for now, so we need to continue using the old structure.
var modelPartIdx, codeIdx, datasetIdx, docsIdx int
for _, layerDesc := range manifest.Layers {
// This variable supports older-format tar layers (that don't include the
// layer path). For current ModelKits, this will be empty
var relPath string

mediaType := constants.ParseMediaType(layerDesc.MediaType)
switch mediaType.BaseType {
case constants.ModelType:
if !shouldUnpackLayer(config.Model, opts.filterConfs) {
continue
}
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, config.Model.Path)
if err != nil {
return fmt.Errorf("error resolving model path: %w", err)
if config.Model.LayerInfo != nil {
if config.Model.LayerInfo.Digest != layerDesc.Digest.String() {
return fmt.Errorf("digest in config and manifest do not match in model")
}
relPath = ""
} else {
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, config.Model.Path)
if err != nil {
return fmt.Errorf("error resolving model path: %w", err)
}
}
output.Infof("Unpacking model %s to %s", config.Model.Name, relPath)

output.Infof("Unpacking model %s to %s", config.Model.Name, config.Model.Path)

case constants.ModelPartType:
part := config.Model.Parts[modelPartIdx]
if !shouldUnpackLayer(part, opts.filterConfs) {
modelPartIdx += 1
continue
}
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, part.Path)
if err != nil {
return fmt.Errorf("error resolving code path: %w", err)
if part.LayerInfo != nil {
if part.LayerInfo.Digest != layerDesc.Digest.String() {
return fmt.Errorf("digest in config and manifest do not match in modelpart")
}
relPath = ""
} else {
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, part.Path)
if err != nil {
return fmt.Errorf("error resolving code path: %w", err)
}
}
output.Infof("Unpacking model part %s to %s", part.Name, relPath)
output.Infof("Unpacking model part %s to %s", part.Name, part.Path)
modelPartIdx += 1

case constants.CodeType:
Expand All @@ -111,11 +131,18 @@ func runUnpackRecursive(ctx context.Context, opts *unpackOptions, visitedRefs []
codeIdx += 1
continue
}
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, codeEntry.Path)
if err != nil {
return fmt.Errorf("error resolving code path: %w", err)
if codeEntry.LayerInfo != nil {
if codeEntry.LayerInfo.Digest != layerDesc.Digest.String() {
return fmt.Errorf("digest in config and manifest do not match in code layer")
}
relPath = ""
} else {
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, codeEntry.Path)
if err != nil {
return fmt.Errorf("error resolving code path: %w", err)
}
}
output.Infof("Unpacking code to %s", relPath)
output.Infof("Unpacking code to %s", codeEntry.Path)
codeIdx += 1

case constants.DatasetType:
Expand All @@ -124,11 +151,18 @@ func runUnpackRecursive(ctx context.Context, opts *unpackOptions, visitedRefs []
datasetIdx += 1
continue
}
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, datasetEntry.Path)
if err != nil {
return fmt.Errorf("error resolving dataset path for dataset %s: %w", datasetEntry.Name, err)
if datasetEntry.LayerInfo != nil {
if datasetEntry.LayerInfo.Digest != layerDesc.Digest.String() {
return fmt.Errorf("digest in config and manifest do not match in dataset layer")
}
relPath = ""
} else {
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, datasetEntry.Path)
if err != nil {
return fmt.Errorf("error resolving dataset path for dataset %s: %w", datasetEntry.Name, err)
}
}
output.Infof("Unpacking dataset %s to %s", datasetEntry.Name, relPath)
output.Infof("Unpacking dataset %s to %s", datasetEntry.Name, datasetEntry.Path)
datasetIdx += 1

case constants.DocsType:
Expand All @@ -137,9 +171,16 @@ func runUnpackRecursive(ctx context.Context, opts *unpackOptions, visitedRefs []
docsIdx += 1
continue
}
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, docsEntry.Path)
if err != nil {
return fmt.Errorf("error resolving path %s for docs: %w", docsEntry.Path, err)
if docsEntry.LayerInfo != nil {
if docsEntry.LayerInfo.Digest != layerDesc.Digest.String() {
return fmt.Errorf("digest in config and manifest do not match in docs layer")
}
relPath = ""
} else {
_, relPath, err = filesystem.VerifySubpath(opts.unpackDir, docsEntry.Path)
if err != nil {
return fmt.Errorf("error resolving path %s for docs: %w", docsEntry.Path, err)
}
}
output.Infof("Unpacking docs to %s", docsEntry.Path)
docsIdx += 1
Expand Down Expand Up @@ -225,19 +266,21 @@ func unpackLayer(ctx context.Context, store content.Storage, desc ocispec.Descri
defer cr.Close()
tr := tar.NewReader(cr)

unpackDir := filepath.Dir(unpackPath)
if err := os.MkdirAll(unpackDir, 0755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", unpackDir, err)
if unpackPath != "" {
unpackPath = filepath.Dir(unpackPath)
if err := os.MkdirAll(unpackPath, 0755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", unpackPath, err)
}
}

if err := extractTar(tr, unpackDir, overwrite, logger); err != nil {
if err := extractTar(tr, unpackPath, overwrite, logger); err != nil {
return err
}
logger.Wait()
return nil
}

func extractTar(tr *tar.Reader, dir string, overwrite bool, logger *output.ProgressLogger) (err error) {
func extractTar(tr *tar.Reader, extractDir string, overwrite bool, logger *output.ProgressLogger) (err error) {
for {
header, err := tr.Next()
if err == io.EOF {
Expand All @@ -247,8 +290,11 @@ func extractTar(tr *tar.Reader, dir string, overwrite bool, logger *output.Progr
return err
}
outPath := header.Name
if extractDir != "" {
outPath = filepath.Join(extractDir, header.Name)
}
// Check if the outPath is within the target directory
_, _, err = filesystem.VerifySubpath(dir, outPath)
_, _, err = filesystem.VerifySubpath(extractDir, outPath)
if err != nil {
return fmt.Errorf("illegal file path: %s: %w", outPath, err)
}
Expand Down

0 comments on commit 153c702

Please sign in to comment.