-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
28 changed files
with
2,542 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
mdctl | ||
vendor | ||
*.safetensors | ||
*.model | ||
gemma-2b:* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
package cmd | ||
|
||
import ( | ||
"fmt" | ||
|
||
v2 "github.com/CloudNativeAI/model-spec/specs-go/v2" | ||
"github.com/CloudNativeAI/model-spec/tools/mdctl/format" | ||
"github.com/CloudNativeAI/model-spec/tools/mdctl/models" | ||
oci "github.com/opencontainers/image-spec/specs-go/v1" | ||
) | ||
|
||
func BuildModel(commands []format.Command) error { | ||
manifest := v2.Manifest{MediaType: v2.MediaTypeModelManifest} | ||
config := v2.Config{} | ||
weights := v2.Weights{} | ||
engine := v2.Engine{} | ||
|
||
if len(commands) == 0 { | ||
return fmt.Errorf("modelfile has no command") | ||
} | ||
|
||
var modelName string | ||
if commands[0].Name == format.CREATE { | ||
modelName = commands[0].Args | ||
fmt.Println("Create", modelName) | ||
} else if commands[0].Name == format.FROM { | ||
modelName = commands[0].Args | ||
fmt.Println("From ", modelName) | ||
if err := PullModel(modelName); err != nil { | ||
return fmt.Errorf("failed to pull base model") | ||
} | ||
_, err := FetchManifest(modelName, &manifest, &config) | ||
if err != nil { | ||
return fmt.Errorf("failed to get remote manifest") | ||
} | ||
} else { | ||
return fmt.Errorf("first command should be %s or %s", format.CREATE, format.FROM) | ||
} | ||
for _, c := range commands { | ||
switch c.Name { | ||
case format.CREATE, format.FROM: | ||
config.Name = c.Args | ||
|
||
case format.NAME: | ||
config.Name = c.Args | ||
|
||
case format.DESCRIPTION: | ||
layer, err := models.BuildDescriptor(models.TAR, c.Args, v2.MediaTypeModelDescription, "Description") | ||
if err != nil { | ||
return fmt.Errorf("failed to build description layer: %w", err) | ||
} | ||
config.Description = append(config.Description, *layer) | ||
fmt.Printf("Add description [%s]\n", c.Args) | ||
|
||
case format.LICENSE: | ||
layer, err := models.BuildDescriptor(models.TAR, c.Args, v2.MediaTypeModelLicense, "License") | ||
if err != nil { | ||
return fmt.Errorf("failed to build license layer: %w", err) | ||
} | ||
config.License = append(config.License, *layer) | ||
fmt.Printf("Add license [%s]\n", c.Args) | ||
|
||
case format.ARCHITECTURE: | ||
config.Architecture = c.Args | ||
|
||
case format.FAMILY: | ||
config.Family = c.Args | ||
|
||
case format.CONFIG: | ||
layer, err := models.BuildDescriptor(models.TAR, c.Args, v2.MediaTypeModelConfig, "") | ||
if err != nil { | ||
return fmt.Errorf("failed to build config layer: %w", err) | ||
} | ||
config.Extensions = append(config.Extensions, *layer) | ||
fmt.Printf("Add config [%s]\n", c.Args) | ||
|
||
case format.PARAM_SIZE: | ||
engine.Name = c.Args | ||
|
||
case format.FORMAT: | ||
weights.Format = c.Args | ||
|
||
case format.WEIGHTS: | ||
layer, err := models.BuildDescriptor(models.TAR, c.Args, v2.MediaTypeModelWeights, "") | ||
if err != nil { | ||
return fmt.Errorf("failed to build weights layer: %w", err) | ||
} | ||
weights.File = append(weights.File, *layer) | ||
fmt.Printf("Add weights [%s]\n", c.Args) | ||
|
||
case format.TOKENIZER: | ||
layer, err := models.BuildDescriptor(models.TAR, c.Args, v2.MediaTypeModelProcessorText, "") | ||
if err != nil { | ||
return fmt.Errorf("failed to build tokenizer layer: %w", err) | ||
} | ||
manifest.Processor = append(manifest.Processor, *layer) | ||
fmt.Printf("Add tokenizer [%s]\n", c.Args) | ||
|
||
default: | ||
fmt.Printf("WARN: [%s] - [%s] not handled\n", c.Name, c.Args) | ||
} | ||
} | ||
|
||
manifest.Config = config | ||
manifest.Weights = weights | ||
manifest.Engine = engine | ||
|
||
// Commit layers | ||
_, err := Commit(&manifest) | ||
if err != nil { | ||
return fmt.Errorf("failed to commit layers: %w", err) | ||
} | ||
|
||
// Commit manifest layer | ||
if err := models.WriteManifest(modelName, &manifest); err != nil { | ||
return fmt.Errorf("failed to write manifest: %w", err) | ||
} | ||
|
||
fmt.Println("Build succeed") | ||
return nil | ||
} | ||
|
||
func Commit(m *v2.Manifest) (bool, error) { | ||
layerGroups := []struct { | ||
name string | ||
layers []oci.Descriptor | ||
}{ | ||
{"Description", m.Config.Description}, | ||
{"License", m.Config.License}, | ||
{"Extensions", m.Config.Extensions}, | ||
{"Weights", m.Weights.File}, | ||
{"Tokenizer", m.Processor}, | ||
} | ||
|
||
var committed bool | ||
for _, group := range layerGroups { | ||
if len(group.layers) == 0 { | ||
continue // Skip empty layer groups | ||
} | ||
groupCommitted, err := commitLayers(group.name, group.layers) | ||
if err != nil { | ||
return false, fmt.Errorf("failed to commit %s layers: %w", group.name, err) | ||
} | ||
committed = committed || groupCommitted | ||
} | ||
return committed, nil | ||
} | ||
|
||
func commitLayers(groupName string, layers []oci.Descriptor) (bool, error) { | ||
var groupCommitted bool | ||
for _, layer := range layers { | ||
layerCommitted, err := commitSingleLayer(groupName, layer) | ||
if err != nil { | ||
return false, err | ||
} | ||
groupCommitted = groupCommitted || layerCommitted | ||
} | ||
return groupCommitted, nil | ||
} | ||
|
||
func commitSingleLayer(groupName string, layer oci.Descriptor) (bool, error) { | ||
committed, err := models.Commit(layer) | ||
if err != nil { | ||
return false, fmt.Errorf("failed to commit %s layer: %w", groupName, err) | ||
} | ||
|
||
return committed, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
package cmd | ||
|
||
import ( | ||
"bytes" | ||
"fmt" | ||
"log" | ||
"os" | ||
"path/filepath" | ||
|
||
"github.com/CloudNativeAI/model-spec/tools/mdctl/format" | ||
"github.com/CloudNativeAI/model-spec/tools/mdctl/progress" | ||
"github.com/spf13/cobra" | ||
) | ||
|
||
func BuildHandler(cmd *cobra.Command, args []string) error { | ||
filename, _ := cmd.Flags().GetString("file") | ||
filename, err := filepath.Abs(filename) | ||
if err != nil { | ||
return fmt.Errorf("failed to get absolute path: %w", err) | ||
} | ||
|
||
p := progress.NewProgress(os.Stderr) | ||
// defer p.Stop() | ||
// bars := make(map[string]*progress.Bar) | ||
|
||
modelFile, err := os.ReadFile(filename) | ||
if err != nil { | ||
return fmt.Errorf("failed to read modelfile: %w", err) | ||
} | ||
|
||
commands, err := format.Parse(bytes.NewReader(modelFile)) | ||
if err != nil { | ||
return fmt.Errorf("failed to parse modelfile: %w", err) | ||
} | ||
|
||
// status := "building" | ||
// spinner := progress.NewSpinner(status) | ||
// p.Add(status, spinner) | ||
|
||
if err := BuildModel(commands); err != nil { | ||
return fmt.Errorf("failed to build model: %w", err) | ||
} | ||
p.StopAndClear() | ||
|
||
return nil | ||
} | ||
|
||
func RunHandler(cmd *cobra.Command, _ []string) error { | ||
name, _ := cmd.Flags().GetString("name") | ||
fmt.Println("Unpack Model: ", name) | ||
if err := RunModel(name); err != nil { | ||
return fmt.Errorf("failed to unpack model: %w", err) | ||
} | ||
fmt.Println("Unpack succeed") | ||
return nil | ||
} | ||
|
||
func PushHandler(cmd *cobra.Command, _ []string) error { | ||
name, _ := cmd.Flags().GetString("name") | ||
fmt.Println("Push Model:", name) | ||
if err := PushModel(name); err != nil { | ||
return fmt.Errorf("failed to push model: %w", err) | ||
} | ||
return nil | ||
} | ||
|
||
func PullHandler(cmd *cobra.Command, _ []string) error { | ||
name, _ := cmd.Flags().GetString("name") | ||
fmt.Println("Pull Model:", name) | ||
if err := PullModel(name); err != nil { | ||
return fmt.Errorf("failed to pull model: %w", err) | ||
} | ||
return nil | ||
} | ||
|
||
func ListHandler(cmd *cobra.Command, args []string) error { | ||
return ListModel() | ||
} | ||
|
||
func NewCLI() *cobra.Command { | ||
log.SetFlags(log.LstdFlags | log.Lshortfile) | ||
cobra.EnableCommandSorting = false | ||
|
||
rootCmd := &cobra.Command{ | ||
Use: "mdctl", | ||
Short: "Model management tool", | ||
SilenceUsage: true, | ||
SilenceErrors: true, | ||
CompletionOptions: cobra.CompletionOptions{ | ||
DisableDefaultCmd: true, | ||
}, | ||
Run: func(cmd *cobra.Command, args []string) { | ||
cmd.Print(cmd.UsageString()) | ||
}, | ||
} | ||
|
||
buildCmd := &cobra.Command{ | ||
Use: "build", | ||
Short: "build models from a Modelfile", | ||
Args: cobra.ExactArgs(0), | ||
RunE: BuildHandler, | ||
} | ||
buildCmd.Flags().StringP("file", "f", "Modelfile", "Path to the Modelfile") | ||
|
||
runCmd := &cobra.Command{ | ||
Use: "unpack", | ||
Short: "run a model", | ||
Args: cobra.ExactArgs(0), | ||
RunE: RunHandler, | ||
} | ||
runCmd.Flags().StringP("name", "n", "", "URL of the model") | ||
|
||
pushCmd := &cobra.Command{ | ||
Use: "push", | ||
Short: "push a model", | ||
Args: cobra.ExactArgs(0), | ||
RunE: PushHandler, | ||
} | ||
pushCmd.Flags().StringP("name", "n", "", "URL of the model") | ||
|
||
pullCmd := &cobra.Command{ | ||
Use: "pull", | ||
Short: "pull a model", | ||
Args: cobra.ExactArgs(0), | ||
RunE: PullHandler, | ||
} | ||
pullCmd.Flags().StringP("name", "n", "", "URL of the model") | ||
|
||
listCmd := &cobra.Command{ | ||
Use: "list", | ||
Short: "list models", | ||
Args: cobra.ExactArgs(0), | ||
RunE: ListHandler, | ||
} | ||
|
||
rootCmd.AddCommand( | ||
buildCmd, | ||
listCmd, | ||
runCmd, | ||
pushCmd, | ||
pullCmd, | ||
) | ||
|
||
return rootCmd | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package cmd | ||
|
||
import ( | ||
"fmt" | ||
"os" | ||
"path/filepath" | ||
"strings" | ||
|
||
"github.com/CloudNativeAI/model-spec/tools/mdctl/models" | ||
) | ||
|
||
func ListModel() error { | ||
dir, err := models.GetManifestRoot() | ||
if err != nil { | ||
return err | ||
} | ||
err = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { | ||
if err != nil { | ||
return fmt.Errorf("failed to walk model: %w", err) | ||
} | ||
if !info.IsDir() { | ||
term := dir + string(os.PathSeparator) | ||
name := strings.TrimPrefix(path, term) | ||
lastSeparatorIndex := strings.LastIndex(name, string(os.PathSeparator)) | ||
if lastSeparatorIndex != -1 { | ||
name = name[:lastSeparatorIndex] + ":" + name[lastSeparatorIndex+1:] | ||
} | ||
fmt.Println(name) | ||
} | ||
return nil | ||
}) | ||
if err != nil { | ||
return fmt.Errorf("failed to list model: %w", err) | ||
} | ||
return nil | ||
} |
Oops, something went wrong.