Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Go] make ActionType internal #360

Merged
merged 1 commit into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/atype"
)

// EmbedderAction is used to convert a document to a
Expand All @@ -34,13 +35,13 @@ type EmbedRequest struct {
// DefineEmbedder registers the given embed function as an action, and returns an
// [EmbedderAction] that runs it.
func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) ([]float32, error)) *EmbedderAction {
return core.DefineAction(provider, name, core.ActionTypeEmbedder, nil, embed)
return core.DefineAction(provider, name, atype.Embedder, nil, embed)
}

// LookupEmbedder looks up an [EmbedderAction] registered by [DefineEmbedder].
// It returns nil if the embedder was not defined.
func LookupEmbedder(provider, name string) *EmbedderAction {
action := core.LookupActionFor[*EmbedRequest, []float32, struct{}](core.ActionTypeEmbedder, provider, name)
action := core.LookupActionFor[*EmbedRequest, []float32, struct{}](atype.Embedder, provider, name)
if action == nil {
return nil
}
Expand Down
5 changes: 3 additions & 2 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/core/logger"
"github.com/firebase/genkit/go/internal/atype"
)

// A ModelAction is used to generate content from an AI model.
Expand Down Expand Up @@ -63,15 +64,15 @@ func DefineModel(provider, name string, metadata *ModelMetadata, generate func(c
}
metadataMap["supports"] = supports
}
return core.DefineStreamingAction(provider, name, core.ActionTypeModel, map[string]any{
return core.DefineStreamingAction(provider, name, atype.Model, map[string]any{
"model": metadataMap,
}, generate)
}

// LookupModel looks up a [ModelAction] registered by [DefineModel].
// It returns nil if the model was not defined.
func LookupModel(provider, name string) *ModelAction {
return core.LookupActionFor[*GenerateRequest, *GenerateResponse, *Candidate](core.ActionTypeModel, provider, name)
return core.LookupActionFor[*GenerateRequest, *GenerateResponse, *Candidate](atype.Model, provider, name)
}

// Generate applies a [ModelAction] to some input, handling tool requests.
Expand Down
3 changes: 2 additions & 1 deletion go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/atype"
)

// PromptRequest is a request to execute a prompt template and
Expand Down Expand Up @@ -50,5 +51,5 @@ func RegisterPrompt(provider, name string, prompt Prompt) {
"prompt": prompt,
}
core.RegisterAction(provider,
core.NewStreamingAction(name, core.ActionTypePrompt, metadata, prompt.Generate))
core.NewStreamingAction(name, atype.Prompt, metadata, prompt.Generate))
}
9 changes: 5 additions & 4 deletions go/ai/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/atype"
)

type (
Expand Down Expand Up @@ -52,25 +53,25 @@ func DefineIndexer(provider, name string, index func(context.Context, *IndexerRe
f := func(ctx context.Context, req *IndexerRequest) (struct{}, error) {
return struct{}{}, index(ctx, req)
}
return core.DefineAction(provider, name, core.ActionTypeIndexer, nil, f)
return core.DefineAction(provider, name, atype.Indexer, nil, f)
}

// LookupIndexer looks up a [IndexerAction] registered by [DefineIndexer].
// It returns nil if the model was not defined.
func LookupIndexer(provider, name string) *IndexerAction {
return core.LookupActionFor[*IndexerRequest, struct{}, struct{}](core.ActionTypeIndexer, provider, name)
return core.LookupActionFor[*IndexerRequest, struct{}, struct{}](atype.Indexer, provider, name)
}

// DefineRetriever registers the given retrieve function as an action, and returns a
// [RetrieverAction] that runs it.
func DefineRetriever(provider, name string, ret func(context.Context, *RetrieverRequest) (*RetrieverResponse, error)) *RetrieverAction {
return core.DefineAction(provider, name, core.ActionTypeRetriever, nil, ret)
return core.DefineAction(provider, name, atype.Retriever, nil, ret)
}

// LookupRetriever looks up a [RetrieverAction] registered by [DefineRetriever].
// It returns nil if the model was not defined.
func LookupRetriever(provider, name string) *RetrieverAction {
return core.LookupActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](core.ActionTypeRetriever, provider, name)
return core.LookupActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](atype.Retriever, provider, name)
}

// Index runs the given [IndexerAction].
Expand Down
5 changes: 3 additions & 2 deletions go/ai/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"maps"

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/atype"
)

// A Tool is an implementation of a single tool.
Expand All @@ -41,7 +42,7 @@ func RegisterTool(definition *ToolDefinition, metadata map[string]any, fn func(c
}
metadata["type"] = "tool"

core.DefineAction("local", definition.Name, core.ActionTypeTool, metadata, fn)
core.DefineAction("local", definition.Name, atype.Tool, metadata, fn)
}

// toolActionType is the instantiated core.Action type registered
Expand All @@ -51,7 +52,7 @@ type toolActionType = core.Action[map[string]any, map[string]any, struct{}]
// RunTool looks up a tool registered by [RegisterTool],
// runs it with the given input, and returns the result.
func RunTool(ctx context.Context, name string, input map[string]any) (map[string]any, error) {
action := core.LookupAction(core.ActionTypeTool, "local", name)
action := core.LookupAction(atype.Tool, "local", name)
if action == nil {
return nil, fmt.Errorf("no tool named %q", name)
}
Expand Down
23 changes: 14 additions & 9 deletions go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/firebase/genkit/go/core/logger"
"github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/internal"
"github.com/firebase/genkit/go/internal/atype"
"github.com/invopop/jsonschema"
)

Expand Down Expand Up @@ -54,7 +55,7 @@ type streamingCallback[Stream any] func(context.Context, Stream) error
// Each time an Action is run, it results in a new trace span.
type Action[In, Out, Stream any] struct {
name string
atype ActionType
atype atype.ActionType
fn Func[In, Out, Stream]
tstate *tracing.State
inputSchema *jsonschema.Schema
Expand All @@ -67,35 +68,39 @@ type Action[In, Out, Stream any] struct {
// See js/core/src/action.ts

// DefineAction creates a new Action and registers it.
func DefineAction[In, Out any](provider, name string, atype ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
func DefineAction[In, Out any](provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
return defineAction(globalRegistry, provider, name, atype, metadata, fn)
}

func defineAction[In, Out any](r *registry, provider, name string, atype ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
func defineAction[In, Out any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
a := NewAction(name, atype, metadata, fn)
r.registerAction(provider, a)
return a
}

func DefineStreamingAction[In, Out, Stream any](provider, name string, atype ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
func DefineStreamingAction[In, Out, Stream any](provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
return defineStreamingAction(globalRegistry, provider, name, atype, metadata, fn)
}

func defineStreamingAction[In, Out, Stream any](r *registry, provider, name string, atype ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
func defineStreamingAction[In, Out, Stream any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
a := NewStreamingAction(name, atype, metadata, fn)
r.registerAction(provider, a)
return a
}

func DefineCustomAction[In, Out, Stream any](provider, name string, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
return DefineStreamingAction(provider, name, atype.Custom, metadata, fn)
}

// NewAction creates a new Action with the given name and non-streaming function.
func NewAction[In, Out any](name string, atype ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
func NewAction[In, Out any](name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
return NewStreamingAction(name, atype, metadata, func(ctx context.Context, in In, cb NoStream) (Out, error) {
return fn(ctx, in)
})
}

// NewStreamingAction creates a new Action with the given name and streaming function.
func NewStreamingAction[In, Out, Stream any](name string, atype ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
func NewStreamingAction[In, Out, Stream any](name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
var i In
var o Out
return &Action[In, Out, Stream]{
Expand All @@ -114,7 +119,7 @@ func NewStreamingAction[In, Out, Stream any](name string, atype ActionType, meta
// Name returns the Action's name.
func (a *Action[In, Out, Stream]) Name() string { return a.name }

func (a *Action[In, Out, Stream]) actionType() ActionType { return a.atype }
func (a *Action[In, Out, Stream]) actionType() atype.ActionType { return a.atype }

// setTracingState sets the action's tracing.State.
func (a *Action[In, Out, Stream]) setTracingState(tstate *tracing.State) { a.tstate = tstate }
Expand Down Expand Up @@ -194,7 +199,7 @@ func (a *Action[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMes
// action is the type that all Action[I, O, S] have in common.
type action interface {
Name() string
actionType() ActionType
actionType() atype.ActionType

// runJSON uses encoding/json to unmarshal the input,
// calls Action.Run, then returns the marshaled result.
Expand Down
12 changes: 7 additions & 5 deletions go/core/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@ import (
"context"
"slices"
"testing"

"github.com/firebase/genkit/go/internal/atype"
)

func inc(_ context.Context, x int) (int, error) {
return x + 1, nil
}

func TestActionRun(t *testing.T) {
a := NewAction("inc", ActionTypeCustom, nil, inc)
a := NewAction("inc", atype.Custom, nil, inc)
got, err := a.Run(context.Background(), 3, nil)
if err != nil {
t.Fatal(err)
Expand All @@ -37,7 +39,7 @@ func TestActionRun(t *testing.T) {
}

func TestActionRunJSON(t *testing.T) {
a := NewAction("inc", ActionTypeCustom, nil, inc)
a := NewAction("inc", atype.Custom, nil, inc)
input := []byte("3")
want := []byte("4")
got, err := a.runJSON(context.Background(), input, nil)
Expand All @@ -51,7 +53,7 @@ func TestActionRunJSON(t *testing.T) {

func TestNewAction(t *testing.T) {
// Verify that struct{} can occur in the function signature.
_ = NewAction("f", ActionTypeCustom, nil, func(context.Context, int) (struct{}, error) { return struct{}{}, nil })
_ = NewAction("f", atype.Custom, nil, func(context.Context, int) (struct{}, error) { return struct{}{}, nil })
}

// count streams the numbers from 0 to n-1, then returns n.
Expand All @@ -68,7 +70,7 @@ func count(ctx context.Context, n int, cb func(context.Context, int) error) (int

func TestActionStreaming(t *testing.T) {
ctx := context.Background()
a := NewStreamingAction("count", ActionTypeCustom, nil, count)
a := NewStreamingAction("count", atype.Custom, nil, count)
const n = 3

// Non-streaming.
Expand Down Expand Up @@ -101,7 +103,7 @@ func TestActionStreaming(t *testing.T) {
func TestActionTracing(t *testing.T) {
ctx := context.Background()
const actionName = "TestTracing-inc"
a := NewAction(actionName, ActionTypeCustom, nil, inc)
a := NewAction(actionName, atype.Custom, nil, inc)
if _, err := a.Run(context.Background(), 3, nil); err != nil {
t.Fatal(err)
}
Expand Down
3 changes: 2 additions & 1 deletion go/core/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/firebase/genkit/go/core/logger"
"github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/internal"
"github.com/firebase/genkit/go/internal/atype"
"github.com/google/uuid"
"github.com/invopop/jsonschema"
otrace "go.opentelemetry.io/otel/trace"
Expand Down Expand Up @@ -258,7 +259,7 @@ func (f *Flow[In, Out, Stream]) action() *Action[*flowInstruction[In], *flowStat
tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true")
return f.runInstruction(ctx, inst, streamingCallback[Stream](cb))
}
return NewStreamingAction(f.name, ActionTypeFlow, metadata, cback)
return NewStreamingAction(f.name, atype.Flow, metadata, cback)
}

// runInstruction performs one of several actions on a flow, as determined by msg.
Expand Down
22 changes: 3 additions & 19 deletions go/core/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"sync"

"github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/internal/atype"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"golang.org/x/exp/maps"
)
Expand Down Expand Up @@ -90,23 +91,6 @@ const (
EnvironmentProd Environment = "prod" // production: user data, SLOs, etc.
)

// An ActionType is the kind of an action.
type ActionType string

const (
ActionTypeChatLLM ActionType = "chat-llm"
ActionTypeTextLLM ActionType = "text-llm"
ActionTypeRetriever ActionType = "retriever"
ActionTypeIndexer ActionType = "indexer"
ActionTypeEmbedder ActionType = "embedder"
ActionTypeEvaluator ActionType = "evaluator"
ActionTypeFlow ActionType = "flow"
ActionTypeModel ActionType = "model"
ActionTypePrompt ActionType = "prompt"
ActionTypeTool ActionType = "tool"
ActionTypeCustom ActionType = "custom"
)

// RegisterAction records the action in the global registry.
// It panics if an action with the same type, provider and name is already
// registered.
Expand Down Expand Up @@ -138,15 +122,15 @@ func (r *registry) lookupAction(key string) action {

// LookupAction returns the action for the given key in the global registry,
// or nil if there is none.
func LookupAction(typ ActionType, provider, name string) action {
func LookupAction(typ atype.ActionType, provider, name string) action {
key := fmt.Sprintf("/%s/%s/%s", typ, provider, name)
return globalRegistry.lookupAction(key)
}

// LookupActionFor returns the action for the given key in the global registry,
// or nil if there is none.
// It panics if the action is of the wrong type.
func LookupActionFor[In, Out, Stream any](typ ActionType, provider, name string) *Action[In, Out, Stream] {
func LookupActionFor[In, Out, Stream any](typ atype.ActionType, provider, name string) *Action[In, Out, Stream] {
a := LookupAction(typ, provider, name)
if a == nil {
return nil
Expand Down
5 changes: 3 additions & 2 deletions go/core/servers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"testing"

"github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/internal/atype"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/invopop/jsonschema"
Expand All @@ -38,10 +39,10 @@ func TestDevServer(t *testing.T) {
if err != nil {
t.Fatal(err)
}
r.registerAction("devServer", NewAction("inc", ActionTypeCustom, map[string]any{
r.registerAction("devServer", NewAction("inc", atype.Custom, map[string]any{
"foo": "bar",
}, inc))
r.registerAction("devServer", NewAction("dec", ActionTypeCustom, map[string]any{
r.registerAction("devServer", NewAction("dec", atype.Custom, map[string]any{
"bar": "baz",
}, dec))
srv := httptest.NewServer(newDevServeMux(r))
Expand Down
33 changes: 33 additions & 0 deletions go/internal/atype/atype.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package atype provides types for Genkit actions.
package atype

// An ActionType is the kind of an action.
type ActionType string

const (
ChatLLM ActionType = "chat-llm"
TextLLM ActionType = "text-llm"
Retriever ActionType = "retriever"
Indexer ActionType = "indexer"
Embedder ActionType = "embedder"
Evaluator ActionType = "evaluator"
Flow ActionType = "flow"
Model ActionType = "model"
Prompt ActionType = "prompt"
Tool ActionType = "tool"
Custom ActionType = "custom"
)
Loading