Skip to content

Commit

Permalink
fixup handling of response_format multiple type options
Browse files Browse the repository at this point in the history
  • Loading branch information
skyscrapr committed Aug 23, 2024
1 parent cff9ecf commit 49a5f03
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
23 changes: 23 additions & 0 deletions openai/assistant.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package openai

import (
"encoding/json"
"fmt"
"net/url"
"strconv"
Expand Down Expand Up @@ -114,6 +115,7 @@ type AssistantToolResources struct {
}

type AssistantResponseFormat struct {
StringValue string `json:"-"`
Type string `json:"type"`
JsonSchema *struct {
Description *string `json:"description,omitempty"`
Expand All @@ -123,6 +125,27 @@ type AssistantResponseFormat struct {
} `json:"json_schema,omitempty"`
}

func (f *AssistantResponseFormat) UnmarshalJSON(data []byte) error {
// First try to unmarshal it as string
var s string
if err := json.Unmarshal(data, &s); err == nil {
// No error, fill the struct
f.StringValue = s
return nil
}

// Otherwise, try to unmarshal as AssistantResponseFormat
type Alias AssistantResponseFormat
var responseFormat Alias
if err := json.Unmarshal(data, &responseFormat); err != nil {
return err
}
f.Type = responseFormat.Type
f.JsonSchema = responseFormat.JsonSchema

return nil
}

// Create an assistant with a model and instructions.
// [OpenAI Documentation]: https://platform.openai.com/docs/api-reference/assistants/createAssistant
func (e *AssistantsEndpoint) CreateAssistant(req *AssistantRequest) (*Assistant, error) {
Expand Down
16 changes: 10 additions & 6 deletions openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"time"
Expand Down Expand Up @@ -96,11 +95,16 @@ func decodeResponse(body io.Reader, v any) error {
if v == nil {
return nil
}
err := json.NewDecoder(body).Decode(v)
if err != nil {
log.Fatal(err)
}
return nil
return json.NewDecoder(body).Decode(v)

//
// DEBUG - Use below if you want to see the raw response body
//
// rawBody, err := io.ReadAll(body)
// if err != nil {
// return err
// }
// return json.Unmarshal(rawBody, v)
}

func (c *Client) handleErrorResp(resp *http.Response) error {
Expand Down

0 comments on commit 49a5f03

Please sign in to comment.