diff --git a/go/ai/embedder.go b/go/ai/embedder.go index bec1ff492..fca80fa43 100644 --- a/go/ai/embedder.go +++ b/go/ai/embedder.go @@ -18,6 +18,7 @@ import ( "context" "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/internal/atype" ) // EmbedderAction is used to convert a document to a @@ -34,13 +35,13 @@ type EmbedRequest struct { // DefineEmbedder registers the given embed function as an action, and returns an // [EmbedderAction] that runs it. func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) ([]float32, error)) *EmbedderAction { - return core.DefineAction(provider, name, core.ActionTypeEmbedder, nil, embed) + return core.DefineAction(provider, name, atype.Embedder, nil, embed) } // LookupEmbedder looks up an [EmbedderAction] registered by [DefineEmbedder]. // It returns nil if the embedder was not defined. func LookupEmbedder(provider, name string) *EmbedderAction { - action := core.LookupActionFor[*EmbedRequest, []float32, struct{}](core.ActionTypeEmbedder, provider, name) + action := core.LookupActionFor[*EmbedRequest, []float32, struct{}](atype.Embedder, provider, name) if action == nil { return nil } diff --git a/go/ai/generate.go b/go/ai/generate.go index ddf0822da..b27657bae 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -25,6 +25,7 @@ import ( "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/logger" + "github.com/firebase/genkit/go/internal/atype" ) // A ModelAction is used to generate content from an AI model. @@ -63,7 +64,7 @@ func DefineModel(provider, name string, metadata *ModelMetadata, generate func(c } metadataMap["supports"] = supports } - return core.DefineStreamingAction(provider, name, core.ActionTypeModel, map[string]any{ + return core.DefineStreamingAction(provider, name, atype.Model, map[string]any{ "model": metadataMap, }, generate) } @@ -71,7 +72,7 @@ func DefineModel(provider, name string, metadata *ModelMetadata, generate func(c // LookupModel looks up a [ModelAction] registered by [DefineModel]. // It returns nil if the model was not defined. func LookupModel(provider, name string) *ModelAction { - return core.LookupActionFor[*GenerateRequest, *GenerateResponse, *Candidate](core.ActionTypeModel, provider, name) + return core.LookupActionFor[*GenerateRequest, *GenerateResponse, *Candidate](atype.Model, provider, name) } // Generate applies a [ModelAction] to some input, handling tool requests. diff --git a/go/ai/prompt.go b/go/ai/prompt.go index 43d88c65a..a3fca0ab4 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -18,6 +18,7 @@ import ( "context" "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/internal/atype" ) // PromptRequest is a request to execute a prompt template and @@ -50,5 +51,5 @@ func RegisterPrompt(provider, name string, prompt Prompt) { "prompt": prompt, } core.RegisterAction(provider, - core.NewStreamingAction(name, core.ActionTypePrompt, metadata, prompt.Generate)) + core.NewStreamingAction(name, atype.Prompt, metadata, prompt.Generate)) } diff --git a/go/ai/retriever.go b/go/ai/retriever.go index 0579eadec..50b405ffd 100644 --- a/go/ai/retriever.go +++ b/go/ai/retriever.go @@ -18,6 +18,7 @@ import ( "context" "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/internal/atype" ) type ( @@ -52,25 +53,25 @@ func DefineIndexer(provider, name string, index func(context.Context, *IndexerRe f := func(ctx context.Context, req *IndexerRequest) (struct{}, error) { return struct{}{}, index(ctx, req) } - return core.DefineAction(provider, name, core.ActionTypeIndexer, nil, f) + return core.DefineAction(provider, name, atype.Indexer, nil, f) } // LookupIndexer looks up a [IndexerAction] registered by [DefineIndexer]. // It returns nil if the model was not defined. func LookupIndexer(provider, name string) *IndexerAction { - return core.LookupActionFor[*IndexerRequest, struct{}, struct{}](core.ActionTypeIndexer, provider, name) + return core.LookupActionFor[*IndexerRequest, struct{}, struct{}](atype.Indexer, provider, name) } // DefineRetriever registers the given retrieve function as an action, and returns a // [RetrieverAction] that runs it. func DefineRetriever(provider, name string, ret func(context.Context, *RetrieverRequest) (*RetrieverResponse, error)) *RetrieverAction { - return core.DefineAction(provider, name, core.ActionTypeRetriever, nil, ret) + return core.DefineAction(provider, name, atype.Retriever, nil, ret) } // LookupRetriever looks up a [RetrieverAction] registered by [DefineRetriever]. // It returns nil if the model was not defined. func LookupRetriever(provider, name string) *RetrieverAction { - return core.LookupActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](core.ActionTypeRetriever, provider, name) + return core.LookupActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](atype.Retriever, provider, name) } // Index runs the given [IndexerAction]. diff --git a/go/ai/tools.go b/go/ai/tools.go index 53b48ccf6..e3ef6735a 100644 --- a/go/ai/tools.go +++ b/go/ai/tools.go @@ -20,6 +20,7 @@ import ( "maps" "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/internal/atype" ) // A Tool is an implementation of a single tool. @@ -41,7 +42,7 @@ func RegisterTool(definition *ToolDefinition, metadata map[string]any, fn func(c } metadata["type"] = "tool" - core.DefineAction("local", definition.Name, core.ActionTypeTool, metadata, fn) + core.DefineAction("local", definition.Name, atype.Tool, metadata, fn) } // toolActionType is the instantiated core.Action type registered @@ -51,7 +52,7 @@ type toolActionType = core.Action[map[string]any, map[string]any, struct{}] // RunTool looks up a tool registered by [RegisterTool], // runs it with the given input, and returns the result. func RunTool(ctx context.Context, name string, input map[string]any) (map[string]any, error) { - action := core.LookupAction(core.ActionTypeTool, "local", name) + action := core.LookupAction(atype.Tool, "local", name) if action == nil { return nil, fmt.Errorf("no tool named %q", name) } diff --git a/go/core/action.go b/go/core/action.go index 8b8cf4f37..8df4c083b 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -24,6 +24,7 @@ import ( "github.com/firebase/genkit/go/core/logger" "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/internal" + "github.com/firebase/genkit/go/internal/atype" "github.com/invopop/jsonschema" ) @@ -54,7 +55,7 @@ type streamingCallback[Stream any] func(context.Context, Stream) error // Each time an Action is run, it results in a new trace span. type Action[In, Out, Stream any] struct { name string - atype ActionType + atype atype.ActionType fn Func[In, Out, Stream] tstate *tracing.State inputSchema *jsonschema.Schema @@ -67,35 +68,39 @@ type Action[In, Out, Stream any] struct { // See js/core/src/action.ts // DefineAction creates a new Action and registers it. -func DefineAction[In, Out any](provider, name string, atype ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { +func DefineAction[In, Out any](provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { return defineAction(globalRegistry, provider, name, atype, metadata, fn) } -func defineAction[In, Out any](r *registry, provider, name string, atype ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { +func defineAction[In, Out any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { a := NewAction(name, atype, metadata, fn) r.registerAction(provider, a) return a } -func DefineStreamingAction[In, Out, Stream any](provider, name string, atype ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { +func DefineStreamingAction[In, Out, Stream any](provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { return defineStreamingAction(globalRegistry, provider, name, atype, metadata, fn) } -func defineStreamingAction[In, Out, Stream any](r *registry, provider, name string, atype ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { +func defineStreamingAction[In, Out, Stream any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { a := NewStreamingAction(name, atype, metadata, fn) r.registerAction(provider, a) return a } +func DefineCustomAction[In, Out, Stream any](provider, name string, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { + return DefineStreamingAction(provider, name, atype.Custom, metadata, fn) +} + // NewAction creates a new Action with the given name and non-streaming function. -func NewAction[In, Out any](name string, atype ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { +func NewAction[In, Out any](name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { return NewStreamingAction(name, atype, metadata, func(ctx context.Context, in In, cb NoStream) (Out, error) { return fn(ctx, in) }) } // NewStreamingAction creates a new Action with the given name and streaming function. -func NewStreamingAction[In, Out, Stream any](name string, atype ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { +func NewStreamingAction[In, Out, Stream any](name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { var i In var o Out return &Action[In, Out, Stream]{ @@ -114,7 +119,7 @@ func NewStreamingAction[In, Out, Stream any](name string, atype ActionType, meta // Name returns the Action's name. func (a *Action[In, Out, Stream]) Name() string { return a.name } -func (a *Action[In, Out, Stream]) actionType() ActionType { return a.atype } +func (a *Action[In, Out, Stream]) actionType() atype.ActionType { return a.atype } // setTracingState sets the action's tracing.State. func (a *Action[In, Out, Stream]) setTracingState(tstate *tracing.State) { a.tstate = tstate } @@ -194,7 +199,7 @@ func (a *Action[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMes // action is the type that all Action[I, O, S] have in common. type action interface { Name() string - actionType() ActionType + actionType() atype.ActionType // runJSON uses encoding/json to unmarshal the input, // calls Action.Run, then returns the marshaled result. diff --git a/go/core/action_test.go b/go/core/action_test.go index 5cc1ee9d0..59ac581be 100644 --- a/go/core/action_test.go +++ b/go/core/action_test.go @@ -19,6 +19,8 @@ import ( "context" "slices" "testing" + + "github.com/firebase/genkit/go/internal/atype" ) func inc(_ context.Context, x int) (int, error) { @@ -26,7 +28,7 @@ func inc(_ context.Context, x int) (int, error) { } func TestActionRun(t *testing.T) { - a := NewAction("inc", ActionTypeCustom, nil, inc) + a := NewAction("inc", atype.Custom, nil, inc) got, err := a.Run(context.Background(), 3, nil) if err != nil { t.Fatal(err) @@ -37,7 +39,7 @@ func TestActionRun(t *testing.T) { } func TestActionRunJSON(t *testing.T) { - a := NewAction("inc", ActionTypeCustom, nil, inc) + a := NewAction("inc", atype.Custom, nil, inc) input := []byte("3") want := []byte("4") got, err := a.runJSON(context.Background(), input, nil) @@ -51,7 +53,7 @@ func TestActionRunJSON(t *testing.T) { func TestNewAction(t *testing.T) { // Verify that struct{} can occur in the function signature. - _ = NewAction("f", ActionTypeCustom, nil, func(context.Context, int) (struct{}, error) { return struct{}{}, nil }) + _ = NewAction("f", atype.Custom, nil, func(context.Context, int) (struct{}, error) { return struct{}{}, nil }) } // count streams the numbers from 0 to n-1, then returns n. @@ -68,7 +70,7 @@ func count(ctx context.Context, n int, cb func(context.Context, int) error) (int func TestActionStreaming(t *testing.T) { ctx := context.Background() - a := NewStreamingAction("count", ActionTypeCustom, nil, count) + a := NewStreamingAction("count", atype.Custom, nil, count) const n = 3 // Non-streaming. @@ -101,7 +103,7 @@ func TestActionStreaming(t *testing.T) { func TestActionTracing(t *testing.T) { ctx := context.Background() const actionName = "TestTracing-inc" - a := NewAction(actionName, ActionTypeCustom, nil, inc) + a := NewAction(actionName, atype.Custom, nil, inc) if _, err := a.Run(context.Background(), 3, nil); err != nil { t.Fatal(err) } diff --git a/go/core/flow.go b/go/core/flow.go index 25a38b0dc..6201fe3d8 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -27,6 +27,7 @@ import ( "github.com/firebase/genkit/go/core/logger" "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/internal" + "github.com/firebase/genkit/go/internal/atype" "github.com/google/uuid" "github.com/invopop/jsonschema" otrace "go.opentelemetry.io/otel/trace" @@ -258,7 +259,7 @@ func (f *Flow[In, Out, Stream]) action() *Action[*flowInstruction[In], *flowStat tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true") return f.runInstruction(ctx, inst, streamingCallback[Stream](cb)) } - return NewStreamingAction(f.name, ActionTypeFlow, metadata, cback) + return NewStreamingAction(f.name, atype.Flow, metadata, cback) } // runInstruction performs one of several actions on a flow, as determined by msg. diff --git a/go/core/registry.go b/go/core/registry.go index 5e5a18dd4..a41fba888 100644 --- a/go/core/registry.go +++ b/go/core/registry.go @@ -26,6 +26,7 @@ import ( "sync" "github.com/firebase/genkit/go/core/tracing" + "github.com/firebase/genkit/go/internal/atype" sdktrace "go.opentelemetry.io/otel/sdk/trace" "golang.org/x/exp/maps" ) @@ -90,23 +91,6 @@ const ( EnvironmentProd Environment = "prod" // production: user data, SLOs, etc. ) -// An ActionType is the kind of an action. -type ActionType string - -const ( - ActionTypeChatLLM ActionType = "chat-llm" - ActionTypeTextLLM ActionType = "text-llm" - ActionTypeRetriever ActionType = "retriever" - ActionTypeIndexer ActionType = "indexer" - ActionTypeEmbedder ActionType = "embedder" - ActionTypeEvaluator ActionType = "evaluator" - ActionTypeFlow ActionType = "flow" - ActionTypeModel ActionType = "model" - ActionTypePrompt ActionType = "prompt" - ActionTypeTool ActionType = "tool" - ActionTypeCustom ActionType = "custom" -) - // RegisterAction records the action in the global registry. // It panics if an action with the same type, provider and name is already // registered. @@ -138,7 +122,7 @@ func (r *registry) lookupAction(key string) action { // LookupAction returns the action for the given key in the global registry, // or nil if there is none. -func LookupAction(typ ActionType, provider, name string) action { +func LookupAction(typ atype.ActionType, provider, name string) action { key := fmt.Sprintf("/%s/%s/%s", typ, provider, name) return globalRegistry.lookupAction(key) } @@ -146,7 +130,7 @@ func LookupAction(typ ActionType, provider, name string) action { // LookupActionFor returns the action for the given key in the global registry, // or nil if there is none. // It panics if the action is of the wrong type. -func LookupActionFor[In, Out, Stream any](typ ActionType, provider, name string) *Action[In, Out, Stream] { +func LookupActionFor[In, Out, Stream any](typ atype.ActionType, provider, name string) *Action[In, Out, Stream] { a := LookupAction(typ, provider, name) if a == nil { return nil diff --git a/go/core/servers_test.go b/go/core/servers_test.go index 2ba017019..a14249ab2 100644 --- a/go/core/servers_test.go +++ b/go/core/servers_test.go @@ -24,6 +24,7 @@ import ( "testing" "github.com/firebase/genkit/go/core/tracing" + "github.com/firebase/genkit/go/internal/atype" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/invopop/jsonschema" @@ -38,10 +39,10 @@ func TestDevServer(t *testing.T) { if err != nil { t.Fatal(err) } - r.registerAction("devServer", NewAction("inc", ActionTypeCustom, map[string]any{ + r.registerAction("devServer", NewAction("inc", atype.Custom, map[string]any{ "foo": "bar", }, inc)) - r.registerAction("devServer", NewAction("dec", ActionTypeCustom, map[string]any{ + r.registerAction("devServer", NewAction("dec", atype.Custom, map[string]any{ "bar": "baz", }, dec)) srv := httptest.NewServer(newDevServeMux(r)) diff --git a/go/internal/atype/atype.go b/go/internal/atype/atype.go new file mode 100644 index 000000000..f5b48c207 --- /dev/null +++ b/go/internal/atype/atype.go @@ -0,0 +1,33 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package atype provides types for Genkit actions. +package atype + +// An ActionType is the kind of an action. +type ActionType string + +const ( + ChatLLM ActionType = "chat-llm" + TextLLM ActionType = "text-llm" + Retriever ActionType = "retriever" + Indexer ActionType = "indexer" + Embedder ActionType = "embedder" + Evaluator ActionType = "evaluator" + Flow ActionType = "flow" + Model ActionType = "model" + Prompt ActionType = "prompt" + Tool ActionType = "tool" + Custom ActionType = "custom" +)