Skip to content

Commit

Permalink
chore: add ability to pass args to input/output filters
Browse files Browse the repository at this point in the history
  • Loading branch information
ibuildthecloud committed Oct 14, 2024
1 parent 6ec5178 commit 3029cf7
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 11 deletions.
7 changes: 5 additions & 2 deletions pkg/runner/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ func (r *Runner) handleInput(callCtx engine.Context, monitor Monitor, env []stri
data := map[string]any{}
_ = json.Unmarshal([]byte(input), &data)
data["input"] = input
inputData, err := json.Marshal(data)

inputArgs, err := argsForFilters(callCtx.Program, inputToolRef, &State{
StartInput: &input,
}, data)
if err != nil {
return "", fmt.Errorf("failed to marshal input: %w", err)
}

res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, inputToolRef.ToolID, string(inputData), "", engine.InputToolCategory)
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, inputToolRef.ToolID, inputArgs, "", engine.InputToolCategory)
if err != nil {
return "", err
}
Expand Down
40 changes: 38 additions & 2 deletions pkg/runner/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,48 @@ import (
"encoding/json"
"errors"
"fmt"
"maps"
"strings"

"github.com/gptscript-ai/gptscript/pkg/engine"
"github.com/gptscript-ai/gptscript/pkg/types"
)

func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []string, state *State, retErr error) (*State, error) {
func argsForFilters(prg *types.Program, tool types.ToolReference, startState *State, filterDefinedInput map[string]any) (string, error) {
startInput := ""
if startState.ResumeInput != nil {
startInput = *startState.ResumeInput
} else if startState.StartInput != nil {
startInput = *startState.StartInput
}

parsedArgs, err := getToolRefInput(prg, tool, startInput)
if err != nil {
return "", err
}

argData := map[string]any{}
if strings.HasPrefix(parsedArgs, "{") {
if err := json.Unmarshal([]byte(parsedArgs), &argData); err != nil {
return "", fmt.Errorf("failed to unmarshal parsedArgs for filter: %w", err)
}
} else if _, hasInput := filterDefinedInput["input"]; parsedArgs != "" && !hasInput {
argData["input"] = parsedArgs
}

resultData := map[string]any{}
maps.Copy(resultData, filterDefinedInput)
maps.Copy(resultData, argData)

result, err := json.Marshal(resultData)
if err != nil {
return "", fmt.Errorf("failed to marshal resultData for filter: %w", err)
}

return string(result), nil
}

func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []string, startState, state *State, retErr error) (*State, error) {
outputToolRefs, err := callCtx.Tool.GetToolsByType(callCtx.Program, types.ToolTypeOutput)
if err != nil {
return nil, err
Expand Down Expand Up @@ -40,7 +76,7 @@ func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []str
}

for _, outputToolRef := range outputToolRefs {
inputData, err := json.Marshal(map[string]any{
inputData, err := argsForFilters(callCtx.Program, outputToolRef, startState, map[string]any{
"output": output,
"continuation": continuation,
"chat": callCtx.Tool.Chat,
Expand Down
16 changes: 9 additions & 7 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ func getToolRefInput(prg *types.Program, ref types.ToolReference, input string)
outputMap := map[string]interface{}{}

_ = json.Unmarshal([]byte(input), &inputMap)
for k, v := range inputMap {
inputMap[strings.ToLower(k)] = v
}

fields := strings.Fields(ref.Arg)

Expand All @@ -291,7 +294,7 @@ func getToolRefInput(prg *types.Program, ref types.ToolReference, input string)
key := strings.TrimPrefix(field, "$")
key = strings.TrimPrefix(key, "{")
key = strings.TrimSuffix(key, "}")
val = inputMap[key]
val = inputMap[strings.ToLower(key)]
} else {
val = field
}
Expand Down Expand Up @@ -425,6 +428,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
msg = "Tool call request has been denied"
}
return &State{
StartInput: &input,
Continuation: &engine.Return{
Result: &msg,
},
Expand All @@ -438,6 +442,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
}

return &State{
StartInput: &input,
Continuation: ret,
}, nil
}
Expand All @@ -447,6 +452,8 @@ type State struct {
ContinuationToolID string `json:"continuationToolID,omitempty"`
Result *string `json:"result,omitempty"`

StartInput *string `json:"startInput,omitempty"`

ResumeInput *string `json:"resumeInput,omitempty"`
SubCalls []SubCallResult `json:"subCalls,omitempty"`
SubCallID string `json:"subCallID,omitempty"`
Expand Down Expand Up @@ -485,14 +492,9 @@ func (s State) ContinuationContent() (string, error) {
return "", fmt.Errorf("illegal state: no result message found in chat response")
}

type Needed struct {
Content string `json:"content,omitempty"`
Input string `json:"input,omitempty"`
}

func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, state *State) (retState *State, retErr error) {
defer func() {
retState, retErr = r.handleOutput(callCtx, monitor, env, retState, retErr)
retState, retErr = r.handleOutput(callCtx, monitor, env, state, retState, retErr)
}()

if state.Continuation == nil {
Expand Down
91 changes: 91 additions & 0 deletions pkg/tests/runner2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package tests

import (
"context"
"encoding/json"
"testing"

"github.com/gptscript-ai/gptscript/pkg/loader"
"github.com/gptscript-ai/gptscript/pkg/tests/tester"
"github.com/hexops/autogold/v2"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -111,3 +113,92 @@ echo '{"env": {"CRED2": "that also worked"}}'
resp, err := r.Chat(context.Background(), nil, prg, nil, "")
r.AssertStep(t, resp, err)
}

func TestFilterArgs(t *testing.T) {
r := tester.NewRunner(t)
prg, err := loader.ProgramFromSource(context.Background(), `
inputfilters: input with ${Foo}
inputfilters: input with foo
inputfilters: input with *
outputfilters: output with *
outputfilters: output with foo
outputfilters: output with ${Foo}
params: Foo: a description
#!/bin/bash
echo ${FOO}
---
name: input
params: notfoo: a description
#!/bin/bash
echo "${GPTSCRIPT_INPUT}"
---
name: output
params: notfoo: a description
#!/bin/bash
echo "${GPTSCRIPT_INPUT}"
`, "")
require.NoError(t, err)

resp, err := r.Chat(context.Background(), nil, prg, nil, `{"foo":"baz", "start": true}`)
r.AssertStep(t, resp, err)

data := map[string]any{}
err = json.Unmarshal([]byte(resp.Content), &data)
require.NoError(t, err)

autogold.Expect(map[string]interface{}{
"chat": false,
"continuation": false,
"notfoo": "baz",
"output": `{"chat":false,"continuation":false,"notfoo":"foo","output":"{\"chat\":false,\"continuation\":false,\"foo\":\"baz\",\"input\":\"{\\\"foo\\\":\\\"baz\\\",\\\"input\\\":\\\"{\\\\\\\"foo\\\\\\\":\\\\\\\"baz\\\\\\\", \\\\\\\"start\\\\\\\": true}\\\",\\\"notfoo\\\":\\\"baz\\\",\\\"start\\\":true}\\n\",\"notfoo\":\"foo\",\"output\":\"baz\\n\",\"start\":true}\n"}
`,
}).Equal(t, data)

val := data["output"].(string)
data = map[string]any{}
err = json.Unmarshal([]byte(val), &data)
require.NoError(t, err)
autogold.Expect(map[string]interface{}{
"chat": false,
"continuation": false,
"notfoo": "foo",
"output": `{"chat":false,"continuation":false,"foo":"baz","input":"{\"foo\":\"baz\",\"input\":\"{\\\"foo\\\":\\\"baz\\\", \\\"start\\\": true}\",\"notfoo\":\"baz\",\"start\":true}\n","notfoo":"foo","output":"baz\n","start":true}
`,
}).Equal(t, data)

val = data["output"].(string)
data = map[string]any{}
err = json.Unmarshal([]byte(val), &data)
require.NoError(t, err)
autogold.Expect(map[string]interface{}{
"chat": false,
"continuation": false,
"foo": "baz", "input": `{"foo":"baz","input":"{\"foo\":\"baz\", \"start\": true}","notfoo":"baz","start":true}
`,
"notfoo": "foo",
"output": "baz\n",
"start": true,
}).Equal(t, data)

val = data["input"].(string)
data = map[string]any{}
err = json.Unmarshal([]byte(val), &data)
require.NoError(t, err)
autogold.Expect(map[string]interface{}{
"foo": "baz",
"input": `{"foo":"baz", "start": true}`,
"notfoo": "baz",
"start": true,
}).Equal(t, data)

val = data["input"].(string)
data = map[string]any{}
err = json.Unmarshal([]byte(val), &data)
require.NoError(t, err)
autogold.Expect(map[string]interface{}{"foo": "baz", "start": true}).Equal(t, data)
}
6 changes: 6 additions & 0 deletions pkg/tests/testdata/TestFilterArgs/step1.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
`{
"done": true,
"content": "{\"chat\":false,\"continuation\":false,\"notfoo\":\"baz\",\"output\":\"{\\\"chat\\\":false,\\\"continuation\\\":false,\\\"notfoo\\\":\\\"foo\\\",\\\"output\\\":\\\"{\\\\\\\"chat\\\\\\\":false,\\\\\\\"continuation\\\\\\\":false,\\\\\\\"foo\\\\\\\":\\\\\\\"baz\\\\\\\",\\\\\\\"input\\\\\\\":\\\\\\\"{\\\\\\\\\\\\\\\"foo\\\\\\\\\\\\\\\":\\\\\\\\\\\\\\\"baz\\\\\\\\\\\\\\\",\\\\\\\\\\\\\\\"input\\\\\\\\\\\\\\\":\\\\\\\\\\\\\\\"{\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\"foo\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\":\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\"baz\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\", \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\"start\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\": true}\\\\\\\\\\\\\\\",\\\\\\\\\\\\\\\"notfoo\\\\\\\\\\\\\\\":\\\\\\\\\\\\\\\"baz\\\\\\\\\\\\\\\",\\\\\\\\\\\\\\\"start\\\\\\\\\\\\\\\":true}\\\\\\\\n\\\\\\\",\\\\\\\"notfoo\\\\\\\":\\\\\\\"foo\\\\\\\",\\\\\\\"output\\\\\\\":\\\\\\\"baz\\\\\\\\n\\\\\\\",\\\\\\\"start\\\\\\\":true}\\\\n\\\"}\\n\"}\n",
"toolID": "",
"state": null
}`

0 comments on commit 3029cf7

Please sign in to comment.