From 062d7033b90b4aff36b71fee243c7d4401848503 Mon Sep 17 00:00:00 2001 From: Daishan Peng Date: Wed, 22 Jan 2025 10:00:51 -0700 Subject: [PATCH] Chore: Support openai o1 model Signed-off-by: Daishan Peng --- pkg/openai/client.go | 56 +++++++++++++++++++++++++++++++---------- pkg/types/completion.go | 1 + 2 files changed, 44 insertions(+), 13 deletions(-) diff --git a/pkg/openai/client.go b/pkg/openai/client.go index db911962..197ce53e 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -240,7 +240,7 @@ func toToolCall(call types.CompletionToolCall) openai.ToolCall { } } -func toMessages(request types.CompletionRequest, compat bool) (result []openai.ChatCompletionMessage, err error) { +func toMessages(request types.CompletionRequest, compat, useO1Model bool) (result []openai.ChatCompletionMessage, err error) { var ( systemPrompts []string msgs []types.CompletionMessage @@ -259,8 +259,12 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C } if len(systemPrompts) > 0 { + role := types.CompletionMessageRoleTypeSystem + if useO1Model { + role = types.CompletionMessageRoleTypeDeveloper + } msgs = slices.Insert(msgs, 0, types.CompletionMessage{ - Role: types.CompletionMessageRoleTypeSystem, + Role: role, Content: types.Text(strings.Join(systemPrompts, "\n")), }) } @@ -306,9 +310,9 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C return } -func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) { +func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, envs []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) { if err := c.ValidAuth(); err != nil { - if err := c.RetrieveAPIKey(ctx, env); err != nil { + if err := c.RetrieveAPIKey(ctx, envs); err != nil { return nil, err } } @@ -317,7 +321,9 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques messageRequest.Model = c.defaultModel } - msgs, err := toMessages(messageRequest, !c.setSeed) + useO1Model := isO1Model(messageRequest.Model, envs) + + msgs, err := toMessages(messageRequest, !c.setSeed, useO1Model) if err != nil { return nil, err } @@ -348,10 +354,13 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques MaxTokens: messageRequest.MaxTokens, } - if messageRequest.Temperature == nil { - request.Temperature = new(float32) - } else { - request.Temperature = messageRequest.Temperature + // openai O1 doesn't support setting temperature + if !useO1Model { + if messageRequest.Temperature == nil { + messageRequest.Temperature = new(float32) + } else { + request.Temperature = messageRequest.Temperature + } } if messageRequest.JSONResponse { @@ -404,7 +413,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques if err != nil { return nil, err } else if !ok { - result, err = c.call(ctx, request, id, env, status) + result, err = c.call(ctx, request, id, envs, status) // If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass. var apiError *openai.APIError @@ -412,7 +421,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques // Decrease maxTokens by 10% to make garbage collection more aggressive. // The retry loop will further decrease maxTokens if needed. maxTokens := decreaseTenPercent(messageRequest.MaxTokens) - result, err = c.contextLimitRetryLoop(ctx, request, id, env, maxTokens, status) + result, err = c.contextLimitRetryLoop(ctx, request, id, envs, maxTokens, status) } if err != nil { return nil, err @@ -446,6 +455,22 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques return &result, nil } +func isO1Model(model string, envs []string) bool { + if model == "o1" { + return true + } + + o1Model := false + for _, env := range envs { + k, v, _ := strings.Cut(env, "=") + if k == "OPENAI_MODEL_NAME" && v == "o1" { + o1Model = true + } + } + + return o1Model +} + func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, env []string, maxTokens int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) { var ( response types.CompletionMessage @@ -545,9 +570,14 @@ func override(left, right string) string { return left } -func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, transactionID string, env []string, partial chan<- types.CompletionStatus) (types.CompletionMessage, error) { +func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, transactionID string, envs []string, partial chan<- types.CompletionStatus) (types.CompletionMessage, error) { streamResponse := os.Getenv("GPTSCRIPT_INTERNAL_OPENAI_STREAMING") != "false" + useO1Model := isO1Model(request.Model, envs) + if useO1Model { + streamResponse = false + } + partial <- types.CompletionStatus{ CompletionID: transactionID, PartialResponse: &types.CompletionMessage{ @@ -567,7 +597,7 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, }, } ) - for _, e := range env { + for _, e := range envs { if strings.HasPrefix(e, "GPTSCRIPT_MODEL_PROVIDER_") { modelProviderEnv = append(modelProviderEnv, e) } else if strings.HasPrefix(e, "GPTSCRIPT_DISABLE_RETRIES") { diff --git a/pkg/types/completion.go b/pkg/types/completion.go index 2362071f..439178c9 100644 --- a/pkg/types/completion.go +++ b/pkg/types/completion.go @@ -41,6 +41,7 @@ type CompletionFunctionDefinition struct { const ( CompletionMessageRoleTypeUser = CompletionMessageRoleType("user") CompletionMessageRoleTypeSystem = CompletionMessageRoleType("system") + CompletionMessageRoleTypeDeveloper = CompletionMessageRoleType("developer") CompletionMessageRoleTypeAssistant = CompletionMessageRoleType("assistant") CompletionMessageRoleTypeTool = CompletionMessageRoleType("tool") )