From 8275117d0f30928783bc0c8d696564cd1d96f0b8 Mon Sep 17 00:00:00 2001 From: Ian Lance Taylor Date: Thu, 9 May 2024 11:43:21 -0700 Subject: [PATCH] [Go] update dotprompt model strings to use provider/model (#94) * [Go] update dotprompt model strings to use provider/model Also update coffee-shop to use GOOGLE_GENAI_API_KEY rather than GEMINI_API_KEY. --- go/ai/generator.go | 4 ++-- go/genkit/dotprompt/genkit.go | 6 +++++- go/samples/coffee-shop/main.go | 8 ++++---- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/go/ai/generator.go b/go/ai/generator.go index 8833cd009..423669557 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -74,8 +74,8 @@ type generatorActionType = genkit.Action[*GenerateRequest, *GenerateResponse, *C // LookupGeneratorAction looks up an action registered by [RegisterGenerator] // and returns a generator that invokes the action. -func LookupGeneratorAction(name string) (Generator, error) { - action := genkit.LookupAction(genkit.ActionTypeModel, name, name) +func LookupGeneratorAction(provider, name string) (Generator, error) { + action := genkit.LookupAction(genkit.ActionTypeModel, provider, name) if action == nil { return nil, fmt.Errorf("LookupGeneratorAction: no generator action named %q", name) } diff --git a/go/genkit/dotprompt/genkit.go b/go/genkit/dotprompt/genkit.go index a44cf2c25..841c1acdb 100644 --- a/go/genkit/dotprompt/genkit.go +++ b/go/genkit/dotprompt/genkit.go @@ -184,8 +184,12 @@ func (p *Prompt) Execute(ctx context.Context, input *ActionInput) (*ai.GenerateR if model == "" { return nil, errors.New("dotprompt action: model not specified") } + provider, name, found := strings.Cut(model, "/") + if !found { + return nil, errors.New("dotprompt model not in provider/name format") + } - generator, err = ai.LookupGeneratorAction(model) + generator, err = ai.LookupGeneratorAction(provider, name) if err != nil { return nil, err } diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go index f17186a47..953eb8fe5 100644 --- a/go/samples/coffee-shop/main.go +++ b/go/samples/coffee-shop/main.go @@ -72,9 +72,9 @@ type testAllCoffeeFlowsOutput struct { } func main() { - apiKey := os.Getenv("GEMINI_API_KEY") + apiKey := os.Getenv("GOOGLE_GENAI_API_KEY") if apiKey == "" { - fmt.Fprintln(os.Stderr, "coffee-shop example requires setting GEMINI_API_KEY in the environment.") + fmt.Fprintln(os.Stderr, "coffee-shop example requires setting GOOGLE_GENAI_API_KEY in the environment.") fmt.Fprintln(os.Stderr, "You can get an API key at https://ai.google.dev.") os.Exit(1) } @@ -86,7 +86,7 @@ func main() { simpleGreetingPrompt, err := dotprompt.Define("simpleGreeting", &dotprompt.Frontmatter{ Name: "simpleGreeting", - Model: "google-genai", + Model: "google-genai/gemini-1.0-pro", Input: dotprompt.FrontmatterInput{ Schema: jsonschema.Reflect(simpleGreetingInput{}), }, @@ -121,7 +121,7 @@ func main() { greetingWithHistoryPrompt, err := dotprompt.Define("greetingWithHistory", &dotprompt.Frontmatter{ Name: "greetingWithHistory", - Model: "google-genai", + Model: "google-genai/gemini-1.0-pro", Input: dotprompt.FrontmatterInput{ Schema: jsonschema.Reflect(customerTimeAndHistoryInput{}), },