Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement streaming in googleai generator #25

Merged
merged 2 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions go/ai/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ import (

// Generator is the interface used to query an AI model.
type Generator interface {
// TODO(randall77): define a type for streaming Generate calls.
Generate(context.Context, *GenerateRequest, genkit.NoStream) (*GenerateResponse, error)
// If the streaming callback is non-nil:
// - Each response candidate will be passed to that callback instead of
// populating the result's Candidates field.
// - If the streaming callback returns a non-nil error, generation will stop
// and Generate immediately returns that error (and a nil response).
Generate(context.Context, *GenerateRequest, genkit.StreamingCallback[*Candidate]) (*GenerateResponse, error)
}

// RegisterGenerator registers the generator in the global registry.
Expand All @@ -35,8 +39,7 @@ func RegisterGenerator(name string, generator Generator) {

// generatorActionType is the instantiated genkit.Action type registered
// by RegisterGenerator.
// TODO(ianlancetaylor, randall77): add streaming support
type generatorActionType = genkit.Action[*GenerateRequest, *GenerateResponse, struct{}]
type generatorActionType = genkit.Action[*GenerateRequest, *GenerateResponse, *Candidate]

// LookupGeneratorAction looks up an action registered by [RegisterGenerator]
// and returns a generator that invokes the action.
Expand All @@ -58,6 +61,6 @@ type generatorAction struct {
}

// Generate implements Generator.
func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, cb genkit.NoStream) (*GenerateResponse, error) {
func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, cb genkit.StreamingCallback[*Candidate]) (*GenerateResponse, error) {
return ga.action.Run(ctx, input, cb)
}
4 changes: 2 additions & 2 deletions go/genkit/dotprompt/genkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import (
"github.com/google/genkit/go/genkit"
)

type testGenerator struct {}
type testGenerator struct{}

func (testGenerator) Generate(ctx context.Context, req *ai.GenerateRequest, cb genkit.NoStream) (*ai.GenerateResponse, error) {
func (testGenerator) Generate(ctx context.Context, req *ai.GenerateRequest, cb genkit.StreamingCallback[*ai.Candidate]) (*ai.GenerateResponse, error) {
input := req.Messages[0].Content[0].Text()
output := fmt.Sprintf("AI reply to %q", input)

Expand Down
122 changes: 82 additions & 40 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/google/generative-ai-go/genai"
"github.com/google/genkit/go/ai"
"github.com/google/genkit/go/genkit"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
)

Expand Down Expand Up @@ -60,7 +61,7 @@ type generator struct {
client *genai.Client
}

func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, _ genkit.NoStream) (*ai.GenerateResponse, error) {
func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb genkit.StreamingCallback[*ai.Candidate]) (*ai.GenerateResponse, error) {
gm := g.client.GenerativeModel(g.model)

// Translate from a ai.GenerateRequest to a genai request.
Expand Down Expand Up @@ -94,54 +95,95 @@ func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, _ g
//TODO: convert input.Tools and append to gm.Tools

// Send out the actual request.
// TODO(randall77): if the streaming callback is non-nil, use SendMessageStream
// and pass the streamed results to the callback.
resp, err := cs.SendMessage(ctx, parts...)
if err != nil {
return nil, err
if cb == nil {
resp, err := cs.SendMessage(ctx, parts...)
if err != nil {
return nil, err
}
return translateResponse(resp), nil
}

// Translate from a genai.GenerateContentResponse to a ai.GenerateResponse.
r := &ai.GenerateResponse{}
for _, cand := range resp.Candidates {
c := &ai.Candidate{}
c.Index = int(cand.Index)
switch cand.FinishReason {
case genai.FinishReasonStop:
c.FinishReason = ai.FinishReasonStop
case genai.FinishReasonMaxTokens:
c.FinishReason = ai.FinishReasonLength
case genai.FinishReasonSafety:
c.FinishReason = ai.FinishReasonBlocked
case genai.FinishReasonRecitation:
c.FinishReason = ai.FinishReasonBlocked
case genai.FinishReasonOther:
c.FinishReason = ai.FinishReasonOther
default: // Unspecified
c.FinishReason = ai.FinishReasonUnknown
// Streaming version.
iter := cs.SendMessageStream(ctx, parts...)
var r *ai.GenerateResponse
for {
chunk, err := iter.Next()
if err == iterator.Done {
break
}
m := &ai.Message{}
m.Role = ai.Role(cand.Content.Role)
for _, part := range cand.Content.Parts {
var p *ai.Part
switch part := part.(type) {
case genai.Text:
p = ai.NewTextPart(string(part))
case genai.Blob:
p = ai.NewBlobPart(part.MIMEType, string(part.Data))
case genai.FunctionResponse:
p = ai.NewBlobPart("TODO", string(part.Name))
default:
panic("unknown part type")
if err != nil {
return nil, err
}
// Send candidates to the callback.
for _, c := range chunk.Candidates {
err := cb(ctx, translateCandidate(c))
if err != nil {
return nil, err
}
m.Content = append(m.Content, p)
}
c.Message = m
r.Candidates = append(r.Candidates, c)
if r == nil {
// Save all other fields of first response
// so we can surface them at the end.
// TODO: necessary? Use last instead of first? merge somehow?
chunk.Candidates = nil
r = translateResponse(chunk)
}
}
if r == nil {
// No candidates were returned. Probably rare, but it might avoid a NPE
// to return an empty instead of nil result.
r = &ai.GenerateResponse{}
}
return r, nil
}

// translateCandidate Translate from a genai.GenerateContentResponse to a ai.GenerateResponse.
func translateCandidate(cand *genai.Candidate) *ai.Candidate {
c := &ai.Candidate{}
c.Index = int(cand.Index)
switch cand.FinishReason {
case genai.FinishReasonStop:
c.FinishReason = ai.FinishReasonStop
case genai.FinishReasonMaxTokens:
c.FinishReason = ai.FinishReasonLength
case genai.FinishReasonSafety:
c.FinishReason = ai.FinishReasonBlocked
case genai.FinishReasonRecitation:
c.FinishReason = ai.FinishReasonBlocked
case genai.FinishReasonOther:
c.FinishReason = ai.FinishReasonOther
default: // Unspecified
c.FinishReason = ai.FinishReasonUnknown
}
m := &ai.Message{}
m.Role = ai.Role(cand.Content.Role)
for _, part := range cand.Content.Parts {
var p *ai.Part
switch part := part.(type) {
case genai.Text:
p = ai.NewTextPart(string(part))
case genai.Blob:
p = ai.NewBlobPart(part.MIMEType, string(part.Data))
case genai.FunctionResponse:
p = ai.NewBlobPart("TODO", string(part.Name))
default:
panic("unknown part type")
}
m.Content = append(m.Content, p)
}
c.Message = m
return c
}

// Translate from a genai.GenerateContentResponse to a ai.GenerateResponse.
func translateResponse(resp *genai.GenerateContentResponse) *ai.GenerateResponse {
r := &ai.GenerateResponse{}
for _, c := range resp.Candidates {
r.Candidates = append(r.Candidates, translateCandidate(c))
}
return r
}

// NewGenerator returns an action which sends a request to
// the google AI model and returns the response.
func NewGenerator(ctx context.Context, model, apiKey string) (ai.Generator, error) {
Expand Down
39 changes: 39 additions & 0 deletions go/plugins/googleai/googleai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package googleai_test
import (
"context"
"flag"
"strings"
"testing"

"github.com/google/genkit/go/ai"
Expand Down Expand Up @@ -84,3 +85,41 @@ func TestGenerator(t *testing.T) {
t.Errorf("got \"%s\", expecting \"France\"", out)
}
}

func TestGeneratorStreaming(t *testing.T) {
if *apiKey == "" {
t.Skipf("no -key provided")
}
ctx := context.Background()
g, err := googleai.NewGenerator(ctx, "gemini-1.0-pro", *apiKey)
if err != nil {
t.Fatal(err)
}
req := &ai.GenerateRequest{
Candidates: 1,
Messages: []*ai.Message{
&ai.Message{
Content: []*ai.Part{ai.NewTextPart("Write one paragraph about the Golden State Warriors.")},
Role: ai.RoleUser,
},
},
}

out := ""
parts := 0
_, err = g.Generate(ctx, req, func(ctx context.Context, c *ai.Candidate) error {
parts++
out += c.Message.Content[0].Text()
return nil
})
if err != nil {
t.Fatal(err)
}
if !strings.Contains(out, "San Francisco") {
t.Errorf("got \"%s\", expecting it to contain \"San Francisco\"", out)
}
if parts == 1 {
// Check if streaming actually occurred.
t.Errorf("expecting more than one part")
}
}
5 changes: 4 additions & 1 deletion go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ type generator struct {
client *genai.Client
}

func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, _ genkit.NoStream) (*ai.GenerateResponse, error) {
func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb genkit.StreamingCallback[*ai.Candidate]) (*ai.GenerateResponse, error) {
if cb != nil {
panic("streaming not supported yet") // TODO: streaming
}
gm := g.client.GenerativeModel(g.model)

// Translate from a ai.GenerateRequest to a genai request.
Expand Down
Loading