From 3029cf73dfe0092c3db84ce08bc7a8e075f36328 Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Mon, 14 Oct 2024 10:10:36 -0700 Subject: [PATCH] chore: add ability to pass args to input/output filters --- pkg/runner/input.go | 7 +- pkg/runner/output.go | 40 +++++++- pkg/runner/runner.go | 16 ++-- pkg/tests/runner2_test.go | 91 +++++++++++++++++++ .../testdata/TestFilterArgs/step1.golden | 6 ++ 5 files changed, 149 insertions(+), 11 deletions(-) create mode 100644 pkg/tests/testdata/TestFilterArgs/step1.golden diff --git a/pkg/runner/input.go b/pkg/runner/input.go index 23228813..360e6274 100644 --- a/pkg/runner/input.go +++ b/pkg/runner/input.go @@ -18,12 +18,15 @@ func (r *Runner) handleInput(callCtx engine.Context, monitor Monitor, env []stri data := map[string]any{} _ = json.Unmarshal([]byte(input), &data) data["input"] = input - inputData, err := json.Marshal(data) + + inputArgs, err := argsForFilters(callCtx.Program, inputToolRef, &State{ + StartInput: &input, + }, data) if err != nil { return "", fmt.Errorf("failed to marshal input: %w", err) } - res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, inputToolRef.ToolID, string(inputData), "", engine.InputToolCategory) + res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, inputToolRef.ToolID, inputArgs, "", engine.InputToolCategory) if err != nil { return "", err } diff --git a/pkg/runner/output.go b/pkg/runner/output.go index e5fe849d..8a6aefdb 100644 --- a/pkg/runner/output.go +++ b/pkg/runner/output.go @@ -4,12 +4,48 @@ import ( "encoding/json" "errors" "fmt" + "maps" + "strings" "github.com/gptscript-ai/gptscript/pkg/engine" "github.com/gptscript-ai/gptscript/pkg/types" ) -func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []string, state *State, retErr error) (*State, error) { +func argsForFilters(prg *types.Program, tool types.ToolReference, startState *State, filterDefinedInput map[string]any) (string, error) { + startInput := "" + if startState.ResumeInput != nil { + startInput = *startState.ResumeInput + } else if startState.StartInput != nil { + startInput = *startState.StartInput + } + + parsedArgs, err := getToolRefInput(prg, tool, startInput) + if err != nil { + return "", err + } + + argData := map[string]any{} + if strings.HasPrefix(parsedArgs, "{") { + if err := json.Unmarshal([]byte(parsedArgs), &argData); err != nil { + return "", fmt.Errorf("failed to unmarshal parsedArgs for filter: %w", err) + } + } else if _, hasInput := filterDefinedInput["input"]; parsedArgs != "" && !hasInput { + argData["input"] = parsedArgs + } + + resultData := map[string]any{} + maps.Copy(resultData, filterDefinedInput) + maps.Copy(resultData, argData) + + result, err := json.Marshal(resultData) + if err != nil { + return "", fmt.Errorf("failed to marshal resultData for filter: %w", err) + } + + return string(result), nil +} + +func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []string, startState, state *State, retErr error) (*State, error) { outputToolRefs, err := callCtx.Tool.GetToolsByType(callCtx.Program, types.ToolTypeOutput) if err != nil { return nil, err @@ -40,7 +76,7 @@ func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []str } for _, outputToolRef := range outputToolRefs { - inputData, err := json.Marshal(map[string]any{ + inputData, err := argsForFilters(callCtx.Program, outputToolRef, startState, map[string]any{ "output": output, "continuation": continuation, "chat": callCtx.Tool.Chat, diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 7ac9fae0..18bc1bc4 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -269,6 +269,9 @@ func getToolRefInput(prg *types.Program, ref types.ToolReference, input string) outputMap := map[string]interface{}{} _ = json.Unmarshal([]byte(input), &inputMap) + for k, v := range inputMap { + inputMap[strings.ToLower(k)] = v + } fields := strings.Fields(ref.Arg) @@ -291,7 +294,7 @@ func getToolRefInput(prg *types.Program, ref types.ToolReference, input string) key := strings.TrimPrefix(field, "$") key = strings.TrimPrefix(key, "{") key = strings.TrimSuffix(key, "}") - val = inputMap[key] + val = inputMap[strings.ToLower(key)] } else { val = field } @@ -425,6 +428,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en msg = "Tool call request has been denied" } return &State{ + StartInput: &input, Continuation: &engine.Return{ Result: &msg, }, @@ -438,6 +442,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en } return &State{ + StartInput: &input, Continuation: ret, }, nil } @@ -447,6 +452,8 @@ type State struct { ContinuationToolID string `json:"continuationToolID,omitempty"` Result *string `json:"result,omitempty"` + StartInput *string `json:"startInput,omitempty"` + ResumeInput *string `json:"resumeInput,omitempty"` SubCalls []SubCallResult `json:"subCalls,omitempty"` SubCallID string `json:"subCallID,omitempty"` @@ -485,14 +492,9 @@ func (s State) ContinuationContent() (string, error) { return "", fmt.Errorf("illegal state: no result message found in chat response") } -type Needed struct { - Content string `json:"content,omitempty"` - Input string `json:"input,omitempty"` -} - func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, state *State) (retState *State, retErr error) { defer func() { - retState, retErr = r.handleOutput(callCtx, monitor, env, retState, retErr) + retState, retErr = r.handleOutput(callCtx, monitor, env, state, retState, retErr) }() if state.Continuation == nil { diff --git a/pkg/tests/runner2_test.go b/pkg/tests/runner2_test.go index 8dbd2ba7..165f86c8 100644 --- a/pkg/tests/runner2_test.go +++ b/pkg/tests/runner2_test.go @@ -2,10 +2,12 @@ package tests import ( "context" + "encoding/json" "testing" "github.com/gptscript-ai/gptscript/pkg/loader" "github.com/gptscript-ai/gptscript/pkg/tests/tester" + "github.com/hexops/autogold/v2" "github.com/stretchr/testify/require" ) @@ -111,3 +113,92 @@ echo '{"env": {"CRED2": "that also worked"}}' resp, err := r.Chat(context.Background(), nil, prg, nil, "") r.AssertStep(t, resp, err) } + +func TestFilterArgs(t *testing.T) { + r := tester.NewRunner(t) + prg, err := loader.ProgramFromSource(context.Background(), ` +inputfilters: input with ${Foo} +inputfilters: input with foo +inputfilters: input with * +outputfilters: output with * +outputfilters: output with foo +outputfilters: output with ${Foo} +params: Foo: a description + +#!/bin/bash +echo ${FOO} + +--- +name: input +params: notfoo: a description + +#!/bin/bash +echo "${GPTSCRIPT_INPUT}" + +--- +name: output +params: notfoo: a description + +#!/bin/bash +echo "${GPTSCRIPT_INPUT}" +`, "") + require.NoError(t, err) + + resp, err := r.Chat(context.Background(), nil, prg, nil, `{"foo":"baz", "start": true}`) + r.AssertStep(t, resp, err) + + data := map[string]any{} + err = json.Unmarshal([]byte(resp.Content), &data) + require.NoError(t, err) + + autogold.Expect(map[string]interface{}{ + "chat": false, + "continuation": false, + "notfoo": "baz", + "output": `{"chat":false,"continuation":false,"notfoo":"foo","output":"{\"chat\":false,\"continuation\":false,\"foo\":\"baz\",\"input\":\"{\\\"foo\\\":\\\"baz\\\",\\\"input\\\":\\\"{\\\\\\\"foo\\\\\\\":\\\\\\\"baz\\\\\\\", \\\\\\\"start\\\\\\\": true}\\\",\\\"notfoo\\\":\\\"baz\\\",\\\"start\\\":true}\\n\",\"notfoo\":\"foo\",\"output\":\"baz\\n\",\"start\":true}\n"} +`, + }).Equal(t, data) + + val := data["output"].(string) + data = map[string]any{} + err = json.Unmarshal([]byte(val), &data) + require.NoError(t, err) + autogold.Expect(map[string]interface{}{ + "chat": false, + "continuation": false, + "notfoo": "foo", + "output": `{"chat":false,"continuation":false,"foo":"baz","input":"{\"foo\":\"baz\",\"input\":\"{\\\"foo\\\":\\\"baz\\\", \\\"start\\\": true}\",\"notfoo\":\"baz\",\"start\":true}\n","notfoo":"foo","output":"baz\n","start":true} +`, + }).Equal(t, data) + + val = data["output"].(string) + data = map[string]any{} + err = json.Unmarshal([]byte(val), &data) + require.NoError(t, err) + autogold.Expect(map[string]interface{}{ + "chat": false, + "continuation": false, + "foo": "baz", "input": `{"foo":"baz","input":"{\"foo\":\"baz\", \"start\": true}","notfoo":"baz","start":true} +`, + "notfoo": "foo", + "output": "baz\n", + "start": true, + }).Equal(t, data) + + val = data["input"].(string) + data = map[string]any{} + err = json.Unmarshal([]byte(val), &data) + require.NoError(t, err) + autogold.Expect(map[string]interface{}{ + "foo": "baz", + "input": `{"foo":"baz", "start": true}`, + "notfoo": "baz", + "start": true, + }).Equal(t, data) + + val = data["input"].(string) + data = map[string]any{} + err = json.Unmarshal([]byte(val), &data) + require.NoError(t, err) + autogold.Expect(map[string]interface{}{"foo": "baz", "start": true}).Equal(t, data) +} diff --git a/pkg/tests/testdata/TestFilterArgs/step1.golden b/pkg/tests/testdata/TestFilterArgs/step1.golden new file mode 100644 index 00000000..a6e6599b --- /dev/null +++ b/pkg/tests/testdata/TestFilterArgs/step1.golden @@ -0,0 +1,6 @@ +`{ + "done": true, + "content": "{\"chat\":false,\"continuation\":false,\"notfoo\":\"baz\",\"output\":\"{\\\"chat\\\":false,\\\"continuation\\\":false,\\\"notfoo\\\":\\\"foo\\\",\\\"output\\\":\\\"{\\\\\\\"chat\\\\\\\":false,\\\\\\\"continuation\\\\\\\":false,\\\\\\\"foo\\\\\\\":\\\\\\\"baz\\\\\\\",\\\\\\\"input\\\\\\\":\\\\\\\"{\\\\\\\\\\\\\\\"foo\\\\\\\\\\\\\\\":\\\\\\\\\\\\\\\"baz\\\\\\\\\\\\\\\",\\\\\\\\\\\\\\\"input\\\\\\\\\\\\\\\":\\\\\\\\\\\\\\\"{\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\"foo\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\":\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\"baz\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\", \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\"start\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\": true}\\\\\\\\\\\\\\\",\\\\\\\\\\\\\\\"notfoo\\\\\\\\\\\\\\\":\\\\\\\\\\\\\\\"baz\\\\\\\\\\\\\\\",\\\\\\\\\\\\\\\"start\\\\\\\\\\\\\\\":true}\\\\\\\\n\\\\\\\",\\\\\\\"notfoo\\\\\\\":\\\\\\\"foo\\\\\\\",\\\\\\\"output\\\\\\\":\\\\\\\"baz\\\\\\\\n\\\\\\\",\\\\\\\"start\\\\\\\":true}\\\\n\\\"}\\n\"}\n", + "toolID": "", + "state": null +}`