Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go/plugins/compat_oai/compat_oai.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (o *OpenAICompatible) DefineModel(provider, id string, opts ai.ModelOptions
generator := NewModelGenerator(o.client, modelName).WithMessages(input.Messages).WithConfig(input.Config).WithTools(input.Tools)

// Generate response
resp, err := generator.Generate(ctx, cb)
resp, err := generator.Generate(ctx, input, cb)
if err != nil {
return nil, err
}
Expand Down
77 changes: 51 additions & 26 deletions go/plugins/compat_oai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator {

am := openai.ChatCompletionAssistantMessageParam{}
am.Content.OfString = param.NewOpt(content)
toolCalls := convertToolCalls(msg.Content)
toolCalls, err := convertToolCalls(msg.Content)
if err != nil {
g.err = err
return g
}
if len(toolCalls) > 0 {
am.ToolCalls = (toolCalls)
}
Expand All @@ -100,10 +104,12 @@ func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator {
toolCallID = p.ToolResponse.Name
}

tm := openai.ToolMessage(
anyToJSONString(p.ToolResponse.Output),
toolCallID,
)
toolOutput, err := anyToJSONString(p.ToolResponse.Output)
if err != nil {
g.err = err
return g
}
tm := openai.ToolMessage(toolOutput, toolCallID)
oaiMessages = append(oaiMessages, tm)
}
case ai.RoleUser:
Expand Down Expand Up @@ -210,7 +216,7 @@ func (g *ModelGenerator) WithTools(tools []*ai.ToolDefinition) *ModelGenerator {
}

// Generate executes the generation request
func (g *ModelGenerator) Generate(ctx context.Context, handleChunk func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
func (g *ModelGenerator) Generate(ctx context.Context, req *ai.ModelRequest, handleChunk func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
// Check for any errors that occurred during building
if g.err != nil {
return nil, g.err
Expand All @@ -228,7 +234,7 @@ func (g *ModelGenerator) Generate(ctx context.Context, handleChunk func(context.
if handleChunk != nil {
return g.generateStream(ctx, handleChunk)
}
return g.generateComplete(ctx)
return g.generateComplete(ctx, req)
}

// concatenateContent concatenates text content into a single string
Expand Down Expand Up @@ -322,11 +328,19 @@ func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(co
if choice.FinishReason == "tool_calls" && currentToolCall != nil {
// parse accumulated arguments string
for _, toolcall := range toolCallCollects {
toolcall.toolCall.Input = jsonStringToMap(toolcall.args)
args, err := jsonStringToMap(toolcall.args)
if err != nil {
return nil, fmt.Errorf("could not parse tool args: %w", err)
}
toolcall.toolCall.Input = args
fullResponse.Message.Content = append(fullResponse.Message.Content, ai.NewToolRequestPart(toolcall.toolCall))
}
if currentArguments != "" {
currentToolCall.Input = jsonStringToMap(currentArguments)
args, err := jsonStringToMap(currentArguments)
if err != nil {
return nil, fmt.Errorf("could not parse tool args: %w", err)
}
currentToolCall.Input = args
}
fullResponse.Message.Content = append(fullResponse.Message.Content, ai.NewToolRequestPart(currentToolCall))
}
Expand Down Expand Up @@ -356,14 +370,14 @@ func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(co
}

// generateComplete generates a complete model response
func (g *ModelGenerator) generateComplete(ctx context.Context) (*ai.ModelResponse, error) {
func (g *ModelGenerator) generateComplete(ctx context.Context, req *ai.ModelRequest) (*ai.ModelResponse, error) {
completion, err := g.client.Chat.Completions.New(ctx, *g.request)
if err != nil {
return nil, fmt.Errorf("failed to create completion: %w", err)
}

resp := &ai.ModelResponse{
Request: &ai.ModelRequest{},
Request: req,
Usage: &ai.GenerationUsage{
InputTokens: int(completion.Usage.PromptTokens),
OutputTokens: int(completion.Usage.CompletionTokens),
Expand Down Expand Up @@ -392,10 +406,14 @@ func (g *ModelGenerator) generateComplete(ctx context.Context) (*ai.ModelRespons
// handle tool calls
var toolRequestParts []*ai.Part
for _, toolCall := range choice.Message.ToolCalls {
args, err := jsonStringToMap(toolCall.Function.Arguments)
if err != nil {
return nil, err
}
toolRequestParts = append(toolRequestParts, ai.NewToolRequestPart(&ai.ToolRequest{
Ref: toolCall.ID,
Name: toolCall.Function.Name,
Input: jsonStringToMap(toolCall.Function.Arguments),
Input: args,
}))
}

Expand All @@ -412,50 +430,57 @@ func (g *ModelGenerator) generateComplete(ctx context.Context) (*ai.ModelRespons
return resp, nil
}

func convertToolCalls(content []*ai.Part) []openai.ChatCompletionMessageToolCallParam {
func convertToolCalls(content []*ai.Part) ([]openai.ChatCompletionMessageToolCallParam, error) {
var toolCalls []openai.ChatCompletionMessageToolCallParam
for _, p := range content {
if !p.IsToolRequest() {
continue
}
toolCall := convertToolCall(p)
toolCalls = append(toolCalls, toolCall)
toolCall, err := convertToolCall(p)
if err != nil {
return nil, err
}
toolCalls = append(toolCalls, *toolCall)
}
return toolCalls
return toolCalls, nil
}

func convertToolCall(part *ai.Part) openai.ChatCompletionMessageToolCallParam {
func convertToolCall(part *ai.Part) (*openai.ChatCompletionMessageToolCallParam, error) {
toolCallID := part.ToolRequest.Ref
if toolCallID == "" {
toolCallID = part.ToolRequest.Name
}

param := openai.ChatCompletionMessageToolCallParam{
param := &openai.ChatCompletionMessageToolCallParam{
ID: (toolCallID),
Function: (openai.ChatCompletionMessageToolCallFunctionParam{
Name: (part.ToolRequest.Name),
}),
}

args, err := anyToJSONString(part.ToolRequest.Input)
if err != nil {
return nil, err
}
if part.ToolRequest.Input != nil {
param.Function.Arguments = (anyToJSONString(part.ToolRequest.Input))
param.Function.Arguments = args
}

return param
return param, nil
}

func jsonStringToMap(jsonString string) map[string]any {
func jsonStringToMap(jsonString string) (map[string]any, error) {
var result map[string]any
if err := json.Unmarshal([]byte(jsonString), &result); err != nil {
panic(fmt.Errorf("unmarshal failed to parse json string %s: %w", jsonString, err))
return nil, fmt.Errorf("unmarshal failed to parse json string %s: %w", jsonString, err)
}
return result
return result, nil
}

func anyToJSONString(data any) string {
func anyToJSONString(data any) (string, error) {
jsonBytes, err := json.Marshal(data)
if err != nil {
panic(fmt.Errorf("failed to marshal any to JSON string: data, %#v %w", data, err))
return "", fmt.Errorf("failed to marshal any to JSON string: data, %#v %w", data, err)
}
return string(jsonBytes)
return string(jsonBytes), nil
}
16 changes: 12 additions & 4 deletions go/plugins/compat_oai/generate_live_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,11 @@ func TestGenerator_Complete(t *testing.T) {
},
},
}
req := &ai.ModelRequest{
Messages: messages,
}

resp, err := g.WithMessages(messages).Generate(context.Background(), nil)
resp, err := g.WithMessages(messages).Generate(context.Background(), req, nil)
if err != nil {
t.Error(err)
}
Expand All @@ -79,7 +82,6 @@ func TestGenerator_Complete(t *testing.T) {

func TestGenerator_Stream(t *testing.T) {
g := setupTestClient(t)

messages := []*ai.Message{
{
Role: ai.RoleUser,
Expand All @@ -88,6 +90,9 @@ func TestGenerator_Stream(t *testing.T) {
},
},
}
req := &ai.ModelRequest{
Messages: messages,
}

var chunks []string
handleChunk := func(ctx context.Context, chunk *ai.ModelResponseChunk) error {
Expand All @@ -97,7 +102,7 @@ func TestGenerator_Stream(t *testing.T) {
return nil
}

_, err := g.WithMessages(messages).Generate(context.Background(), handleChunk)
_, err := g.WithMessages(messages).Generate(context.Background(), req, handleChunk)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -229,11 +234,14 @@ func TestWithConfig(t *testing.T) {
},
},
}
req := &ai.ModelRequest{
Messages: messages,
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
generator := setupTestClient(t)
result, err := generator.WithMessages(messages).WithConfig(tt.config).Generate(context.Background(), nil)
result, err := generator.WithMessages(messages).WithConfig(tt.config).Generate(context.Background(), req, nil)

if tt.err != nil {
assert.Error(t, err)
Expand Down
29 changes: 29 additions & 0 deletions go/plugins/compat_oai/openai/openai_live_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,4 +250,33 @@ func TestPlugin(t *testing.T) {
}
t.Logf("invalid config type error: %v", err)
})

t.Run("check history", func(t *testing.T) {
resp, err := genkit.Generate(ctx, g,
ai.WithPrompt("Tell me a joke"))
if err != nil {
t.Fatal("got error: %w", err)
}
if resp.Request == nil {
t.Fatal("unexpected nil pointer for request")
}
if len(resp.Request.Messages) == 0 {
t.Fatal("expecting user messages in request")
}
resp, err = genkit.Generate(ctx, g,
ai.WithMessages(resp.History()...),
ai.WithPrompt("explain the joke that you just provided me"))
if err != nil {
t.Fatal("got error: %w", err)
}
userMsgCount := 0
for _, m := range resp.History() {
if m.Role == ai.RoleUser {
userMsgCount += 1
}
}
if userMsgCount != 2 {
t.Fatalf("expecting 2 user messages, got: %d", userMsgCount)
}
})
}
Loading