diff --git a/go/plugins/compat_oai/compat_oai.go b/go/plugins/compat_oai/compat_oai.go index 41adb4b44e..2dde807589 100644 --- a/go/plugins/compat_oai/compat_oai.go +++ b/go/plugins/compat_oai/compat_oai.go @@ -111,7 +111,7 @@ func (o *OpenAICompatible) DefineModel(provider, id string, opts ai.ModelOptions generator := NewModelGenerator(o.client, modelName).WithMessages(input.Messages).WithConfig(input.Config).WithTools(input.Tools) // Generate response - resp, err := generator.Generate(ctx, cb) + resp, err := generator.Generate(ctx, input, cb) if err != nil { return nil, err } diff --git a/go/plugins/compat_oai/generate.go b/go/plugins/compat_oai/generate.go index 72cbf8d17b..c37878949f 100644 --- a/go/plugins/compat_oai/generate.go +++ b/go/plugins/compat_oai/generate.go @@ -82,7 +82,11 @@ func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator { am := openai.ChatCompletionAssistantMessageParam{} am.Content.OfString = param.NewOpt(content) - toolCalls := convertToolCalls(msg.Content) + toolCalls, err := convertToolCalls(msg.Content) + if err != nil { + g.err = err + return g + } if len(toolCalls) > 0 { am.ToolCalls = (toolCalls) } @@ -100,10 +104,12 @@ func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator { toolCallID = p.ToolResponse.Name } - tm := openai.ToolMessage( - anyToJSONString(p.ToolResponse.Output), - toolCallID, - ) + toolOutput, err := anyToJSONString(p.ToolResponse.Output) + if err != nil { + g.err = err + return g + } + tm := openai.ToolMessage(toolOutput, toolCallID) oaiMessages = append(oaiMessages, tm) } case ai.RoleUser: @@ -210,7 +216,7 @@ func (g *ModelGenerator) WithTools(tools []*ai.ToolDefinition) *ModelGenerator { } // Generate executes the generation request -func (g *ModelGenerator) Generate(ctx context.Context, handleChunk func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) { +func (g *ModelGenerator) Generate(ctx context.Context, req *ai.ModelRequest, handleChunk func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) { // Check for any errors that occurred during building if g.err != nil { return nil, g.err @@ -228,7 +234,7 @@ func (g *ModelGenerator) Generate(ctx context.Context, handleChunk func(context. if handleChunk != nil { return g.generateStream(ctx, handleChunk) } - return g.generateComplete(ctx) + return g.generateComplete(ctx, req) } // concatenateContent concatenates text content into a single string @@ -322,11 +328,19 @@ func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(co if choice.FinishReason == "tool_calls" && currentToolCall != nil { // parse accumulated arguments string for _, toolcall := range toolCallCollects { - toolcall.toolCall.Input = jsonStringToMap(toolcall.args) + args, err := jsonStringToMap(toolcall.args) + if err != nil { + return nil, fmt.Errorf("could not parse tool args: %w", err) + } + toolcall.toolCall.Input = args fullResponse.Message.Content = append(fullResponse.Message.Content, ai.NewToolRequestPart(toolcall.toolCall)) } if currentArguments != "" { - currentToolCall.Input = jsonStringToMap(currentArguments) + args, err := jsonStringToMap(currentArguments) + if err != nil { + return nil, fmt.Errorf("could not parse tool args: %w", err) + } + currentToolCall.Input = args } fullResponse.Message.Content = append(fullResponse.Message.Content, ai.NewToolRequestPart(currentToolCall)) } @@ -356,14 +370,14 @@ func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(co } // generateComplete generates a complete model response -func (g *ModelGenerator) generateComplete(ctx context.Context) (*ai.ModelResponse, error) { +func (g *ModelGenerator) generateComplete(ctx context.Context, req *ai.ModelRequest) (*ai.ModelResponse, error) { completion, err := g.client.Chat.Completions.New(ctx, *g.request) if err != nil { return nil, fmt.Errorf("failed to create completion: %w", err) } resp := &ai.ModelResponse{ - Request: &ai.ModelRequest{}, + Request: req, Usage: &ai.GenerationUsage{ InputTokens: int(completion.Usage.PromptTokens), OutputTokens: int(completion.Usage.CompletionTokens), @@ -392,10 +406,14 @@ func (g *ModelGenerator) generateComplete(ctx context.Context) (*ai.ModelRespons // handle tool calls var toolRequestParts []*ai.Part for _, toolCall := range choice.Message.ToolCalls { + args, err := jsonStringToMap(toolCall.Function.Arguments) + if err != nil { + return nil, err + } toolRequestParts = append(toolRequestParts, ai.NewToolRequestPart(&ai.ToolRequest{ Ref: toolCall.ID, Name: toolCall.Function.Name, - Input: jsonStringToMap(toolCall.Function.Arguments), + Input: args, })) } @@ -412,50 +430,57 @@ func (g *ModelGenerator) generateComplete(ctx context.Context) (*ai.ModelRespons return resp, nil } -func convertToolCalls(content []*ai.Part) []openai.ChatCompletionMessageToolCallParam { +func convertToolCalls(content []*ai.Part) ([]openai.ChatCompletionMessageToolCallParam, error) { var toolCalls []openai.ChatCompletionMessageToolCallParam for _, p := range content { if !p.IsToolRequest() { continue } - toolCall := convertToolCall(p) - toolCalls = append(toolCalls, toolCall) + toolCall, err := convertToolCall(p) + if err != nil { + return nil, err + } + toolCalls = append(toolCalls, *toolCall) } - return toolCalls + return toolCalls, nil } -func convertToolCall(part *ai.Part) openai.ChatCompletionMessageToolCallParam { +func convertToolCall(part *ai.Part) (*openai.ChatCompletionMessageToolCallParam, error) { toolCallID := part.ToolRequest.Ref if toolCallID == "" { toolCallID = part.ToolRequest.Name } - param := openai.ChatCompletionMessageToolCallParam{ + param := &openai.ChatCompletionMessageToolCallParam{ ID: (toolCallID), Function: (openai.ChatCompletionMessageToolCallFunctionParam{ Name: (part.ToolRequest.Name), }), } + args, err := anyToJSONString(part.ToolRequest.Input) + if err != nil { + return nil, err + } if part.ToolRequest.Input != nil { - param.Function.Arguments = (anyToJSONString(part.ToolRequest.Input)) + param.Function.Arguments = args } - return param + return param, nil } -func jsonStringToMap(jsonString string) map[string]any { +func jsonStringToMap(jsonString string) (map[string]any, error) { var result map[string]any if err := json.Unmarshal([]byte(jsonString), &result); err != nil { - panic(fmt.Errorf("unmarshal failed to parse json string %s: %w", jsonString, err)) + return nil, fmt.Errorf("unmarshal failed to parse json string %s: %w", jsonString, err) } - return result + return result, nil } -func anyToJSONString(data any) string { +func anyToJSONString(data any) (string, error) { jsonBytes, err := json.Marshal(data) if err != nil { - panic(fmt.Errorf("failed to marshal any to JSON string: data, %#v %w", data, err)) + return "", fmt.Errorf("failed to marshal any to JSON string: data, %#v %w", data, err) } - return string(jsonBytes) + return string(jsonBytes), nil } diff --git a/go/plugins/compat_oai/generate_live_test.go b/go/plugins/compat_oai/generate_live_test.go index f3ef0b588c..3e2e340b19 100644 --- a/go/plugins/compat_oai/generate_live_test.go +++ b/go/plugins/compat_oai/generate_live_test.go @@ -64,8 +64,11 @@ func TestGenerator_Complete(t *testing.T) { }, }, } + req := &ai.ModelRequest{ + Messages: messages, + } - resp, err := g.WithMessages(messages).Generate(context.Background(), nil) + resp, err := g.WithMessages(messages).Generate(context.Background(), req, nil) if err != nil { t.Error(err) } @@ -79,7 +82,6 @@ func TestGenerator_Complete(t *testing.T) { func TestGenerator_Stream(t *testing.T) { g := setupTestClient(t) - messages := []*ai.Message{ { Role: ai.RoleUser, @@ -88,6 +90,9 @@ func TestGenerator_Stream(t *testing.T) { }, }, } + req := &ai.ModelRequest{ + Messages: messages, + } var chunks []string handleChunk := func(ctx context.Context, chunk *ai.ModelResponseChunk) error { @@ -97,7 +102,7 @@ func TestGenerator_Stream(t *testing.T) { return nil } - _, err := g.WithMessages(messages).Generate(context.Background(), handleChunk) + _, err := g.WithMessages(messages).Generate(context.Background(), req, handleChunk) if err != nil { t.Error(err) } @@ -229,11 +234,14 @@ func TestWithConfig(t *testing.T) { }, }, } + req := &ai.ModelRequest{ + Messages: messages, + } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { generator := setupTestClient(t) - result, err := generator.WithMessages(messages).WithConfig(tt.config).Generate(context.Background(), nil) + result, err := generator.WithMessages(messages).WithConfig(tt.config).Generate(context.Background(), req, nil) if tt.err != nil { assert.Error(t, err) diff --git a/go/plugins/compat_oai/openai/openai_live_test.go b/go/plugins/compat_oai/openai/openai_live_test.go index 916b7afa24..d14fcd6fed 100644 --- a/go/plugins/compat_oai/openai/openai_live_test.go +++ b/go/plugins/compat_oai/openai/openai_live_test.go @@ -250,4 +250,33 @@ func TestPlugin(t *testing.T) { } t.Logf("invalid config type error: %v", err) }) + + t.Run("check history", func(t *testing.T) { + resp, err := genkit.Generate(ctx, g, + ai.WithPrompt("Tell me a joke")) + if err != nil { + t.Fatal("got error: %w", err) + } + if resp.Request == nil { + t.Fatal("unexpected nil pointer for request") + } + if len(resp.Request.Messages) == 0 { + t.Fatal("expecting user messages in request") + } + resp, err = genkit.Generate(ctx, g, + ai.WithMessages(resp.History()...), + ai.WithPrompt("explain the joke that you just provided me")) + if err != nil { + t.Fatal("got error: %w", err) + } + userMsgCount := 0 + for _, m := range resp.History() { + if m.Role == ai.RoleUser { + userMsgCount += 1 + } + } + if userMsgCount != 2 { + t.Fatalf("expecting 2 user messages, got: %d", userMsgCount) + } + }) }