diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index 911a427072..d11ae9c00b 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -160,6 +160,10 @@ func cmdPredict(cmd *cobra.Command, args []string) error { return predictIndividualInputs(predictor, inputFlags, outPath) } +func isURI(ref *openapi3.Schema) bool { + return ref != nil && ref.Type == "string" && ref.Format == "uri" +} + func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, outputPath string) error { console.Info("Running prediction...") schema, err := predictor.GetSchema() @@ -172,35 +176,54 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o return err } - prediction, err := predictor.Predict(inputs) - if err != nil { - return err - } + // Ignore @, to make it behave the same as -i + outputPath = strings.TrimPrefix(outputPath, "@") // Generate output depending on type in schema - var out []byte responseSchema := schema.Paths["/predictions"].Post.Responses["200"].Value.Content["application/json"].Schema.Value outputSchema := responseSchema.Properties["output"].Value - // Multiple outputs! - if outputSchema.Type == "array" && outputSchema.Items.Value != nil && outputSchema.Items.Value.Type == "string" && outputSchema.Items.Value.Format == "uri" { - return handleMultipleFileOutput(prediction, outputSchema) - } - if outputSchema.Type == "string" && outputSchema.Format == "uri" { - dataurlObj, err := dataurl.DecodeString((*prediction.Output).(string)) + // Multiple outputs! + if outputSchema.Type == "array" && isURI(outputSchema.Items.Value) { + prediction, err := predictor.Predict(inputs) if err != nil { - return fmt.Errorf("Failed to decode dataurl: %w", err) + return err } - out = dataurlObj.Data - if outputPath == "" { - outputPath = "output" - extension := mime.ExtensionByType(dataurlObj.ContentType()) - if extension != "" { - outputPath += extension + outputs, ok := (*prediction.Output).([]interface{}) + if !ok { + return fmt.Errorf("Failed to decode output") + } + + for i, output := range outputs { + if err := writeDataURLOutput(output.(string), fmt.Sprintf("output.%d", i), true); err != nil { + return err } } - } else if outputSchema.Type == "string" { + return nil + } + + // If outputPath != "", then we now know the output path for sure + if outputPath != "" && checkOutputWritable(outputPath) != nil { + return err + } + + prediction, err := predictor.Predict(inputs) + if err != nil { + return err + } + + if isURI(outputSchema) { + if outputPath == "" { + return writeDataURLOutput((*prediction.Output).(string), "output", true) + } else { + return writeDataURLOutput((*prediction.Output).(string), outputPath, false) + } + } + + var out []byte + + if outputSchema.Type == "string" { // Handle strings separately because if we encode it to JSON it will be surrounded by quotes. s := (*prediction.Output).(string) out = []byte(s) @@ -230,11 +253,26 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o } // Fall back to writing file + return writeOutput(outputPath, out) +} - // Ignore @, to make it behave the same as -i - outputPath = strings.TrimPrefix(outputPath, "@") +// Try to open the file, prevents getting errors after long prediction/training +func checkOutputWritable(outputPath string) error { + outputPath, err := homedir.Expand(outputPath) + if err != nil { + return err + } - return writeOutput(outputPath, out) + // Try to open the file + outFile, err := os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE, 0o755) + if err != nil { + return err + } + + if err := outFile.Close(); err != nil { + return err + } + return nil } func writeOutput(outputPath string, output []byte) error { @@ -244,7 +282,7 @@ func writeOutput(outputPath string, output []byte) error { } // Write to file - outFile, err := os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE, 0o755) + outFile, err := os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755) if err != nil { return err } @@ -259,26 +297,21 @@ func writeOutput(outputPath string, output []byte) error { return nil } -func handleMultipleFileOutput(prediction *predict.Response, outputSchema *openapi3.Schema) error { - outputs, ok := (*prediction.Output).([]interface{}) - if !ok { - return fmt.Errorf("Failed to decode output") +func writeDataURLOutput(outputString string, outputPath string, addExtension bool) error { + dataurlObj, err := dataurl.DecodeString(outputString) + if err != nil { + return fmt.Errorf("Failed to decode dataurl: %w", err) } - - for i, output := range outputs { - outputString := output.(string) - dataurlObj, err := dataurl.DecodeString(outputString) - if err != nil { - return fmt.Errorf("Failed to decode dataurl: %w", err) - } - out := dataurlObj.Data + out := dataurlObj.Data + if addExtension { extension := mime.ExtensionByType(dataurlObj.ContentType()) - outputPath := fmt.Sprintf("output.%d%s", i, extension) - if err := writeOutput(outputPath, out); err != nil { - return err + if extension != "" { + outputPath += extension } } - + if err := writeOutput(outputPath, out); err != nil { + return err + } return nil }