diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index d328d0b246..c106678d54 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -7,6 +7,7 @@ import ( "fmt" "os" "os/signal" + "path/filepath" "strings" "syscall" @@ -14,6 +15,7 @@ import ( "github.com/mitchellh/go-homedir" "github.com/spf13/cobra" "github.com/vincent-petithory/dataurl" + "golang.org/x/sys/unix" "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker" @@ -167,6 +169,10 @@ func cmdPredict(cmd *cobra.Command, args []string) error { return predictIndividualInputs(predictor, inputFlags, outPath) } +func isURI(ref *openapi3.Schema) bool { + return ref != nil && ref.Type.Is("string") && ref.Format == "uri" +} + func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, outputPath string) error { console.Info("Running prediction...") schema, err := predictor.GetSchema() @@ -179,48 +185,85 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o return err } - prediction, err := predictor.Predict(inputs) - if err != nil { - return err + // If outputPath != "", then we now know the output path for sure + if outputPath != "" { + // Ignore @, to make it behave the same as -i + outputPath = strings.TrimPrefix(outputPath, "@") + + if err := checkOutputWritable(outputPath); err != nil { + return fmt.Errorf("Output path is not writable: %w", err) + } } // Generate output depending on type in schema - var out []byte responseSchema := schema.Paths.Value("/predictions").Post.Responses.Value("200").Value.Content["application/json"].Schema.Value outputSchema := responseSchema.Properties["output"].Value - // Multiple outputs! - if outputSchema.Type.Is("array") && outputSchema.Items.Value != nil && outputSchema.Items.Value.Type.Is("string") && outputSchema.Items.Value.Format == "uri" { - return handleMultipleFileOutput(prediction, outputSchema) + prediction, err := predictor.Predict(inputs) + if err != nil { + return fmt.Errorf("Failed to predict: %w", err) + } + + if prediction.Output == nil { + console.Warn("No output generated") + return nil } switch { - case outputSchema.Type.Is("string") && outputSchema.Format == "uri": - dataurlObj, err := dataurl.DecodeString((*prediction.Output).(string)) - if err != nil { - return fmt.Errorf("Failed to decode dataurl: %w", err) - } - out = dataurlObj.Data + case isURI(outputSchema): + addExtension := false if outputPath == "" { outputPath = "output" - extension := mime.ExtensionByType(dataurlObj.ContentType()) - if extension != "" { - outputPath += extension - } + addExtension = true } - case outputSchema.Type.Is("string"): - // Handle strings separately because if we encode it to JSON it will be surrounded by quotes. - if prediction.Output == nil { - console.Warnf("No output generated") - return nil + + outputStr, ok := (*prediction.Output).(string) + if !ok { + return fmt.Errorf("Failed to convert prediction output to string") + } + + if err := writeDataURLOutput(outputStr, outputPath, addExtension); err != nil { + return fmt.Errorf("Failed to write output: %w", err) + } + + return nil + case outputSchema.Type.Is("array") && isURI(outputSchema.Items.Value): + outputs, ok := (*prediction.Output).([]interface{}) + if !ok { + return fmt.Errorf("Failed to decode output") } + for i, output := range outputs { + outputPath := fmt.Sprintf("output.%d", i) + addExtension := true + + outputStr, ok := output.(string) + if !ok { + return fmt.Errorf("Failed to convert prediction output to string") + } + + if err := writeDataURLOutput(outputStr, outputPath, addExtension); err != nil { + return fmt.Errorf("Failed to write output %d: %w", i, err) + } + } + + return nil + case outputSchema.Type.Is("string"): s, ok := (*prediction.Output).(string) if !ok { return fmt.Errorf("Failed to convert prediction output to string") } - out = []byte(s) + if outputPath == "" { + console.Output(s) + } else { + err := writeOutput(outputPath, []byte(s)) + if err != nil { + return fmt.Errorf("Failed to write output: %w", err) + } + } + + return nil default: // Treat everything else as JSON -- ints, floats, bools will all convert correctly. rawJSON, err := json.Marshal(prediction.Output) @@ -231,26 +274,39 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o if err := json.Indent(&indentedJSON, rawJSON, "", " "); err != nil { return err } - out = indentedJSON.Bytes() - // FIXME: this stopped working - // f := colorjson.NewFormatter() - // f.Indent = 2 - // s, _ := f.Marshal(obj) - } + if outputPath == "" { + console.Output(indentedJSON.String()) + } else { + err := writeOutput(outputPath, indentedJSON.Bytes()) + if err != nil { + return fmt.Errorf("Failed to write output: %w", err) + } + } - // Write to stdout - if outputPath == "" { - console.Output(string(out)) return nil } +} - // Fall back to writing file +func checkOutputWritable(outputPath string) error { + outputPath, err := homedir.Expand(outputPath) + if err != nil { + return err + } - // Ignore @, to make it behave the same as -i - outputPath = strings.TrimPrefix(outputPath, "@") + // Check if the file exists + _, err = os.Stat(outputPath) + if err == nil { + // File exists, check if it's writable + return unix.Access(outputPath, unix.W_OK) + } else if os.IsNotExist(err) { + // File doesn't exist, check if the directory is writable + dir := filepath.Dir(outputPath) + return unix.Access(dir, unix.W_OK) + } - return writeOutput(outputPath, out) + // Some other error occurred + return err } func writeOutput(outputPath string, output []byte) error { @@ -260,7 +316,7 @@ func writeOutput(outputPath string, output []byte) error { } // Write to file - outFile, err := os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE, 0o755) + outFile, err := os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755) if err != nil { return err } @@ -275,26 +331,24 @@ func writeOutput(outputPath string, output []byte) error { return nil } -func handleMultipleFileOutput(prediction *predict.Response, outputSchema *openapi3.Schema) error { - outputs, ok := (*prediction.Output).([]interface{}) - if !ok { - return fmt.Errorf("Failed to decode output") +func writeDataURLOutput(outputString string, outputPath string, addExtension bool) error { + dataurlObj, err := dataurl.DecodeString(outputString) + if err != nil { + return fmt.Errorf("Failed to decode dataurl: %w", err) } + output := dataurlObj.Data - for i, output := range outputs { - outputString := output.(string) - dataurlObj, err := dataurl.DecodeString(outputString) - if err != nil { - return fmt.Errorf("Failed to decode dataurl: %w", err) - } - out := dataurlObj.Data + if addExtension { extension := mime.ExtensionByType(dataurlObj.ContentType()) - outputPath := fmt.Sprintf("output.%d%s", i, extension) - if err := writeOutput(outputPath, out); err != nil { - return err + if extension != "" { + outputPath += extension } } + if err := writeOutput(outputPath, output); err != nil { + return err + } + return nil }