Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fast generator for cog build #2108

Merged
merged 21 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/cli/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func cmdDockerfile(cmd *cobra.Command, args []string) error {
return err
}

generator, err := dockerfile.NewGenerator(cfg, projectDir)
generator, err := dockerfile.NewGenerator(cfg, projectDir, false)
if err != nil {
return fmt.Errorf("Error creating Dockerfile generator: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/config/compatibility.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ func CUDABaseImageFor(cuda string, cuDNN string) (string, error) {
func tfGPUPackage(ver string, cuda string) (name string, cpuVersion string, err error) {
for _, compat := range TFCompatibilityMatrix {
if compat.TF == ver && version.Equal(compat.CUDA, cuda) {
name, cpuVersion, _, _, err = splitPinnedPythonRequirement(compat.TFGPUPackage)
name, cpuVersion, _, _, err = SplitPinnedPythonRequirement(compat.TFGPUPackage)
return name, cpuVersion, err
}
}
Expand Down
64 changes: 8 additions & 56 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func (c *Config) cudaFromTF() (tfVersion string, tfCUDA string, tfCuDNN string,

func (c *Config) pythonPackageVersion(name string) (version string, ok bool) {
for _, pkg := range c.Build.pythonRequirementsContent {
pkgName, version, _, _, err := splitPinnedPythonRequirement(pkg)
pkgName, version, _, _, err := SplitPinnedPythonRequirement(pkg)
if err != nil {
// package is not in package==version format
continue
Expand Down Expand Up @@ -331,7 +331,11 @@ func (c *Config) PythonRequirementsForArch(goos string, goarch string, includePa

includePackageNames := []string{}
for _, pkg := range includePackages {
includePackageNames = append(includePackageNames, packageName(pkg))
packageName, err := PackageName(pkg)
if err != nil {
return "", err
}
includePackageNames = append(includePackageNames, packageName)
}

// Include all the requirements and remove our include packages if they exist
Expand All @@ -352,7 +356,7 @@ func (c *Config) PythonRequirementsForArch(goos string, goarch string, includePa
}
}

packageName := packageName(archPkg)
packageName, _ := PackageName(archPkg)
if packageName != "" {
foundIdx := -1
for i, includePkg := range includePackageNames {
Expand Down Expand Up @@ -390,7 +394,7 @@ func (c *Config) PythonRequirementsForArch(goos string, goarch string, includePa
// pythonPackageForArch takes a package==version line and
// returns a package==version and index URL resolved to the correct GPU package for the given OS and architecture
func (c *Config) pythonPackageForArch(pkg, goos, goarch string) (actualPackage string, findLinksList []string, extraIndexURLs []string, err error) {
name, version, findLinksList, extraIndexURLs, err := splitPinnedPythonRequirement(pkg)
name, version, findLinksList, extraIndexURLs, err := SplitPinnedPythonRequirement(pkg)
if err != nil {
// It's not pinned, so just return the line verbatim
return pkg, []string{}, []string{}, nil
Expand Down Expand Up @@ -562,50 +566,6 @@ Compatible cuDNN version is: %s`, c.Build.CuDNN, tfVersion, tfCuDNN)
return nil
}

// splitPythonPackage returns the name, version, findLinks, and extraIndexURLs from a requirements.txt line
// in the form name==version [--find-links=<findLink>] [-f <findLink>] [--extra-index-url=<extraIndexURL>]
func splitPinnedPythonRequirement(requirement string) (name string, version string, findLinks []string, extraIndexURLs []string, err error) {
pinnedPackageRe := regexp.MustCompile(`(?:([a-zA-Z0-9\-_]+)==([^ ]+)|--find-links=([^\s]+)|-f\s+([^\s]+)|--extra-index-url=([^\s]+))`)

matches := pinnedPackageRe.FindAllStringSubmatch(requirement, -1)
if matches == nil {
return "", "", nil, nil, fmt.Errorf("Package %s is not in the expected format", requirement)
}

nameFound := false
versionFound := false

for _, match := range matches {
if match[1] != "" {
name = match[1]
nameFound = true
}

if match[2] != "" {
version = match[2]
versionFound = true
}

if match[3] != "" {
findLinks = append(findLinks, match[3])
}

if match[4] != "" {
findLinks = append(findLinks, match[4])
}

if match[5] != "" {
extraIndexURLs = append(extraIndexURLs, match[5])
}
}

if !nameFound || !versionFound {
return "", "", nil, nil, fmt.Errorf("Package name or version is missing in %s", requirement)
}

return name, version, findLinks, extraIndexURLs, nil
}

func sliceContains(slice []string, s string) bool {
for _, el := range slice {
if el == s {
Expand All @@ -614,11 +574,3 @@ func sliceContains(slice []string, s string) bool {
}
return false
}

func packageName(pipRequirement string) string {
match := PipPackageNameRegex.FindStringSubmatch(pipRequirement)
if len(match) <= 1 {
return ""
}
return match[1]
}
2 changes: 1 addition & 1 deletion pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ func TestSplitPinnedPythonRequirement(t *testing.T) {
}

for _, tc := range testCases {
name, version, findLinks, extraIndexURLs, err := splitPinnedPythonRequirement(tc.input)
name, version, findLinks, extraIndexURLs, err := SplitPinnedPythonRequirement(tc.input)

if tc.expectedError {
require.Error(t, err)
Expand Down
130 changes: 130 additions & 0 deletions pkg/config/requirements.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package config

import (
"bufio"
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"sort"
)

func GenerateRequirements(tmpDir string, config *Config) (string, error) {
// Deduplicate packages between the requirements.txt and the python packages directive.
packageNames := make(map[string]string)

// Read the python packages configuration.
for _, requirement := range config.Build.PythonPackages {
packageName, err := PackageName(requirement)
if err != nil {
return "", err
}
packageNames[packageName] = requirement
}

// Read the python requirements.
if config.Build.PythonRequirements != "" {
fh, err := os.Open(config.Build.PythonRequirements)
if err != nil {
return "", err
}
scanner := bufio.NewScanner(fh)
for scanner.Scan() {
requirement := scanner.Text()
packageName, err := PackageName(requirement)
if err != nil {
return "", err
}
packageNames[packageName] = requirement
}
}

// If we don't have any packages skip further processing
if len(packageNames) == 0 {
return "", nil
}

// Sort the package names by alphabetical order.
keys := make([]string, 0, len(packageNames))
for k := range packageNames {
keys = append(keys, k)
}
sort.Strings(keys)

// Render the expected contents
requirementsContent := ""
for _, k := range keys {
requirementsContent += packageNames[k] + "\n"
}

// Check against the old requirements contents
requirementsFile := filepath.Join(tmpDir, "requirements.txt")
_, err := os.Stat(requirementsFile)
if !errors.Is(err, os.ErrNotExist) {
bytes, err := os.ReadFile(requirementsFile)
if err != nil {
return "", err
}
oldRequirementsContents := string(bytes)
if oldRequirementsContents == requirementsFile {
return requirementsFile, nil
}
}

// Write out a new requirements file
err = os.WriteFile(requirementsFile, []byte(requirementsContent), 0o644)
if err != nil {
return "", err
}
return requirementsFile, nil
}

// SplitPinnedPythonRequirement returns the name, version, findLinks, and extraIndexURLs from a requirements.txt line
// in the form name==version [--find-links=<findLink>] [-f <findLink>] [--extra-index-url=<extraIndexURL>]
func SplitPinnedPythonRequirement(requirement string) (name string, version string, findLinks []string, extraIndexURLs []string, err error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized this. Do we still need this, now that monobase.user dedups the requirements and handles Torch index anyway?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeh we need to perform a dedup ourselves on the cog.yaml, it offers both python_packages and python_requirements as an option.

pinnedPackageRe := regexp.MustCompile(`(?:([a-zA-Z0-9\-_]+)==([^ ]+)|--find-links=([^\s]+)|-f\s+([^\s]+)|--extra-index-url=([^\s]+))`)

matches := pinnedPackageRe.FindAllStringSubmatch(requirement, -1)
if matches == nil {
return "", "", nil, nil, fmt.Errorf("Package %s is not in the expected format", requirement)
}

nameFound := false
versionFound := false

for _, match := range matches {
if match[1] != "" {
name = match[1]
nameFound = true
}

if match[2] != "" {
version = match[2]
versionFound = true
}

if match[3] != "" {
findLinks = append(findLinks, match[3])
}

if match[4] != "" {
findLinks = append(findLinks, match[4])
}

if match[5] != "" {
extraIndexURLs = append(extraIndexURLs, match[5])
}
}

if !nameFound || !versionFound {
return "", "", nil, nil, fmt.Errorf("Package name or version is missing in %s", requirement)
}

return name, version, findLinks, extraIndexURLs, nil
}

func PackageName(pipRequirement string) (string, error) {
name, _, _, _, err := SplitPinnedPythonRequirement(pipRequirement)
return name, err
}
21 changes: 21 additions & 0 deletions pkg/config/requirements_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package config

import (
"path/filepath"
"testing"

"github.com/stretchr/testify/require"
)

func TestGenerateRequirements(t *testing.T) {
tmpDir := t.TempDir()
build := Build{
PythonPackages: []string{"torch==2.5.1"},
}
config := Config{
Build: &build,
}
requirementsFile, err := GenerateRequirements(tmpDir, &config)
require.NoError(t, err)
require.Equal(t, filepath.Join(tmpDir, "requirements.txt"), requirementsFile)
}
12 changes: 9 additions & 3 deletions pkg/docker/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,22 @@ import (
"strings"

"github.com/replicate/cog/pkg/config"
"github.com/replicate/cog/pkg/dockerfile"

"github.com/replicate/cog/pkg/util"
"github.com/replicate/cog/pkg/util/console"
)

func Build(dir, dockerfile, imageName string, secrets []string, noCache bool, progressOutput string, epoch int64) error {
func Build(dir, dockerfileContents, imageName string, secrets []string, noCache bool, progressOutput string, epoch int64) error {
var args []string

userCache, err := dockerfile.UserCache()
if err != nil {
return err
}

args = append(args,
"buildx", "build",
"buildx", "build", "--build-context", "usercache="+userCache,
)

if util.IsAppleSiliconMac(runtime.GOOS, runtime.GOARCH) {
Expand Down Expand Up @@ -65,7 +71,7 @@ func Build(dir, dockerfile, imageName string, secrets []string, noCache bool, pr
cmd.Dir = dir
cmd.Stdout = os.Stderr // redirect stdout to stderr - build output is all messaging
cmd.Stderr = os.Stderr
cmd.Stdin = strings.NewReader(dockerfile)
cmd.Stdin = strings.NewReader(dockerfileContents)

console.Debug("$ " + strings.Join(cmd.Args, " "))
return cmd.Run()
Expand Down
2 changes: 1 addition & 1 deletion pkg/dockerfile/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func (g *BaseImageGenerator) GenerateDockerfile() (string, error) {
return "", err
}

generator, err := NewGenerator(conf, "")
generator, err := NewGenerator(conf, "", false)
if err != nil {
return "", err
}
Expand Down
33 changes: 33 additions & 0 deletions pkg/dockerfile/build_tempdir.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package dockerfile

import (
"os"
"path"
"time"
)

func BuildCogTempDir(dir string) (string, error) {
rootTmp := path.Join(dir, ".cog/tmp")
if err := os.MkdirAll(rootTmp, 0o755); err != nil {
return "", err
}
return rootTmp, nil
}

func BuildTempDir(dir string) (string, error) {
rootTmp, err := BuildCogTempDir(dir)
if err != nil {
return "", err
}

if err := os.MkdirAll(rootTmp, 0o755); err != nil {
return "", err
}
// tmpDir ends up being something like dir/.cog/tmp/build20240620123456.000000
now := time.Now().Format("20060102150405.000000")
tmpDir, err := os.MkdirTemp(rootTmp, "build"+now)
if err != nil {
return "", err
}
return tmpDir, nil
}
15 changes: 15 additions & 0 deletions pkg/dockerfile/build_tempdir_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package dockerfile

import (
"path/filepath"
"testing"

"github.com/stretchr/testify/require"
)

func TestBuildCogTempDir(t *testing.T) {
tmpDir := t.TempDir()
cogTmpDir, err := BuildCogTempDir(tmpDir)
require.NoError(t, err)
require.Equal(t, filepath.Join(tmpDir, ".cog/tmp"), cogTmpDir)
}
6 changes: 6 additions & 0 deletions pkg/dockerfile/cog_embed.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package dockerfile

import "embed"

//go:embed embed/*.whl
var CogEmbed embed.FS
Loading
Loading