Skip to content

Commit

Permalink
cog predict: check output path writability before predict/train (#1389)
Browse files Browse the repository at this point in the history
* cog predict: check output path writability before predict/train

Co-authored-by: Yorick van Pelt <[email protected]>
Signed-off-by: Yorick van Pelt <[email protected]>

* Use unix.Access to check writeability without creating or opening the file

* Move output path writeable check earlier

* Wrap error

* Refactor

* Formatting

---------

Signed-off-by: Yorick van Pelt <[email protected]>
Co-authored-by: Mattt Zmuda <[email protected]>
  • Loading branch information
yorickvP and mattt authored Jul 17, 2024
1 parent f7759a0 commit c95733b
Showing 1 changed file with 105 additions and 51 deletions.
156 changes: 105 additions & 51 deletions pkg/cli/predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ import (
"fmt"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"

"github.com/getkin/kin-openapi/openapi3"
"github.com/mitchellh/go-homedir"
"github.com/spf13/cobra"
"github.com/vincent-petithory/dataurl"
"golang.org/x/sys/unix"

"github.com/replicate/cog/pkg/config"
"github.com/replicate/cog/pkg/docker"
Expand Down Expand Up @@ -167,6 +169,10 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
return predictIndividualInputs(predictor, inputFlags, outPath)
}

func isURI(ref *openapi3.Schema) bool {
return ref != nil && ref.Type.Is("string") && ref.Format == "uri"
}

func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, outputPath string) error {
console.Info("Running prediction...")
schema, err := predictor.GetSchema()
Expand All @@ -179,48 +185,85 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o
return err
}

prediction, err := predictor.Predict(inputs)
if err != nil {
return err
// If outputPath != "", then we now know the output path for sure
if outputPath != "" {
// Ignore @, to make it behave the same as -i
outputPath = strings.TrimPrefix(outputPath, "@")

if err := checkOutputWritable(outputPath); err != nil {
return fmt.Errorf("Output path is not writable: %w", err)
}
}

// Generate output depending on type in schema
var out []byte
responseSchema := schema.Paths.Value("/predictions").Post.Responses.Value("200").Value.Content["application/json"].Schema.Value
outputSchema := responseSchema.Properties["output"].Value

// Multiple outputs!
if outputSchema.Type.Is("array") && outputSchema.Items.Value != nil && outputSchema.Items.Value.Type.Is("string") && outputSchema.Items.Value.Format == "uri" {
return handleMultipleFileOutput(prediction, outputSchema)
prediction, err := predictor.Predict(inputs)
if err != nil {
return fmt.Errorf("Failed to predict: %w", err)
}

if prediction.Output == nil {
console.Warn("No output generated")
return nil
}

switch {
case outputSchema.Type.Is("string") && outputSchema.Format == "uri":
dataurlObj, err := dataurl.DecodeString((*prediction.Output).(string))
if err != nil {
return fmt.Errorf("Failed to decode dataurl: %w", err)
}
out = dataurlObj.Data
case isURI(outputSchema):
addExtension := false
if outputPath == "" {
outputPath = "output"
extension := mime.ExtensionByType(dataurlObj.ContentType())
if extension != "" {
outputPath += extension
}
addExtension = true
}
case outputSchema.Type.Is("string"):
// Handle strings separately because if we encode it to JSON it will be surrounded by quotes.
if prediction.Output == nil {
console.Warnf("No output generated")
return nil

outputStr, ok := (*prediction.Output).(string)
if !ok {
return fmt.Errorf("Failed to convert prediction output to string")
}

if err := writeDataURLOutput(outputStr, outputPath, addExtension); err != nil {
return fmt.Errorf("Failed to write output: %w", err)
}

return nil
case outputSchema.Type.Is("array") && isURI(outputSchema.Items.Value):
outputs, ok := (*prediction.Output).([]interface{})
if !ok {
return fmt.Errorf("Failed to decode output")
}

for i, output := range outputs {
outputPath := fmt.Sprintf("output.%d", i)
addExtension := true

outputStr, ok := output.(string)
if !ok {
return fmt.Errorf("Failed to convert prediction output to string")
}

if err := writeDataURLOutput(outputStr, outputPath, addExtension); err != nil {
return fmt.Errorf("Failed to write output %d: %w", i, err)
}
}

return nil
case outputSchema.Type.Is("string"):
s, ok := (*prediction.Output).(string)
if !ok {
return fmt.Errorf("Failed to convert prediction output to string")
}

out = []byte(s)
if outputPath == "" {
console.Output(s)
} else {
err := writeOutput(outputPath, []byte(s))
if err != nil {
return fmt.Errorf("Failed to write output: %w", err)
}
}

return nil
default:
// Treat everything else as JSON -- ints, floats, bools will all convert correctly.
rawJSON, err := json.Marshal(prediction.Output)
Expand All @@ -231,26 +274,39 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o
if err := json.Indent(&indentedJSON, rawJSON, "", " "); err != nil {
return err
}
out = indentedJSON.Bytes()

// FIXME: this stopped working
// f := colorjson.NewFormatter()
// f.Indent = 2
// s, _ := f.Marshal(obj)
}
if outputPath == "" {
console.Output(indentedJSON.String())
} else {
err := writeOutput(outputPath, indentedJSON.Bytes())
if err != nil {
return fmt.Errorf("Failed to write output: %w", err)
}
}

// Write to stdout
if outputPath == "" {
console.Output(string(out))
return nil
}
}

// Fall back to writing file
func checkOutputWritable(outputPath string) error {
outputPath, err := homedir.Expand(outputPath)
if err != nil {
return err
}

// Ignore @, to make it behave the same as -i
outputPath = strings.TrimPrefix(outputPath, "@")
// Check if the file exists
_, err = os.Stat(outputPath)
if err == nil {
// File exists, check if it's writable
return unix.Access(outputPath, unix.W_OK)
} else if os.IsNotExist(err) {
// File doesn't exist, check if the directory is writable
dir := filepath.Dir(outputPath)
return unix.Access(dir, unix.W_OK)
}

return writeOutput(outputPath, out)
// Some other error occurred
return err
}

func writeOutput(outputPath string, output []byte) error {
Expand All @@ -260,7 +316,7 @@ func writeOutput(outputPath string, output []byte) error {
}

// Write to file
outFile, err := os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE, 0o755)
outFile, err := os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil {
return err
}
Expand All @@ -275,26 +331,24 @@ func writeOutput(outputPath string, output []byte) error {
return nil
}

func handleMultipleFileOutput(prediction *predict.Response, outputSchema *openapi3.Schema) error {
outputs, ok := (*prediction.Output).([]interface{})
if !ok {
return fmt.Errorf("Failed to decode output")
func writeDataURLOutput(outputString string, outputPath string, addExtension bool) error {
dataurlObj, err := dataurl.DecodeString(outputString)
if err != nil {
return fmt.Errorf("Failed to decode dataurl: %w", err)
}
output := dataurlObj.Data

for i, output := range outputs {
outputString := output.(string)
dataurlObj, err := dataurl.DecodeString(outputString)
if err != nil {
return fmt.Errorf("Failed to decode dataurl: %w", err)
}
out := dataurlObj.Data
if addExtension {
extension := mime.ExtensionByType(dataurlObj.ContentType())
outputPath := fmt.Sprintf("output.%d%s", i, extension)
if err := writeOutput(outputPath, out); err != nil {
return err
if extension != "" {
outputPath += extension
}
}

if err := writeOutput(outputPath, output); err != nil {
return err
}

return nil
}

Expand Down

0 comments on commit c95733b

Please sign in to comment.