Skip to content

Commit

Permalink
Chore: Support openai o1 model
Browse files Browse the repository at this point in the history
Signed-off-by: Daishan Peng <[email protected]>
  • Loading branch information
StrongMonkey committed Jan 22, 2025
1 parent 3f876b2 commit 062d703
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 13 deletions.
56 changes: 43 additions & 13 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func toToolCall(call types.CompletionToolCall) openai.ToolCall {
}
}

func toMessages(request types.CompletionRequest, compat bool) (result []openai.ChatCompletionMessage, err error) {
func toMessages(request types.CompletionRequest, compat, useO1Model bool) (result []openai.ChatCompletionMessage, err error) {
var (
systemPrompts []string
msgs []types.CompletionMessage
Expand All @@ -259,8 +259,12 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
}

if len(systemPrompts) > 0 {
role := types.CompletionMessageRoleTypeSystem
if useO1Model {
role = types.CompletionMessageRoleTypeDeveloper
}
msgs = slices.Insert(msgs, 0, types.CompletionMessage{
Role: types.CompletionMessageRoleTypeSystem,
Role: role,
Content: types.Text(strings.Join(systemPrompts, "\n")),
})
}
Expand Down Expand Up @@ -306,9 +310,9 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
return
}

func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, envs []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
if err := c.ValidAuth(); err != nil {
if err := c.RetrieveAPIKey(ctx, env); err != nil {
if err := c.RetrieveAPIKey(ctx, envs); err != nil {
return nil, err
}
}
Expand All @@ -317,7 +321,9 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
messageRequest.Model = c.defaultModel
}

msgs, err := toMessages(messageRequest, !c.setSeed)
useO1Model := isO1Model(messageRequest.Model, envs)

msgs, err := toMessages(messageRequest, !c.setSeed, useO1Model)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -348,10 +354,13 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
MaxTokens: messageRequest.MaxTokens,
}

if messageRequest.Temperature == nil {
request.Temperature = new(float32)
} else {
request.Temperature = messageRequest.Temperature
// openai O1 doesn't support setting temperature
if !useO1Model {
if messageRequest.Temperature == nil {
messageRequest.Temperature = new(float32)
} else {
request.Temperature = messageRequest.Temperature
}
}

if messageRequest.JSONResponse {
Expand Down Expand Up @@ -404,15 +413,15 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
if err != nil {
return nil, err
} else if !ok {
result, err = c.call(ctx, request, id, env, status)
result, err = c.call(ctx, request, id, envs, status)

// If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass.
var apiError *openai.APIError
if errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" && messageRequest.Chat {
// Decrease maxTokens by 10% to make garbage collection more aggressive.
// The retry loop will further decrease maxTokens if needed.
maxTokens := decreaseTenPercent(messageRequest.MaxTokens)
result, err = c.contextLimitRetryLoop(ctx, request, id, env, maxTokens, status)
result, err = c.contextLimitRetryLoop(ctx, request, id, envs, maxTokens, status)
}
if err != nil {
return nil, err
Expand Down Expand Up @@ -446,6 +455,22 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
return &result, nil
}

func isO1Model(model string, envs []string) bool {
if model == "o1" {
return true
}

o1Model := false
for _, env := range envs {
k, v, _ := strings.Cut(env, "=")
if k == "OPENAI_MODEL_NAME" && v == "o1" {
o1Model = true
}
}

return o1Model
}

func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, env []string, maxTokens int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) {
var (
response types.CompletionMessage
Expand Down Expand Up @@ -545,9 +570,14 @@ func override(left, right string) string {
return left
}

func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, transactionID string, env []string, partial chan<- types.CompletionStatus) (types.CompletionMessage, error) {
func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, transactionID string, envs []string, partial chan<- types.CompletionStatus) (types.CompletionMessage, error) {
streamResponse := os.Getenv("GPTSCRIPT_INTERNAL_OPENAI_STREAMING") != "false"

useO1Model := isO1Model(request.Model, envs)
if useO1Model {
streamResponse = false
}

partial <- types.CompletionStatus{
CompletionID: transactionID,
PartialResponse: &types.CompletionMessage{
Expand All @@ -567,7 +597,7 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
},
}
)
for _, e := range env {
for _, e := range envs {
if strings.HasPrefix(e, "GPTSCRIPT_MODEL_PROVIDER_") {
modelProviderEnv = append(modelProviderEnv, e)
} else if strings.HasPrefix(e, "GPTSCRIPT_DISABLE_RETRIES") {
Expand Down
1 change: 1 addition & 0 deletions pkg/types/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type CompletionFunctionDefinition struct {
const (
CompletionMessageRoleTypeUser = CompletionMessageRoleType("user")
CompletionMessageRoleTypeSystem = CompletionMessageRoleType("system")
CompletionMessageRoleTypeDeveloper = CompletionMessageRoleType("developer")
CompletionMessageRoleTypeAssistant = CompletionMessageRoleType("assistant")
CompletionMessageRoleTypeTool = CompletionMessageRoleType("tool")
)
Expand Down

0 comments on commit 062d703

Please sign in to comment.