diff --git a/go/ai/generator.go b/go/ai/generator.go index 032221789..98a64e33a 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -23,8 +23,12 @@ import ( // Generator is the interface used to query an AI model. type Generator interface { - // TODO(randall77): define a type for streaming Generate calls. - Generate(context.Context, *GenerateRequest, genkit.NoStream) (*GenerateResponse, error) + // If the streaming callback is non-nil: + // - Each response candidate will be passed to that callback instead of + // populating the result's Candidates field. + // - If the streaming callback returns a non-nil error, generation will stop + // and Generate immediately returns that error (and a nil response). + Generate(context.Context, *GenerateRequest, genkit.StreamingCallback[*Candidate]) (*GenerateResponse, error) } // RegisterGenerator registers the generator in the global registry. @@ -35,8 +39,7 @@ func RegisterGenerator(name string, generator Generator) { // generatorActionType is the instantiated genkit.Action type registered // by RegisterGenerator. -// TODO(ianlancetaylor, randall77): add streaming support -type generatorActionType = genkit.Action[*GenerateRequest, *GenerateResponse, struct{}] +type generatorActionType = genkit.Action[*GenerateRequest, *GenerateResponse, *Candidate] // LookupGeneratorAction looks up an action registered by [RegisterGenerator] // and returns a generator that invokes the action. @@ -58,6 +61,6 @@ type generatorAction struct { } // Generate implements Generator. -func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, cb genkit.NoStream) (*GenerateResponse, error) { +func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, cb genkit.StreamingCallback[*Candidate]) (*GenerateResponse, error) { return ga.action.Run(ctx, input, cb) } diff --git a/go/genkit/dotprompt/genkit_test.go b/go/genkit/dotprompt/genkit_test.go index b0d4157da..190c32e6c 100644 --- a/go/genkit/dotprompt/genkit_test.go +++ b/go/genkit/dotprompt/genkit_test.go @@ -23,9 +23,9 @@ import ( "github.com/google/genkit/go/genkit" ) -type testGenerator struct {} +type testGenerator struct{} -func (testGenerator) Generate(ctx context.Context, req *ai.GenerateRequest, cb genkit.NoStream) (*ai.GenerateResponse, error) { +func (testGenerator) Generate(ctx context.Context, req *ai.GenerateRequest, cb genkit.StreamingCallback[*ai.Candidate]) (*ai.GenerateResponse, error) { input := req.Messages[0].Content[0].Text() output := fmt.Sprintf("AI reply to %q", input) diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index eecddaa65..fcb826a2d 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -20,6 +20,7 @@ import ( "github.com/google/generative-ai-go/genai" "github.com/google/genkit/go/ai" "github.com/google/genkit/go/genkit" + "google.golang.org/api/iterator" "google.golang.org/api/option" ) @@ -60,7 +61,7 @@ type generator struct { client *genai.Client } -func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, _ genkit.NoStream) (*ai.GenerateResponse, error) { +func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb genkit.StreamingCallback[*ai.Candidate]) (*ai.GenerateResponse, error) { gm := g.client.GenerativeModel(g.model) // Translate from a ai.GenerateRequest to a genai request. @@ -94,54 +95,95 @@ func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, _ g //TODO: convert input.Tools and append to gm.Tools // Send out the actual request. - // TODO(randall77): if the streaming callback is non-nil, use SendMessageStream - // and pass the streamed results to the callback. - resp, err := cs.SendMessage(ctx, parts...) - if err != nil { - return nil, err + if cb == nil { + resp, err := cs.SendMessage(ctx, parts...) + if err != nil { + return nil, err + } + return translateResponse(resp), nil } - // Translate from a genai.GenerateContentResponse to a ai.GenerateResponse. - r := &ai.GenerateResponse{} - for _, cand := range resp.Candidates { - c := &ai.Candidate{} - c.Index = int(cand.Index) - switch cand.FinishReason { - case genai.FinishReasonStop: - c.FinishReason = ai.FinishReasonStop - case genai.FinishReasonMaxTokens: - c.FinishReason = ai.FinishReasonLength - case genai.FinishReasonSafety: - c.FinishReason = ai.FinishReasonBlocked - case genai.FinishReasonRecitation: - c.FinishReason = ai.FinishReasonBlocked - case genai.FinishReasonOther: - c.FinishReason = ai.FinishReasonOther - default: // Unspecified - c.FinishReason = ai.FinishReasonUnknown + // Streaming version. + iter := cs.SendMessageStream(ctx, parts...) + var r *ai.GenerateResponse + for { + chunk, err := iter.Next() + if err == iterator.Done { + break } - m := &ai.Message{} - m.Role = ai.Role(cand.Content.Role) - for _, part := range cand.Content.Parts { - var p *ai.Part - switch part := part.(type) { - case genai.Text: - p = ai.NewTextPart(string(part)) - case genai.Blob: - p = ai.NewBlobPart(part.MIMEType, string(part.Data)) - case genai.FunctionResponse: - p = ai.NewBlobPart("TODO", string(part.Name)) - default: - panic("unknown part type") + if err != nil { + return nil, err + } + // Send candidates to the callback. + for _, c := range chunk.Candidates { + err := cb(ctx, translateCandidate(c)) + if err != nil { + return nil, err } - m.Content = append(m.Content, p) } - c.Message = m - r.Candidates = append(r.Candidates, c) + if r == nil { + // Save all other fields of first response + // so we can surface them at the end. + // TODO: necessary? Use last instead of first? merge somehow? + chunk.Candidates = nil + r = translateResponse(chunk) + } + } + if r == nil { + // No candidates were returned. Probably rare, but it might avoid a NPE + // to return an empty instead of nil result. + r = &ai.GenerateResponse{} } return r, nil } +// translateCandidate Translate from a genai.GenerateContentResponse to a ai.GenerateResponse. +func translateCandidate(cand *genai.Candidate) *ai.Candidate { + c := &ai.Candidate{} + c.Index = int(cand.Index) + switch cand.FinishReason { + case genai.FinishReasonStop: + c.FinishReason = ai.FinishReasonStop + case genai.FinishReasonMaxTokens: + c.FinishReason = ai.FinishReasonLength + case genai.FinishReasonSafety: + c.FinishReason = ai.FinishReasonBlocked + case genai.FinishReasonRecitation: + c.FinishReason = ai.FinishReasonBlocked + case genai.FinishReasonOther: + c.FinishReason = ai.FinishReasonOther + default: // Unspecified + c.FinishReason = ai.FinishReasonUnknown + } + m := &ai.Message{} + m.Role = ai.Role(cand.Content.Role) + for _, part := range cand.Content.Parts { + var p *ai.Part + switch part := part.(type) { + case genai.Text: + p = ai.NewTextPart(string(part)) + case genai.Blob: + p = ai.NewBlobPart(part.MIMEType, string(part.Data)) + case genai.FunctionResponse: + p = ai.NewBlobPart("TODO", string(part.Name)) + default: + panic("unknown part type") + } + m.Content = append(m.Content, p) + } + c.Message = m + return c +} + +// Translate from a genai.GenerateContentResponse to a ai.GenerateResponse. +func translateResponse(resp *genai.GenerateContentResponse) *ai.GenerateResponse { + r := &ai.GenerateResponse{} + for _, c := range resp.Candidates { + r.Candidates = append(r.Candidates, translateCandidate(c)) + } + return r +} + // NewGenerator returns an action which sends a request to // the google AI model and returns the response. func NewGenerator(ctx context.Context, model, apiKey string) (ai.Generator, error) { diff --git a/go/plugins/googleai/googleai_test.go b/go/plugins/googleai/googleai_test.go index 22b1b6146..72c76deb9 100644 --- a/go/plugins/googleai/googleai_test.go +++ b/go/plugins/googleai/googleai_test.go @@ -17,6 +17,7 @@ package googleai_test import ( "context" "flag" + "strings" "testing" "github.com/google/genkit/go/ai" @@ -84,3 +85,41 @@ func TestGenerator(t *testing.T) { t.Errorf("got \"%s\", expecting \"France\"", out) } } + +func TestGeneratorStreaming(t *testing.T) { + if *apiKey == "" { + t.Skipf("no -key provided") + } + ctx := context.Background() + g, err := googleai.NewGenerator(ctx, "gemini-1.0-pro", *apiKey) + if err != nil { + t.Fatal(err) + } + req := &ai.GenerateRequest{ + Candidates: 1, + Messages: []*ai.Message{ + &ai.Message{ + Content: []*ai.Part{ai.NewTextPart("Write one paragraph about the Golden State Warriors.")}, + Role: ai.RoleUser, + }, + }, + } + + out := "" + parts := 0 + _, err = g.Generate(ctx, req, func(ctx context.Context, c *ai.Candidate) error { + parts++ + out += c.Message.Content[0].Text() + return nil + }) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(out, "San Francisco") { + t.Errorf("got \"%s\", expecting it to contain \"San Francisco\"", out) + } + if parts == 1 { + // Check if streaming actually occurred. + t.Errorf("expecting more than one part") + } +} diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index 6e5f92c89..8dd053d82 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -31,7 +31,10 @@ type generator struct { client *genai.Client } -func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, _ genkit.NoStream) (*ai.GenerateResponse, error) { +func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb genkit.StreamingCallback[*ai.Candidate]) (*ai.GenerateResponse, error) { + if cb != nil { + panic("streaming not supported yet") // TODO: streaming + } gm := g.client.GenerativeModel(g.model) // Translate from a ai.GenerateRequest to a genai request.