diff --git a/go/plugins/compat_oai/xai/README.md b/go/plugins/compat_oai/xai/README.md new file mode 100644 index 0000000000..b19122de5a --- /dev/null +++ b/go/plugins/compat_oai/xai/README.md @@ -0,0 +1,170 @@ +# OpenAI Plugin + +This plugin provides a simple interface for using xAI's services. + +## Supported Models +The plugin supports the following xAI models. + +### Grok 4 (grok-4-0709) + +Modalities: Text input, text output + +Context window: 256 000 + +Features: +- Function calling +- Structured outputs +- Reasoning + +### Grok Code Fast 1 (grok-code-fast-1). + +Modalities: Text input, text output + +Context window: 256 000 + +Features: +- Function calling +- Structured outputs +- Reasoning + +### Grok 4 Fast (grok-4-fast-reasoning). + +Modalities: Text input, image input, text output + +Context window: 2 000 000 + +Features: +- Function calling +- Structured outputs +- Reasoning + +### Grok 4 Fast (Non-Reasoning) + +Modalities: Text input, image input, text output + +Context window: 2 000 000 + +Features: +- Function calling +- Structured outputs + + +### Grok 3 Mini (grok-3-mini) + +Modalities: Text input, text output + +Context window: 131 072 + +Features: +- Function calling +- Structured outputs +- Reasoning + +### Grok 3 (grok-3) + +Modalities: Text input, text output + +Context window: 131 072 + +Features: +- Function calling +- Structured outputs + +### Grok 2 Vision (grok-2-vision) + +Modalities: Text input, image input, text output + +Context window: 32 768 + +Features: +- Function calling +- Structured outputs + +### Grok 2 Image Gen + +Grok 2 Image Gen + +## Prerequisites + +- Go installed on your system +- An xAI API key + +## Usage + +Here's a simple example of how to use the OpenAI plugin: + +```go +import ( + "context" + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" +) +// Initialize the xAI plugin with your API key +x := &XAi{ + Opts: []option.RequestOption{ + option.WithAPIKey(apiKey), + }, +} + +// Initialize Genkit with the xAI plugin +g := genkit.Init(ctx, + genkit.WithDefaultModel("xai/grok-3-mini"), + genkit.WithPlugins(x), +) + +config := &openai.ChatCompletionNewParams{ +// define optional config fields +} + +resp, err = genkit.Generate(ctx, g, +ai.WithPromptText("Write a short sentence about artificial intelligence."), +ai.WithConfig(config), +) +``` + +## Running Tests + +First, set your xAI API key as an environment variable: + +```bash +export XAI_API_KEY= +``` + +### Running All Tests +To run all tests in the directory: +```bash +go test -v . +``` + +### Running Tests from Specific Files +To run tests from a specific file: +```bash +# Run only generate_live_test.go tests +go test -run "^TestGenerator" + +# Run only openai_live_test.go tests +go test -run "^TestPlugin" +``` + +### Running Individual Tests +To run a specific test case: +```bash +# Run only the streaming test from xai_live_test.go +go test -run "TestPlugin/streaming" + +# Run only the Complete test from generate_live_test.go +go test -run "TestGenerator_Complete" + +# Run only the Stream test from generate_live_test.go +go test -run "TestGenerator_Stream" +``` + +### Test Output Verbosity +Add the `-v` flag for verbose output: +```bash +go test -v -run "TestPlugin/streaming" +``` + +Note: All live tests require the XAI_API_KEY environment variable to be set. Tests will be skipped if the API key is not provided. diff --git a/go/plugins/compat_oai/xai/xai.go b/go/plugins/compat_oai/xai/xai.go new file mode 100644 index 0000000000..d4b48a9e21 --- /dev/null +++ b/go/plugins/compat_oai/xai/xai.go @@ -0,0 +1,129 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xai + +import ( + "context" + "os" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/compat_oai" + "github.com/openai/openai-go/option" +) + +const ( + provider = "xai" + baseURL = "https://api.x.ai/v1" +) + +var supportedModels = map[string]ai.ModelOptions{ + "grok-code-fast-1": { + Label: "Grok Code Fast 1", + Supports: &compat_oai.BasicText, + Versions: []string{"grok-code-fast-1"}, + }, + "grok-4-fast-reasoning": { + Label: "Grok 4 Fast", + Supports: &compat_oai.Multimodal, + Versions: []string{"grok-4-fast-reasoning", "grok-4-fast", "grok-4-fast-reasoning-latest"}, + }, + "grok-4-fast-non-reasoning": { + Label: "Grok 4 Fast (Non-Reasoning)", + Supports: &compat_oai.Multimodal, + Versions: []string{"grok-4-fast-non-reasoning", "grok-4-fast-non-reasoning-latest"}, + }, + "grok-4-0709": { + Label: "Grok 4", + Supports: &compat_oai.Multimodal, + Versions: []string{"grok-4-0709", "grok-4", "grok-4-latest"}, + }, + "grok-3": { + Label: "Grok 3", + Supports: &compat_oai.BasicText, + Versions: []string{"grok-3"}, + }, + "grok-3-mini": { + Label: "Grok 3 Mini", + Supports: &compat_oai.BasicText, + Versions: []string{"grok-3-mini"}, + }, + "grok-2-vision": { + Label: "Grok 2 Vision", + Supports: &ai.ModelSupports{ + Multiturn: false, + Tools: true, + SystemRole: false, + Media: true, + }, + Versions: []string{"grok-2-vision", "grok-2-vision-1212", "grok-2-vision-latest"}, + }, +} + +type XAi struct { + Opts []option.RequestOption + openAICompatible *compat_oai.OpenAICompatible +} + +func (x *XAi) Name() string { + return provider +} + +func (x *XAi) Init(ctx context.Context) []api.Action { + url := os.Getenv("XAI_BASE_URL") + if url == "" { + url = baseURL + } + x.Opts = append([]option.RequestOption{option.WithBaseURL(url)}, x.Opts...) + + apiKey := os.Getenv("XAI_API_KEY") + if apiKey != "" { + x.Opts = append([]option.RequestOption{option.WithAPIKey(apiKey)}, x.Opts...) + } + + if x.openAICompatible == nil { + x.openAICompatible = &compat_oai.OpenAICompatible{} + } + + x.openAICompatible.Opts = x.Opts + compatActions := x.openAICompatible.Init(ctx) + + var actions []api.Action + actions = append(actions, compatActions...) + + // define default models + for model, opts := range supportedModels { + actions = append(actions, x.DefineModel(model, opts).(api.Action)) + } + + return actions +} + +func (x *XAi) Model(g *genkit.Genkit, id string) ai.Model { + return x.openAICompatible.Model(g, api.NewName(provider, id)) +} + +func (x *XAi) DefineModel(id string, opts ai.ModelOptions) ai.Model { + return x.openAICompatible.DefineModel(provider, id, opts) +} + +func (x *XAi) ListActions(ctx context.Context) []api.ActionDesc { + return x.openAICompatible.ListActions(ctx) +} + +func (x *XAi) ResolveAction(atype api.ActionType, name string) api.Action { + return x.openAICompatible.ResolveAction(atype, name) +} diff --git a/go/plugins/compat_oai/xai/xai_live_test.go b/go/plugins/compat_oai/xai/xai_live_test.go new file mode 100644 index 0000000000..26d1d34203 --- /dev/null +++ b/go/plugins/compat_oai/xai/xai_live_test.go @@ -0,0 +1,230 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xai + +import ( + "context" + "math" + "os" + "strings" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/openai/openai-go/option" +) + +func TestPlugin(t *testing.T) { + apiKey := os.Getenv("XAI_API_KEY") + if apiKey == "" { + t.Skip("Skipping test: XAI_API_KEY environment variable not set") + } + + ctx := context.Background() + + x := &XAi{ + Opts: []option.RequestOption{ + option.WithAPIKey(apiKey), + }, + } + + g := genkit.Init(ctx, + genkit.WithDefaultModel("xai/grok-3-mini"), + genkit.WithPlugins(x)) + + gablorkenTool := genkit.DefineTool(g, "gablorken", "use when need to calculate a gablorken", + func(ctx *ai.ToolContext, input struct { + Value float64 + Over float64 + }, + ) (float64, error) { + return math.Pow(input.Value, input.Over), nil + }, + ) + + t.Log("genkit initialized") + + t.Run("basic completion", func(t *testing.T) { + t.Log("generating basic completion response") + resp, err := genkit.Generate(ctx, g, + ai.WithPrompt("What is the capital of France?"), + ) + if err != nil { + t.Fatal("error generating basic completion response: ", err) + } + t.Logf("basic completion response: %+v", resp) + + out := resp.Message.Content[0].Text + if !strings.Contains(strings.ToLower(out), "paris") { + t.Errorf("got %q, expecting it to contain 'Paris'", out) + } + + // Verify usage statistics are present + if resp.Usage == nil || resp.Usage.TotalTokens == 0 { + t.Error("Expected non-zero usage statistics") + } + }) + + t.Run("streaming", func(t *testing.T) { + var streamedOutput string + chunks := 0 + + final, err := genkit.Generate(ctx, g, + ai.WithPrompt("Write a short paragraph about artificial intelligence."), + ai.WithStreaming(func(ctx context.Context, chunk *ai.ModelResponseChunk) error { + chunks++ + for _, content := range chunk.Content { + streamedOutput += content.Text + } + return nil + })) + if err != nil { + t.Fatal(err) + } + + // Verify streaming worked + if chunks <= 1 { + t.Error("Expected multiple chunks for streaming") + } + + // Verify the final output matches streamed content + finalOutput := "" + for _, content := range final.Message.Content { + finalOutput += content.Text + } + if streamedOutput != finalOutput { + t.Errorf("Streaming output doesn't match final output\nStreamed: %s\nFinal: %s", + streamedOutput, finalOutput) + } + + t.Logf("streaming response: %+v", finalOutput) + }) + + t.Run("media part", func(t *testing.T) { + image := "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAHIAAABUAQMAAABk5vEVAAAABlBMVEX///8AAABVwtN+" + + "AAAAI0lEQVR4nGNgGHaA/z8UHIDwOWASDqP8Uf7w56On/1FAQwAAVM0exw1hqwkAAAAASUVORK5CYII=" + + resp, err := genkit.Generate(ctx, g, + ai.WithModelName("xai/grok-4-fast-non-reasoning"), + ai.WithMessages( + ai.NewUserMessage( + ai.NewMediaPart("image/png", image), + ai.NewTextPart("Is there a rectangle in the picture? Yes or not."), + ), + ), + ) + + if err != nil { + t.Fatal(err) + } + + text := resp.Message.Content[0].Text + if !strings.Contains(strings.ToLower(text), "yes") { + t.Errorf("got %q, expecting it to contain 'yes'", text) + } + }) + + t.Run("system message", func(t *testing.T) { + resp, err := genkit.Generate(ctx, g, + ai.WithPrompt("What are you?"), + ai.WithSystem("You are a helpful math tutor who loves numbers."), + ) + if err != nil { + t.Fatal(err) + } + + out := resp.Message.Content[0].Text + if !strings.Contains(strings.ToLower(out), "math") { + t.Errorf("got %q, expecting response to mention being a math tutor", out) + } + + t.Logf("system message response: %+v", out) + }) + + t.Run("tool usage with basic completion", func(t *testing.T) { + resp, err := genkit.Generate(ctx, g, + ai.WithPrompt("what is a gablorken of 2 over 3?"), + ai.WithTools(gablorkenTool)) + if err != nil { + t.Fatal(err) + } + + out := resp.Message.Content[0].Text + const want = "8" + if !strings.Contains(out, want) { + t.Errorf("got %q, expecting it to contain %q", out, want) + } + + t.Logf("tool usage with basic completion response: %+v", out) + }) + + t.Run("tool usage with streaming", func(t *testing.T) { + var streamedOutput string + chunks := 0 + + final, err := genkit.Generate(ctx, g, + ai.WithPrompt("what is a gablorken of 2 over 3?"), + ai.WithTools(gablorkenTool), + ai.WithStreaming(func(ctx context.Context, chunk *ai.ModelResponseChunk) error { + chunks++ + for _, content := range chunk.Content { + streamedOutput += content.Text + } + return nil + })) + if err != nil { + t.Fatal(err) + } + + // Verify streaming worked + if chunks <= 1 { + t.Error("Expected multiple chunks for streaming") + } + + // Verify the final output matches streamed content + finalOutput := "" + for _, content := range final.Message.Content { + finalOutput += content.Text + } + if streamedOutput != finalOutput { + t.Errorf("Streaming output doesn't match final output\nStreamed: %s\nFinal: %s", + streamedOutput, finalOutput) + } + + const want = "8" + if !strings.Contains(finalOutput, want) { + t.Errorf("got %q, expecting it to contain %q", finalOutput, want) + } + + t.Logf("tool usage with streaming response: %+v", finalOutput) + }) + + t.Run("invalid config type", func(t *testing.T) { + // Try to use a string as config instead of *openai.ChatCompletionNewParams + config := "not a config" + + _, err := genkit.Generate(ctx, g, + ai.WithPrompt("Write a short sentence about artificial intelligence."), + ai.WithConfig(config), + ) + if err == nil { + t.Fatal("expected error for invalid config type") + } + if !strings.Contains(err.Error(), "unexpected config type: string") { + t.Errorf("got error %q, want error containing 'unexpected config type: string'", err.Error()) + } + t.Logf("invalid config type error: %v", err) + }) +}