Skip to content

Commit

Permalink
fix: [Go] action name must include provider (#438)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelgj authored Jun 20, 2024
1 parent 8fd188a commit c25b351
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 29 deletions.
1 change: 0 additions & 1 deletion go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ func DefinePrompt(provider, name string, metadata map[string]any, render func(co
mm = make(map[string]any)
}
mm["type"] = "prompt"
mm["prompt"] = true // required by genkit ui
return core.DefineActionWithInputSchema(provider, name, atype.Prompt, mm, render, inputSchema)
}

Expand Down
12 changes: 6 additions & 6 deletions go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ func DefineAction[In, Out any](provider, name string, atype atype.ActionType, me
}

func defineAction[In, Out any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
a := newAction(name, atype, metadata, fn)
r.registerAction(provider, a)
a := newAction(provider+"/"+name, atype, metadata, fn)
r.registerAction(a)
return a
}

Expand All @@ -83,8 +83,8 @@ func DefineStreamingAction[In, Out, Stream any](provider, name string, atype aty
}

func defineStreamingAction[In, Out, Stream any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
a := newStreamingAction(name, atype, metadata, fn)
r.registerAction(provider, a)
a := newStreamingAction(provider+"/"+name, atype, metadata, fn)
r.registerAction(a)
return a
}

Expand All @@ -101,8 +101,8 @@ func DefineActionWithInputSchema[Out any](provider, name string, atype atype.Act
}

func defineActionWithInputSchema[Out any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, any) (Out, error), inputSchema *jsonschema.Schema) *Action[any, Out, struct{}] {
a := newActionWithInputSchema(name, atype, metadata, fn, inputSchema)
r.registerAction(provider, a)
a := newActionWithInputSchema(provider+"/"+name, atype, metadata, fn, inputSchema)
r.registerAction(a)
return a
}

Expand Down
2 changes: 1 addition & 1 deletion go/core/conformance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestFlowConformance(t *testing.T) {
t.Fatal(err)
}
_ = defineFlow(r, test.Name, flowFunction(test.Commands))
key := fmt.Sprintf("/flow/%s/%[1]s", test.Name)
key := fmt.Sprintf("/flow/%s", test.Name)
resp, err := runAction(context.Background(), r, key, test.Input, nil)
if err != nil {
t.Fatal(err)
Expand Down
2 changes: 1 addition & 1 deletion go/core/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func defineFlow[In, Out, Stream any](r *registry, name string, fn Func[In, Out,
// TODO(jba): set stateStore?
}
a := f.action()
r.registerAction(name, a)
r.registerAction(a)
// TODO(jba): this is a roundabout way to transmit the tracing state. Is there a cleaner way?
f.tstate = a.tstate
r.registerFlow(f)
Expand Down
5 changes: 2 additions & 3 deletions go/core/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ const (
// registerAction records the action in the registry.
// It panics if an action with the same type, provider and name is already
// registered.
func (r *registry) registerAction(provider string, a action) {
key := fmt.Sprintf("/%s/%s/%s", a.actionType(), provider, a.Name())
func (r *registry) registerAction(a action) {
key := fmt.Sprintf("/%s/%s", a.actionType(), a.Name())
r.mu.Lock()
defer r.mu.Unlock()
if _, ok := r.actions[key]; ok {
Expand All @@ -105,7 +105,6 @@ func (r *registry) registerAction(provider string, a action) {
r.actions[key] = a
slog.Info("RegisterAction",
"type", a.actionType(),
"provider", provider,
"name", a.Name())
}

Expand Down
8 changes: 4 additions & 4 deletions go/core/servers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ func TestDevServer(t *testing.T) {
if err != nil {
t.Fatal(err)
}
r.registerAction("devServer", newAction("inc", atype.Custom, map[string]any{
r.registerAction(newAction("devServer/inc", atype.Custom, map[string]any{
"foo": "bar",
}, inc))
r.registerAction("devServer", newAction("dec", atype.Custom, map[string]any{
r.registerAction(newAction("devServer/dec", atype.Custom, map[string]any{
"bar": "baz",
}, dec))
srv := httptest.NewServer(newDevServeMux(r))
Expand Down Expand Up @@ -87,7 +87,7 @@ func TestDevServer(t *testing.T) {
want := map[string]actionDesc{
"/custom/devServer/inc": {
Key: "/custom/devServer/inc",
Name: "inc",
Name: "devServer/inc",
InputSchema: &jsonschema.Schema{Type: "integer"},
OutputSchema: &jsonschema.Schema{Type: "integer"},
Metadata: map[string]any{"foo": "bar"},
Expand All @@ -96,7 +96,7 @@ func TestDevServer(t *testing.T) {
Key: "/custom/devServer/dec",
InputSchema: &jsonschema.Schema{Type: "integer"},
OutputSchema: &jsonschema.Schema{Type: "integer"},
Name: "dec",
Name: "devServer/dec",
Metadata: map[string]any{"bar": "baz"},
},
}
Expand Down
14 changes: 9 additions & 5 deletions go/plugins/dotprompt/dotprompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,12 @@ type Prompt struct {

Config

// The template for the prompt.
// The parsed prompt template.
Template *raymond.Template

// The original prompt template text.
TemplateText string

// A hash of the prompt contents.
hash string

Expand Down Expand Up @@ -197,10 +200,11 @@ func newPrompt(name, templateText, hash string, config Config) (*Prompt, error)
}
template.RegisterHelpers(templateHelpers)
return &Prompt{
Name: name,
Config: config,
hash: hash,
Template: template,
Name: name,
Config: config,
hash: hash,
Template: template,
TemplateText: templateText,
}, nil
}

Expand Down
18 changes: 16 additions & 2 deletions go/plugins/dotprompt/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,11 @@ func (p *Prompt) buildVariables(variables any) (map[string]any, error) {
}

v := reflect.Indirect(reflect.ValueOf(variables))
if v.Kind() == reflect.Map {
return variables.(map[string]any), nil
}
if v.Kind() != reflect.Struct {
return nil, errors.New("dotprompt: fields not a struct or pointer to a struct")
return nil, errors.New("dotprompt: fields not a struct or pointer to a struct or a map")
}
vt := v.Type()

Expand Down Expand Up @@ -138,7 +141,18 @@ func (p *Prompt) Register() error {
name += "." + p.Variant
}

p.action = ai.DefinePrompt("dotprompt", name, nil, p.buildRequest, p.Config.InputSchema)
// TODO: Undo clearing of the Version once Monaco Editor supports newer than JSON schema draft-07.
p.InputSchema.Version = ""

metadata := map[string]any{
"prompt": map[string]any{
"name": p.Name,
"input": map[string]any{"schema": p.InputSchema},
"output": map[string]any{"format": p.OutputFormat},
"template": p.TemplateText,
},
}
p.action = ai.DefinePrompt("dotprompt", name, metadata, p.buildRequest, p.Config.InputSchema)

return nil
}
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/localvec/localvec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func TestInit(t *testing.T) {
if err != nil {
t.Fatal(err)
}
want := []string{"a", "b"}
want := []string{"devLocalVectorStore/a", "devLocalVectorStore/b"}

if got := names(is); !slices.Equal(got, want) {
t.Errorf("got %v, want %v", got, want)
Expand Down
10 changes: 5 additions & 5 deletions go/samples/coffee-shop/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,15 @@ func main() {
log.Fatal(err)
}

r := &jsonschema.Reflector{
AllowAdditionalProperties: false,
DoNotReference: true,
}
g := googleai.Model("gemini-1.5-pro")
simpleGreetingPrompt, err := dotprompt.Define("simpleGreeting", simpleGreetingPromptTemplate,
dotprompt.Config{
ModelAction: g,
InputSchema: jsonschema.Reflect(simpleGreetingInput{}),
InputSchema: r.Reflect(simpleGreetingInput{}),
OutputFormat: ai.OutputFormatText,
},
)
Expand Down Expand Up @@ -176,10 +180,6 @@ func main() {
return text, nil
})

r := &jsonschema.Reflector{
AllowAdditionalProperties: false,
DoNotReference: true,
}
schema := r.Reflect(simpleGreetingOutput{})
jsonBytes, err := schema.MarshalJSON()
if err != nil {
Expand Down

0 comments on commit c25b351

Please sign in to comment.