Skip to content

Commit

Permalink
[Go] ai.Embedder is a separate type from its action
Browse files Browse the repository at this point in the history
See #402 and #458.
  • Loading branch information
jba committed Jun 24, 2024
1 parent 0db5280 commit c6a6d95
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 31 deletions.
19 changes: 10 additions & 9 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import (
"github.com/firebase/genkit/go/internal/atype"
)

// EmbedderAction is used to convert a document to a
// An Embedder is used to convert a document to a
// multidimensional vector.
type EmbedderAction = core.Action[*EmbedRequest, []float32, struct{}]
type Embedder core.Action[*EmbedRequest, []float32, struct{}]

// EmbedRequest is the data we pass to convert a document
// to a multidimensional vector.
Expand All @@ -34,21 +34,22 @@ type EmbedRequest struct {

// DefineEmbedder registers the given embed function as an action, and returns an
// [EmbedderAction] that runs it.
func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) ([]float32, error)) *EmbedderAction {
return core.DefineAction(provider, name, atype.Embedder, nil, embed)
func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) ([]float32, error)) *Embedder {
return (*Embedder)(core.DefineAction(provider, name, atype.Embedder, nil, embed))
}

// LookupEmbedder looks up an [EmbedderAction] registered by [DefineEmbedder].
// It returns nil if the embedder was not defined.
func LookupEmbedder(provider, name string) *EmbedderAction {
func LookupEmbedder(provider, name string) *Embedder {
action := core.LookupActionFor[*EmbedRequest, []float32, struct{}](atype.Embedder, provider, name)
if action == nil {
return nil
}
return action
return (*Embedder)(action)
}

// Embed runs the given [EmbedderAction].
func Embed(ctx context.Context, emb *EmbedderAction, req *EmbedRequest) ([]float32, error) {
return emb.Run(ctx, req, nil)
// Embed runs the given [Embedder].
func (e *Embedder) Embed(ctx context.Context, req *EmbedRequest) ([]float32, error) {
a := (*core.Action[*EmbedRequest, []float32, struct{}])(e)
return a.Run(ctx, req, nil)
}
2 changes: 1 addition & 1 deletion go/internal/fakeembedder/fakeembedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
"github.com/firebase/genkit/go/ai"
)

// Embedder is a fake implementation of genkit.Embedder.
// Embedder is a fake implementation of an Embedder.
type Embedder struct {
registry map[*ai.Document][]float32
}
Expand Down
6 changes: 3 additions & 3 deletions go/internal/fakeembedder/fakeembedder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (

func TestFakeEmbedder(t *testing.T) {
embed := New()
embedAction := ai.DefineEmbedder("fake", "embed", embed.Embed)
emb := ai.DefineEmbedder("fake", "embed", embed.Embed)
d := ai.DocumentFromText("fakeembedder test", nil)

vals := []float32{1, 2}
Expand All @@ -34,7 +34,7 @@ func TestFakeEmbedder(t *testing.T) {
Document: d,
}
ctx := context.Background()
got, err := ai.Embed(ctx, embedAction, req)
got, err := emb.Embed(ctx, req)
if err != nil {
t.Fatal(err)
}
Expand All @@ -43,7 +43,7 @@ func TestFakeEmbedder(t *testing.T) {
}

req.Document = ai.DocumentFromText("missing document", nil)
if _, err = ai.Embed(ctx, embedAction, req); err == nil {
if _, err = emb.Embed(ctx, req); err == nil {
t.Error("embedding unknown document succeeded unexpectedly")
}
}
8 changes: 4 additions & 4 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func defineModel(name string, caps ai.ModelCapabilities) *ai.Model {
}

// DefineEmbedder defines an embedder with a given name.
func DefineEmbedder(name string) *ai.EmbedderAction {
func DefineEmbedder(name string) *ai.Embedder {
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
Expand All @@ -130,7 +130,7 @@ func DefineEmbedder(name string) *ai.EmbedderAction {
}

// requires state.mu
func defineEmbedder(name string) *ai.EmbedderAction {
func defineEmbedder(name string) *ai.Embedder {
return ai.DefineEmbedder(provider, name, func(ctx context.Context, input *ai.EmbedRequest) ([]float32, error) {
em := state.client.EmbeddingModel(name)
parts, err := convertParts(input.Document.Content)
Expand All @@ -151,9 +151,9 @@ func Model(name string) *ai.Model {
return ai.LookupModel(provider, name)
}

// Embedder returns the [ai.EmbedderAction] with the given name.
// Embedder returns the [ai.Embedder] with the given name.
// It returns nil if the embedder was not configured.
func Embedder(name string) *ai.EmbedderAction {
func Embedder(name string) *ai.Embedder {
return ai.LookupEmbedder(provider, name)
}

Expand Down
2 changes: 1 addition & 1 deletion go/plugins/googleai/googleai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func TestLive(t *testing.T) {
},
)
t.Run("embedder", func(t *testing.T) {
out, err := ai.Embed(ctx, embedder, &ai.EmbedRequest{
out, err := embedder.Embed(ctx, &ai.EmbedRequest{
Document: ai.DocumentFromText("yellow banana", nil),
})
if err != nil {
Expand Down
10 changes: 5 additions & 5 deletions go/plugins/localvec/localvec.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ const provider = "devLocalVectorStore"
type Config struct {
// Where to store the data. Defaults to os.TempDir.
Dir string
Embedder *ai.EmbedderAction
Embedder *ai.Embedder
EmbedderOptions any
}

Expand Down Expand Up @@ -73,7 +73,7 @@ func Retriever(name string) *ai.Retriever {
// This is based on js/plugins/dev-local-vectorstore/src/index.ts.
type docStore struct {
filename string
embedder *ai.EmbedderAction
embedder *ai.Embedder
embedderOptions any
data map[string]dbValue
}
Expand All @@ -85,7 +85,7 @@ type dbValue struct {
}

// newDocStore returns a new ai.DocumentStore to register.
func newDocStore(dir, name string, embedder *ai.EmbedderAction, embedderOptions any) (*docStore, error) {
func newDocStore(dir, name string, embedder *ai.Embedder, embedderOptions any) (*docStore, error) {
if dir == "" {
dir = os.TempDir()
}
Expand Down Expand Up @@ -124,7 +124,7 @@ func (ds *docStore) index(ctx context.Context, req *ai.IndexerRequest) error {
Document: doc,
Options: ds.embedderOptions,
}
vals, err := ai.Embed(ctx, ds.embedder, ereq)
vals, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return fmt.Errorf("localvec index embedding failed: %v", err)
}
Expand Down Expand Up @@ -186,7 +186,7 @@ func (ds *docStore) retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
Document: req.Document,
Options: ds.embedderOptions,
}
vals, err := ai.Embed(ctx, ds.embedder, ereq)
vals, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return nil, fmt.Errorf("localvec retrieve embedding failed: %v", err)
}
Expand Down
8 changes: 4 additions & 4 deletions go/plugins/pinecone/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ type Config struct {
// The index ID to use.
IndexID string
// Embedder to use. Required.
Embedder *ai.EmbedderAction
Embedder *ai.Embedder
EmbedderOptions any
// The metadata key to use to store document text
// in Pinecone; the default is "_content".
Expand Down Expand Up @@ -160,7 +160,7 @@ type RetrieverOptions struct {
// docStore implements the genkit [ai.DocumentStore] interface.
type docStore struct {
index *index
embedder *ai.EmbedderAction
embedder *ai.Embedder
embedderOptions any
textKey string
}
Expand Down Expand Up @@ -190,7 +190,7 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
Document: doc,
Options: ds.embedderOptions,
}
vals, err := ai.Embed(ctx, ds.embedder, ereq)
vals, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return fmt.Errorf("pinecone index embedding failed: %v", err)
}
Expand Down Expand Up @@ -285,7 +285,7 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
Document: req.Document,
Options: ds.embedderOptions,
}
vals, err := ai.Embed(ctx, ds.embedder, ereq)
vals, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return nil, fmt.Errorf("pinecone retrieve embedding failed: %v", err)
}
Expand Down
6 changes: 3 additions & 3 deletions go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func DefineModel(name string) *ai.Model {
}

// DefineModel defines an embedder with the given name.
func DefineEmbedder(name string) *ai.EmbedderAction {
func DefineEmbedder(name string) *ai.Embedder {
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
Expand All @@ -107,9 +107,9 @@ func Model(name string) *ai.Model {
return ai.LookupModel(provider, name)
}

// Embedder returns the [ai.EmbedderAction] with the given name.
// Embedder returns the [ai.Embedder] with the given name.
// It returns nil if the embedder was not configured.
func Embedder(name string) *ai.EmbedderAction {
func Embedder(name string) *ai.Embedder {
return ai.LookupEmbedder(provider, name)
}

Expand Down
2 changes: 1 addition & 1 deletion go/plugins/vertexai/vertexai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func TestLive(t *testing.T) {
}
})
t.Run("embedder", func(t *testing.T) {
out, err := ai.Embed(ctx, embedder, &ai.EmbedRequest{
out, err := embedder.Embed(ctx, &ai.EmbedRequest{
Document: ai.DocumentFromText("time flies like an arrow", nil),
})
if err != nil {
Expand Down

0 comments on commit c6a6d95

Please sign in to comment.