Skip to content

Commit

Permalink
feat: add tool for demo
Browse files Browse the repository at this point in the history
  • Loading branch information
aftersnow committed Sep 26, 2024
1 parent 2872729 commit 0609fae
Show file tree
Hide file tree
Showing 28 changed files with 2,542 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tools/mdctl/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mdctl
vendor
*.safetensors
*.model
gemma-2b:*
168 changes: 168 additions & 0 deletions tools/mdctl/cmd/build.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package cmd

import (
"fmt"

v2 "github.com/CloudNativeAI/model-spec/specs-go/v2"
"github.com/CloudNativeAI/model-spec/tools/mdctl/format"
"github.com/CloudNativeAI/model-spec/tools/mdctl/models"
oci "github.com/opencontainers/image-spec/specs-go/v1"
)

func BuildModel(commands []format.Command) error {
manifest := v2.Manifest{MediaType: v2.MediaTypeModelManifest}
config := v2.Config{}
weights := v2.Weights{}
engine := v2.Engine{}

if len(commands) == 0 {
return fmt.Errorf("modelfile has no command")
}

var modelName string
if commands[0].Name == format.CREATE {
modelName = commands[0].Args
fmt.Println("Create", modelName)
} else if commands[0].Name == format.FROM {
modelName = commands[0].Args
fmt.Println("From ", modelName)
if err := PullModel(modelName); err != nil {
return fmt.Errorf("failed to pull base model")
}
_, err := FetchManifest(modelName, &manifest, &config)
if err != nil {
return fmt.Errorf("failed to get remote manifest")
}
} else {
return fmt.Errorf("first command should be %s or %s", format.CREATE, format.FROM)
}
for _, c := range commands {
switch c.Name {
case format.CREATE, format.FROM:
config.Name = c.Args

case format.NAME:
config.Name = c.Args

case format.DESCRIPTION:
layer, err := models.BuildDescriptor(models.TAR, c.Args, v2.MediaTypeModelDescription, "Description")
if err != nil {
return fmt.Errorf("failed to build description layer: %w", err)
}
config.Description = append(config.Description, *layer)
fmt.Printf("Add description [%s]\n", c.Args)

case format.LICENSE:
layer, err := models.BuildDescriptor(models.TAR, c.Args, v2.MediaTypeModelLicense, "License")
if err != nil {
return fmt.Errorf("failed to build license layer: %w", err)
}
config.License = append(config.License, *layer)
fmt.Printf("Add license [%s]\n", c.Args)

case format.ARCHITECTURE:
config.Architecture = c.Args

case format.FAMILY:
config.Family = c.Args

case format.CONFIG:
layer, err := models.BuildDescriptor(models.TAR, c.Args, v2.MediaTypeModelConfig, "")
if err != nil {
return fmt.Errorf("failed to build config layer: %w", err)
}
config.Extensions = append(config.Extensions, *layer)
fmt.Printf("Add config [%s]\n", c.Args)

case format.PARAM_SIZE:
engine.Name = c.Args

case format.FORMAT:
weights.Format = c.Args

case format.WEIGHTS:
layer, err := models.BuildDescriptor(models.TAR, c.Args, v2.MediaTypeModelWeights, "")
if err != nil {
return fmt.Errorf("failed to build weights layer: %w", err)
}
weights.File = append(weights.File, *layer)
fmt.Printf("Add weights [%s]\n", c.Args)

case format.TOKENIZER:
layer, err := models.BuildDescriptor(models.TAR, c.Args, v2.MediaTypeModelProcessorText, "")
if err != nil {
return fmt.Errorf("failed to build tokenizer layer: %w", err)
}
manifest.Processor = append(manifest.Processor, *layer)
fmt.Printf("Add tokenizer [%s]\n", c.Args)

default:
fmt.Printf("WARN: [%s] - [%s] not handled\n", c.Name, c.Args)
}
}

manifest.Config = config
manifest.Weights = weights
manifest.Engine = engine

// Commit layers
_, err := Commit(&manifest)
if err != nil {
return fmt.Errorf("failed to commit layers: %w", err)
}

// Commit manifest layer
if err := models.WriteManifest(modelName, &manifest); err != nil {
return fmt.Errorf("failed to write manifest: %w", err)
}

fmt.Println("Build succeed")
return nil
}

func Commit(m *v2.Manifest) (bool, error) {
layerGroups := []struct {
name string
layers []oci.Descriptor
}{
{"Description", m.Config.Description},
{"License", m.Config.License},
{"Extensions", m.Config.Extensions},
{"Weights", m.Weights.File},
{"Tokenizer", m.Processor},
}

var committed bool
for _, group := range layerGroups {
if len(group.layers) == 0 {
continue // Skip empty layer groups
}
groupCommitted, err := commitLayers(group.name, group.layers)
if err != nil {
return false, fmt.Errorf("failed to commit %s layers: %w", group.name, err)
}
committed = committed || groupCommitted
}
return committed, nil
}

func commitLayers(groupName string, layers []oci.Descriptor) (bool, error) {
var groupCommitted bool
for _, layer := range layers {
layerCommitted, err := commitSingleLayer(groupName, layer)
if err != nil {
return false, err
}
groupCommitted = groupCommitted || layerCommitted
}
return groupCommitted, nil
}

func commitSingleLayer(groupName string, layer oci.Descriptor) (bool, error) {
committed, err := models.Commit(layer)
if err != nil {
return false, fmt.Errorf("failed to commit %s layer: %w", groupName, err)
}

return committed, nil
}
145 changes: 145 additions & 0 deletions tools/mdctl/cmd/cmd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package cmd

import (
"bytes"
"fmt"
"log"
"os"
"path/filepath"

"github.com/CloudNativeAI/model-spec/tools/mdctl/format"
"github.com/CloudNativeAI/model-spec/tools/mdctl/progress"
"github.com/spf13/cobra"
)

func BuildHandler(cmd *cobra.Command, args []string) error {
filename, _ := cmd.Flags().GetString("file")
filename, err := filepath.Abs(filename)
if err != nil {
return fmt.Errorf("failed to get absolute path: %w", err)
}

p := progress.NewProgress(os.Stderr)
// defer p.Stop()
// bars := make(map[string]*progress.Bar)

modelFile, err := os.ReadFile(filename)
if err != nil {
return fmt.Errorf("failed to read modelfile: %w", err)
}

commands, err := format.Parse(bytes.NewReader(modelFile))
if err != nil {
return fmt.Errorf("failed to parse modelfile: %w", err)
}

// status := "building"
// spinner := progress.NewSpinner(status)
// p.Add(status, spinner)

if err := BuildModel(commands); err != nil {
return fmt.Errorf("failed to build model: %w", err)
}
p.StopAndClear()

return nil
}

func RunHandler(cmd *cobra.Command, _ []string) error {
name, _ := cmd.Flags().GetString("name")
fmt.Println("Unpack Model: ", name)
if err := RunModel(name); err != nil {
return fmt.Errorf("failed to unpack model: %w", err)
}
fmt.Println("Unpack succeed")
return nil
}

func PushHandler(cmd *cobra.Command, _ []string) error {
name, _ := cmd.Flags().GetString("name")
fmt.Println("Push Model:", name)
if err := PushModel(name); err != nil {
return fmt.Errorf("failed to push model: %w", err)
}
return nil
}

func PullHandler(cmd *cobra.Command, _ []string) error {
name, _ := cmd.Flags().GetString("name")
fmt.Println("Pull Model:", name)
if err := PullModel(name); err != nil {
return fmt.Errorf("failed to pull model: %w", err)
}
return nil
}

func ListHandler(cmd *cobra.Command, args []string) error {
return ListModel()
}

func NewCLI() *cobra.Command {
log.SetFlags(log.LstdFlags | log.Lshortfile)
cobra.EnableCommandSorting = false

rootCmd := &cobra.Command{
Use: "mdctl",
Short: "Model management tool",
SilenceUsage: true,
SilenceErrors: true,
CompletionOptions: cobra.CompletionOptions{
DisableDefaultCmd: true,
},
Run: func(cmd *cobra.Command, args []string) {
cmd.Print(cmd.UsageString())
},
}

buildCmd := &cobra.Command{
Use: "build",
Short: "build models from a Modelfile",
Args: cobra.ExactArgs(0),
RunE: BuildHandler,
}
buildCmd.Flags().StringP("file", "f", "Modelfile", "Path to the Modelfile")

runCmd := &cobra.Command{
Use: "unpack",
Short: "run a model",
Args: cobra.ExactArgs(0),
RunE: RunHandler,
}
runCmd.Flags().StringP("name", "n", "", "URL of the model")

pushCmd := &cobra.Command{
Use: "push",
Short: "push a model",
Args: cobra.ExactArgs(0),
RunE: PushHandler,
}
pushCmd.Flags().StringP("name", "n", "", "URL of the model")

pullCmd := &cobra.Command{
Use: "pull",
Short: "pull a model",
Args: cobra.ExactArgs(0),
RunE: PullHandler,
}
pullCmd.Flags().StringP("name", "n", "", "URL of the model")

listCmd := &cobra.Command{
Use: "list",
Short: "list models",
Args: cobra.ExactArgs(0),
RunE: ListHandler,
}

rootCmd.AddCommand(
buildCmd,
listCmd,
runCmd,
pushCmd,
pullCmd,
)

return rootCmd
}
36 changes: 36 additions & 0 deletions tools/mdctl/cmd/list.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package cmd

import (
"fmt"
"os"
"path/filepath"
"strings"

"github.com/CloudNativeAI/model-spec/tools/mdctl/models"
)

func ListModel() error {
dir, err := models.GetManifestRoot()
if err != nil {
return err
}
err = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return fmt.Errorf("failed to walk model: %w", err)
}
if !info.IsDir() {
term := dir + string(os.PathSeparator)
name := strings.TrimPrefix(path, term)
lastSeparatorIndex := strings.LastIndex(name, string(os.PathSeparator))
if lastSeparatorIndex != -1 {
name = name[:lastSeparatorIndex] + ":" + name[lastSeparatorIndex+1:]
}
fmt.Println(name)
}
return nil
})
if err != nil {
return fmt.Errorf("failed to list model: %w", err)
}
return nil
}
Loading

0 comments on commit 0609fae

Please sign in to comment.