Skip to content

Commit

Permalink
[Go] make ai.Prompt its own type (#463)
Browse files Browse the repository at this point in the history
See previous PRs.
  • Loading branch information
jba authored Jun 24, 2024
1 parent e921e5a commit a50d4f1
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 22 deletions.
26 changes: 13 additions & 13 deletions go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,31 @@ import (
"github.com/invopop/jsonschema"
)

// A PromptAction is used to render a prompt template,
// producing a [GenerateRequest] that may be passed to a [ModelAction].
type PromptAction = core.Action[any, *GenerateRequest, struct{}]
// A Prompt is used to render a prompt template,
// producing a [GenerateRequest] that may be passed to a [Model].
type Prompt core.Action[any, *GenerateRequest, struct{}]

// DefinePrompt takes a function that renders a prompt template
// into a [GenerateRequest] that may be passed to a [ModelAction].
// into a [GenerateRequest] that may be passed to a [Model].
// The prompt expects some input described by inputSchema.
// DefinePrompt registers the function as an action,
// and returns a [PromptAction] that runs it.
func DefinePrompt(provider, name string, metadata map[string]any, render func(context.Context, any) (*GenerateRequest, error), inputSchema *jsonschema.Schema) *PromptAction {
// and returns a [Prompt] that runs it.
func DefinePrompt(provider, name string, metadata map[string]any, render func(context.Context, any) (*GenerateRequest, error), inputSchema *jsonschema.Schema) *Prompt {
mm := maps.Clone(metadata)
if mm == nil {
mm = make(map[string]any)
}
mm["type"] = "prompt"
return core.DefineActionWithInputSchema(provider, name, atype.Prompt, mm, render, inputSchema)
return (*Prompt)(core.DefineActionWithInputSchema(provider, name, atype.Prompt, mm, render, inputSchema))
}

// LookupPrompt looks up a [PromptAction] registered by [DefinePrompt].
// LookupPrompt looks up a [Prompt] registered by [DefinePrompt].
// It returns nil if the prompt was not defined.
func LookupPrompt(provider, name string) *PromptAction {
return core.LookupActionFor[any, *GenerateRequest, struct{}](atype.Prompt, provider, name)
func LookupPrompt(provider, name string) *Prompt {
return (*Prompt)(core.LookupActionFor[any, *GenerateRequest, struct{}](atype.Prompt, provider, name))
}

// Render renders a [PromptAction] with some input data.
func Render(ctx context.Context, p *PromptAction, input any) (*GenerateRequest, error) {
return p.Run(ctx, input, nil)
// Render renders the [Prompt] with some input data.
func (p *Prompt) Render(ctx context.Context, input any) (*GenerateRequest, error) {
return (*core.Action[any, *GenerateRequest, struct{}])(p).Run(ctx, input, nil)
}
8 changes: 4 additions & 4 deletions go/plugins/dotprompt/dotprompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ type Prompt struct {
// A hash of the prompt contents.
hash string

// An action that renders the prompt.
action *ai.PromptAction
// A prompt that renders the prompt.
prompt *ai.Prompt
}

// Config is optional configuration for a [Prompt].
Expand Down Expand Up @@ -290,10 +290,10 @@ func Define(name, templateText string, cfg Config) (*Prompt, error) {
// genkit action and flow mechanisms.
func New(name, templateText string, cfg Config) (*Prompt, error) {
if cfg.ModelName == "" && cfg.Model == nil {
return nil, errors.New("dotprompt.New: config must specify either Model or ModelAction")
return nil, errors.New("dotprompt.New: config must specify either ModelName or Model")
}
if cfg.ModelName != "" && cfg.Model != nil {
return nil, errors.New("dotprompt.New: config must specify exactly one of Model and ModelAction")
return nil, errors.New("dotprompt.New: config must specify exactly one of ModelName and Model")
}
hash := fmt.Sprintf("%02x", sha256.Sum256([]byte(templateText)))
return newPrompt(name, templateText, hash, cfg)
Expand Down
10 changes: 5 additions & 5 deletions go/plugins/dotprompt/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
)

// PromptRequest is a request to execute a dotprompt template and
// pass the result to a [ModelAction].
// pass the result to a [Model].
type PromptRequest struct {
// Input fields for the prompt. If not nil this should be a struct
// or pointer to a struct that matches the prompt's input schema.
Expand Down Expand Up @@ -129,7 +129,7 @@ func (p *Prompt) buildRequest(ctx context.Context, input any) (*ai.GenerateReque

// Register registers an action to render a prompt.
func (p *Prompt) Register() error {
if p.action != nil {
if p.prompt != nil {
return nil
}

Expand All @@ -152,7 +152,7 @@ func (p *Prompt) Register() error {
"template": p.TemplateText,
},
}
p.action = ai.DefinePrompt("dotprompt", name, metadata, p.buildRequest, p.Config.InputSchema)
p.prompt = ai.DefinePrompt("dotprompt", name, metadata, p.buildRequest, p.Config.InputSchema)

return nil
}
Expand All @@ -167,8 +167,8 @@ func (p *Prompt) Generate(ctx context.Context, pr *PromptRequest, cb func(contex

var genReq *ai.GenerateRequest
var err error
if p.action != nil {
genReq, err = ai.Render(ctx, p.action, pr.Variables)
if p.prompt != nil {
genReq, err = p.prompt.Render(ctx, pr.Variables)
} else {
genReq, err = p.buildRequest(ctx, pr.Variables)
}
Expand Down

0 comments on commit a50d4f1

Please sign in to comment.