Skip to content

Commit

Permalink
cog predict: check output path writability before predict/train
Browse files Browse the repository at this point in the history
Signed-off-by: Yorick van Pelt <[email protected]>
  • Loading branch information
yorickvP committed Nov 20, 2023
1 parent 21b7b9c commit a22ed38
Showing 1 changed file with 72 additions and 39 deletions.
111 changes: 72 additions & 39 deletions pkg/cli/predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
return predictIndividualInputs(predictor, inputFlags, outPath)
}

func isURI(ref *openapi3.Schema) bool {
return ref != nil && ref.Type == "string" && ref.Format == "uri"
}

func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, outputPath string) error {
console.Info("Running prediction...")
schema, err := predictor.GetSchema()
Expand All @@ -172,35 +176,54 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o
return err
}

prediction, err := predictor.Predict(inputs)
if err != nil {
return err
}
// Ignore @, to make it behave the same as -i
outputPath = strings.TrimPrefix(outputPath, "@")

// Generate output depending on type in schema
var out []byte
responseSchema := schema.Paths["/predictions"].Post.Responses["200"].Value.Content["application/json"].Schema.Value
outputSchema := responseSchema.Properties["output"].Value

// Multiple outputs!
if outputSchema.Type == "array" && outputSchema.Items.Value != nil && outputSchema.Items.Value.Type == "string" && outputSchema.Items.Value.Format == "uri" {
return handleMultipleFileOutput(prediction, outputSchema)
}

if outputSchema.Type == "string" && outputSchema.Format == "uri" {
dataurlObj, err := dataurl.DecodeString((*prediction.Output).(string))
// Multiple outputs!
if outputSchema.Type == "array" && isURI(outputSchema.Items.Value) {
prediction, err := predictor.Predict(inputs)
if err != nil {
return fmt.Errorf("Failed to decode dataurl: %w", err)
return err
}
out = dataurlObj.Data
if outputPath == "" {
outputPath = "output"
extension := mime.ExtensionByType(dataurlObj.ContentType())
if extension != "" {
outputPath += extension
outputs, ok := (*prediction.Output).([]interface{})
if !ok {
return fmt.Errorf("Failed to decode output")
}

for i, output := range outputs {
if err := writeDataURLOutput(output.(string), fmt.Sprintf("output.%d", i), true); err != nil {
return err
}
}
} else if outputSchema.Type == "string" {
return nil
}

// If outputPath != "", then we now know the output path for sure
if outputPath != "" && checkOutputWritable(outputPath) != nil {
return err
}

prediction, err := predictor.Predict(inputs)
if err != nil {
return err
}

if isURI(outputSchema) {
if outputPath == "" {
return writeDataURLOutput((*prediction.Output).(string), "output", true)
} else {
return writeDataURLOutput((*prediction.Output).(string), outputPath, false)
}
}

var out []byte

if outputSchema.Type == "string" {
// Handle strings separately because if we encode it to JSON it will be surrounded by quotes.
s := (*prediction.Output).(string)
out = []byte(s)
Expand Down Expand Up @@ -230,11 +253,26 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o
}

// Fall back to writing file
return writeOutput(outputPath, out)
}

// Ignore @, to make it behave the same as -i
outputPath = strings.TrimPrefix(outputPath, "@")
// Try to open the file, prevents getting errors after long prediction/training
func checkOutputWritable(outputPath string) error {
outputPath, err := homedir.Expand(outputPath)
if err != nil {
return err
}

return writeOutput(outputPath, out)
// Try to open the file
outFile, err := os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE, 0o755)
if err != nil {
return err
}

if err := outFile.Close(); err != nil {
return err
}
return nil
}

func writeOutput(outputPath string, output []byte) error {
Expand All @@ -244,7 +282,7 @@ func writeOutput(outputPath string, output []byte) error {
}

// Write to file
outFile, err := os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE, 0o755)
outFile, err := os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil {
return err
}
Expand All @@ -259,26 +297,21 @@ func writeOutput(outputPath string, output []byte) error {
return nil
}

func handleMultipleFileOutput(prediction *predict.Response, outputSchema *openapi3.Schema) error {
outputs, ok := (*prediction.Output).([]interface{})
if !ok {
return fmt.Errorf("Failed to decode output")
func writeDataURLOutput(outputString string, outputPath string, addExtension bool) error {
dataurlObj, err := dataurl.DecodeString(outputString)
if err != nil {
return fmt.Errorf("Failed to decode dataurl: %w", err)
}

for i, output := range outputs {
outputString := output.(string)
dataurlObj, err := dataurl.DecodeString(outputString)
if err != nil {
return fmt.Errorf("Failed to decode dataurl: %w", err)
}
out := dataurlObj.Data
out := dataurlObj.Data
if addExtension {
extension := mime.ExtensionByType(dataurlObj.ContentType())
outputPath := fmt.Sprintf("output.%d%s", i, extension)
if err := writeOutput(outputPath, out); err != nil {
return err
if extension != "" {
outputPath += extension
}
}

if err := writeOutput(outputPath, out); err != nil {
return err
}
return nil
}

Expand Down

0 comments on commit a22ed38

Please sign in to comment.