Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mattt committed Jul 17, 2024
1 parent 0c0b19d commit 45f2f23
Showing 1 changed file with 60 additions and 47 deletions.
107 changes: 60 additions & 47 deletions pkg/cli/predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,11 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o
return err
}

// Ignore @, to make it behave the same as -i
outputPath = strings.TrimPrefix(outputPath, "@")

// If outputPath != "", then we now know the output path for sure
if outputPath != "" {
// Ignore @, to make it behave the same as -i
outputPath = strings.TrimPrefix(outputPath, "@")

if err := checkOutputWritable(outputPath); err != nil {
return fmt.Errorf("Output path is not writable: %w", err)
}
Expand All @@ -199,55 +199,72 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o
responseSchema := schema.Paths.Value("/predictions").Post.Responses.Value("200").Value.Content["application/json"].Schema.Value
outputSchema := responseSchema.Properties["output"].Value

// Multiple outputs!
if outputSchema.Type.Is("array") && isURI(outputSchema.Items.Value) {
prediction, err := predictor.Predict(inputs)
if err != nil {
return err
prediction, err := predictor.Predict(inputs)
if err != nil {
return fmt.Errorf("Failed to predict: %w", err)
}

if prediction.Output == nil {
console.Warn("No output generated")
return nil
}

switch {
case isURI(outputSchema):
addExtension := false
if outputPath == "" {
outputPath = "output"
addExtension = true
}

outputs, ok := (*prediction.Output).([]interface{})
outputStr, ok := (*prediction.Output).(string)
if !ok {
return fmt.Errorf("Failed to decode output")
return fmt.Errorf("Failed to convert prediction output to string")
}

for i, output := range outputs {
if err := writeDataURLOutput(output.(string), fmt.Sprintf("output.%d", i), true); err != nil {
return err
}
if err := writeDataURLOutput(outputStr, outputPath, addExtension); err != nil {
return fmt.Errorf("Failed to write output: %w", err)
}
return nil
}

prediction, err := predictor.Predict(inputs)
if err != nil {
return fmt.Errorf("Failed to predict: %w", err)
}

if isURI(outputSchema) {
if outputPath == "" {
return writeDataURLOutput((*prediction.Output).(string), "output", true)
} else {
return writeDataURLOutput((*prediction.Output).(string), outputPath, false)
return nil
case outputSchema.Type.Is("array") && isURI(outputSchema.Items.Value):
outputs, ok := (*prediction.Output).([]interface{})
if !ok {
return fmt.Errorf("Failed to decode output")
}
}

var out []byte
for i, output := range outputs {
outputPath := fmt.Sprintf("output.%d", i)
addExtension := true

outputStr, ok := output.(string)
if !ok {
return fmt.Errorf("Failed to convert prediction output to string")
}

if outputSchema.Type.Is("string") {
// Handle strings separately because if we encode it to JSON it will be surrounded by quotes.
if prediction.Output == nil {
console.Warnf("No output generated")
return nil
if err := writeDataURLOutput(outputStr, outputPath, addExtension); err != nil {
return fmt.Errorf("Failed to write output %d: %w", i, err)
}
}

return nil
case outputSchema.Type.Is("string"):
s, ok := (*prediction.Output).(string)
if !ok {
return fmt.Errorf("Failed to convert prediction output to string")
}

out = []byte(s)
} else {
if outputPath == "" {
console.Output(s)
} else {
err := writeOutput(outputPath, []byte(s))
if err != nil {
return fmt.Errorf("Failed to write output: %w", err)
}
}

return nil
default:
// Treat everything else as JSON -- ints, floats, bools will all convert correctly.
rawJSON, err := json.Marshal(prediction.Output)
if err != nil {
Expand All @@ -257,22 +274,18 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o
if err := json.Indent(&indentedJSON, rawJSON, "", " "); err != nil {
return err
}
out = indentedJSON.Bytes()

// FIXME: this stopped working
// f := colorjson.NewFormatter()
// f.Indent = 2
// s, _ := f.Marshal(obj)
}
if outputPath == "" {
console.Output(indentedJSON.String())
} else {
err := writeOutput(outputPath, indentedJSON.Bytes())
if err != nil {
return fmt.Errorf("Failed to write output: %w", err)
}
}

// Write to stdout
if outputPath == "" {
console.Output(string(out))
return nil
}

// Fall back to writing file
return writeOutput(outputPath, out)
}

func checkOutputWritable(outputPath string) error {
Expand Down

0 comments on commit 45f2f23

Please sign in to comment.