diff --git a/pkg/builtin/builtin.go b/pkg/builtin/builtin.go index 23db5152..f972d14c 100644 --- a/pkg/builtin/builtin.go +++ b/pkg/builtin/builtin.go @@ -217,6 +217,7 @@ var tools = map[string]types.Tool{ "message", "The message to display to the user", "fields", "A comma-separated list of fields to prompt for", "sensitive", "(true or false) Whether the input should be hidden", + "metadata", "(optional) A JSON object of metadata to attach to the prompt", ), }, BuiltinFunc: prompt.SysPrompt, diff --git a/pkg/prompt/prompt.go b/pkg/prompt/prompt.go index 44cb20f1..f91a04b6 100644 --- a/pkg/prompt/prompt.go +++ b/pkg/prompt/prompt.go @@ -51,25 +51,29 @@ func sysPromptHTTP(ctx context.Context, envs []string, url string, prompt types. func SysPrompt(ctx context.Context, envs []string, input string, _ chan<- string) (_ string, err error) { var params struct { - Message string `json:"message,omitempty"` - Fields string `json:"fields,omitempty"` - Sensitive string `json:"sensitive,omitempty"` + Message string `json:"message,omitempty"` + Fields string `json:"fields,omitempty"` + Sensitive string `json:"sensitive,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` } if err := json.Unmarshal([]byte(input), ¶ms); err != nil { return "", err } + var fields []string for _, env := range envs { if url, ok := strings.CutPrefix(env, types.PromptURLEnvVar+"="); ok { - var fields []string if params.Fields != "" { fields = strings.Split(params.Fields, ",") } + httpPrompt := types.Prompt{ Message: params.Message, Fields: fields, Sensitive: params.Sensitive == "true", + Metadata: params.Metadata, } + return sysPromptHTTP(ctx, envs, url, httpPrompt) } } diff --git a/pkg/sdkserver/prompt.go b/pkg/sdkserver/prompt.go index 8d34fc53..a519f7b2 100644 --- a/pkg/sdkserver/prompt.go +++ b/pkg/sdkserver/prompt.go @@ -76,11 +76,7 @@ func (s *server) prompt(w http.ResponseWriter, r *http.Request) { }(id) s.events.C <- event{ - Prompt: types.Prompt{ - Message: prompt.Message, - Fields: prompt.Fields, - Sensitive: prompt.Sensitive, - }, + Prompt: prompt, Event: gserver.Event{ RunID: id, Event: runner.Event{ diff --git a/pkg/types/prompt.go b/pkg/types/prompt.go index ea17c11c..653ad066 100644 --- a/pkg/types/prompt.go +++ b/pkg/types/prompt.go @@ -6,7 +6,8 @@ const ( ) type Prompt struct { - Message string `json:"message,omitempty"` - Fields []string `json:"fields,omitempty"` - Sensitive bool `json:"sensitive,omitempty"` + Message string `json:"message,omitempty"` + Fields []string `json:"fields,omitempty"` + Sensitive bool `json:"sensitive,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` }