Skip to content

Commit

Permalink
chore: add multiline parsing and refactor share cred behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
ibuildthecloud committed Oct 9, 2024
1 parent cc5e5ed commit 1b7c477
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 38 deletions.
1 change: 0 additions & 1 deletion pkg/config/cliconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ func (a *AuthConfig) UnmarshalJSON(data []byte) error {
type CLIConfig struct {
Auths map[string]AuthConfig `json:"auths,omitempty"`
CredentialsStore string `json:"credsStore,omitempty"`
GatewayURL string `json:"gatewayURL,omitempty"`
Integrations map[string]string `json:"integrations,omitempty"`

auths map[string]types.AuthConfig
Expand Down
2 changes: 1 addition & 1 deletion pkg/credentials/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func validateCredentialCtx(ctxs []string) error {
}

// check alphanumeric
r := regexp.MustCompile("^[a-zA-Z0-9]+$")
r := regexp.MustCompile("^[-a-zA-Z0-9]+$")
for _, c := range ctxs {
if !r.MatchString(c) {
return fmt.Errorf("credential contexts must be alphanumeric")
Expand Down
109 changes: 88 additions & 21 deletions pkg/parser/parser.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package parser

import (
"bufio"
"fmt"
"io"
"maps"
Expand All @@ -17,8 +16,10 @@ import (

var (
sepRegex = regexp.MustCompile(`^\s*---+\s*$`)
endHeaderRegex = regexp.MustCompile(`^\s*===+\s*$`)
strictSepRegex = regexp.MustCompile(`^---\n$`)
skipRegex = regexp.MustCompile(`^![-.:*\w]+\s*$`)
nameRegex = regexp.MustCompile(`^[a-z]+$`)
)

func normalize(key string) string {
Expand Down Expand Up @@ -56,6 +57,9 @@ func addArg(line string, tool *types.Tool) error {
tool.Parameters.Arguments = &openapi3.Schema{
Type: &openapi3.Types{"object"},
Properties: openapi3.Schemas{},
AdditionalProperties: openapi3.AdditionalProperties{
Has: new(bool),
},
}
}

Expand All @@ -74,7 +78,7 @@ func addArg(line string, tool *types.Tool) error {
return nil
}

func isParam(line string, tool *types.Tool) (_ bool, err error) {
func isParam(line string, tool *types.Tool, scan *simplescanner) (_ bool, err error) {
key, value, ok := strings.Cut(line, ":")
if !ok {
return false, nil
Expand All @@ -90,7 +94,7 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
case "globalmodel", "globalmodelname":
tool.Parameters.GlobalModelName = value
case "description":
tool.Parameters.Description = value
tool.Parameters.Description = scan.AddMultiline(value)
case "internalprompt":
v, err := toBool(value)
if err != nil {
Expand All @@ -104,27 +108,33 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
}
tool.Parameters.Chat = v
case "export", "exporttool", "exports", "exporttools", "sharetool", "sharetools", "sharedtool", "sharedtools":
tool.Parameters.Export = append(tool.Parameters.Export, csv(value)...)
tool.Parameters.Export = append(tool.Parameters.Export, csv(scan.AddMultiline(value))...)
case "tool", "tools":
tool.Parameters.Tools = append(tool.Parameters.Tools, csv(value)...)
tool.Parameters.Tools = append(tool.Parameters.Tools, csv(scan.AddMultiline(value))...)
case "inputfilter", "inputfilters":
tool.Parameters.InputFilters = append(tool.Parameters.InputFilters, csv(value)...)
tool.Parameters.InputFilters = append(tool.Parameters.InputFilters, csv(scan.AddMultiline(value))...)
case "shareinputfilter", "shareinputfilters", "sharedinputfilter", "sharedinputfilters":
tool.Parameters.ExportInputFilters = append(tool.Parameters.ExportInputFilters, csv(value)...)
tool.Parameters.ExportInputFilters = append(tool.Parameters.ExportInputFilters, csv(scan.AddMultiline(value))...)
case "outputfilter", "outputfilters":
tool.Parameters.OutputFilters = append(tool.Parameters.OutputFilters, csv(value)...)
tool.Parameters.OutputFilters = append(tool.Parameters.OutputFilters, csv(scan.AddMultiline(value))...)
case "shareoutputfilter", "shareoutputfilters", "sharedoutputfilter", "sharedoutputfilters":
tool.Parameters.ExportOutputFilters = append(tool.Parameters.ExportOutputFilters, csv(value)...)
tool.Parameters.ExportOutputFilters = append(tool.Parameters.ExportOutputFilters, csv(scan.AddMultiline(value))...)
case "agent", "agents":
tool.Parameters.Agents = append(tool.Parameters.Agents, csv(value)...)
tool.Parameters.Agents = append(tool.Parameters.Agents, csv(scan.AddMultiline(value))...)
case "globaltool", "globaltools":
tool.Parameters.GlobalTools = append(tool.Parameters.GlobalTools, csv(value)...)
tool.Parameters.GlobalTools = append(tool.Parameters.GlobalTools, csv(scan.AddMultiline(value))...)
case "exportcontext", "exportcontexts", "sharecontext", "sharecontexts", "sharedcontext", "sharedcontexts":
tool.Parameters.ExportContext = append(tool.Parameters.ExportContext, csv(value)...)
tool.Parameters.ExportContext = append(tool.Parameters.ExportContext, csv(scan.AddMultiline(value))...)
case "context":
tool.Parameters.Context = append(tool.Parameters.Context, csv(value)...)
tool.Parameters.Context = append(tool.Parameters.Context, csv(scan.AddMultiline(value))...)
case "metadata":
mkey, mvalue, _ := strings.Cut(scan.AddMultiline(value), ":")
if tool.MetaData == nil {
tool.MetaData = map[string]string{}
}
tool.MetaData[strings.TrimSpace(mkey)] = strings.TrimSpace(mvalue)
case "args", "arg", "param", "params", "parameters", "parameter":
if err := addArg(value, tool); err != nil {
if err := addArg(scan.AddMultiline(value), tool); err != nil {
return false, err
}
case "maxtoken", "maxtokens":
Expand All @@ -149,13 +159,13 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
return false, err
}
case "credentials", "creds", "credential", "cred":
tool.Parameters.Credentials = append(tool.Parameters.Credentials, value)
tool.Parameters.Credentials = append(tool.Parameters.Credentials, csv(scan.AddMultiline(value))...)
case "sharecredentials", "sharecreds", "sharecredential", "sharecred", "sharedcredentials", "sharedcreds", "sharedcredential", "sharedcred":
tool.Parameters.ExportCredentials = append(tool.Parameters.ExportCredentials, value)
tool.Parameters.ExportCredentials = append(tool.Parameters.ExportCredentials, scan.AddMultiline(value))
case "type":
tool.Type = types.ToolType(strings.ToLower(value))
default:
return false, nil
return nameRegex.MatchString(key), nil
}

return true, nil
Expand Down Expand Up @@ -206,6 +216,7 @@ func (c *context) finish(tools *[]Node) {
len(c.tool.ExportInputFilters) > 0 ||
len(c.tool.ExportOutputFilters) > 0 ||
len(c.tool.Agents) > 0 ||
len(c.tool.ExportCredentials) > 0 ||
c.tool.Chat {
*tools = append(*tools, Node{
ToolNode: &ToolNode{
Expand Down Expand Up @@ -391,7 +402,10 @@ func assignMetadata(nodes []Node) (result []Node) {

for _, node := range nodes {
if node.ToolNode != nil {
node.ToolNode.Tool.MetaData = metadata[node.ToolNode.Tool.Name]
if node.ToolNode.Tool.MetaData == nil {
node.ToolNode.Tool.MetaData = map[string]string{}
}
maps.Copy(node.ToolNode.Tool.MetaData, metadata[node.ToolNode.Tool.Name])
for wildcard := range metadata {
if strings.Contains(wildcard, "*") {
if m, err := path.Match(wildcard, node.ToolNode.Tool.Name); m && err == nil {
Expand Down Expand Up @@ -433,15 +447,64 @@ func isGPTScriptHashBang(line string) bool {
return false
}

func parse(input io.Reader) ([]Node, error) {
scan := bufio.NewScanner(input)
type simplescanner struct {
lines []string
}

func newSimpleScanner(data []byte) *simplescanner {
if len(data) == 0 {
return &simplescanner{}
}
lines := strings.Split(string(data), "\n")
return &simplescanner{
lines: append([]string{""}, lines...),
}
}

func (s *simplescanner) AddMultiline(current string) string {
result := current
for {
if len(s.lines) < 2 || len(s.lines[1]) == 0 {
return result
}
if strings.HasPrefix(s.lines[1], " ") || strings.HasPrefix(s.lines[1], "\t") {
result += " " + s.lines[1]
s.lines = s.lines[1:]
} else {
return result
}
}
}

func (s *simplescanner) Text() string {
if len(s.lines) == 0 {
return ""
}
return s.lines[0]
}

func (s *simplescanner) Scan() bool {
if len(s.lines) == 0 {
return false
}
s.lines = s.lines[1:]
return true
}

func parse(input io.Reader) ([]Node, error) {
var (
tools []Node
context context
lineNo int
)

data, err := io.ReadAll(input)
if err != nil {
return nil, err
}

scan := newSimpleScanner(data)

for scan.Scan() {
lineNo++
if context.tool.Source.LineNo == 0 {
Expand Down Expand Up @@ -488,11 +551,15 @@ func parse(input io.Reader) ([]Node, error) {
}

// Look for params
if isParam, err := isParam(line, &context.tool); err != nil {
if isParam, err := isParam(line, &context.tool, scan); err != nil {
return nil, NewErrLine("", lineNo, err)
} else if isParam {
context.seenParam = true
continue
} else if endHeaderRegex.MatchString(line) {
// force the end of the header and don't include the current line in the header
context.inBody = true
continue
}
}

Expand Down
91 changes: 91 additions & 0 deletions pkg/parser/parser_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package parser

import (
"reflect"
"strings"
"testing"

Expand Down Expand Up @@ -244,6 +245,7 @@ share output filters: shared
func TestParseMetaData(t *testing.T) {
input := `
name: first
metadata: foo: bar
body
---
Expand All @@ -269,8 +271,97 @@ foo bar

assert.Len(t, tools, 1)
autogold.Expect(map[string]string{
"foo": "bar",
"package.json": "foo=base\nf",
"requirements.txt": "asdf",
"other": "foo bar",
}).Equal(t, tools[0].MetaData)

autogold.Expect(`Name: first
Meta Data: foo: bar
Meta Data: other: foo bar
Meta Data: requirements.txt: asdf
body
---
!metadata:first:package.json
foo=base
f
`).Equal(t, tools[0].String())
}

func TestFormatWithBadInstruction(t *testing.T) {
input := types.Tool{
ToolDef: types.ToolDef{
Parameters: types.Parameters{
Name: "foo",
},
Instructions: "foo: bar",
},
}
autogold.Expect("Name: foo\n===\nfoo: bar\n").Equal(t, input.String())

tools, err := ParseTools(strings.NewReader(input.String()))
require.NoError(t, err)
if reflect.DeepEqual(input, tools[0]) {
t.Errorf("expected %v, got %v", input, tools[0])
}
}

func TestSingleTool(t *testing.T) {
input := `
name: foo
#!sys.echo
hi
`

tools, err := ParseTools(strings.NewReader(input))
require.NoError(t, err)
autogold.Expect(types.Tool{
ToolDef: types.ToolDef{
Parameters: types.Parameters{
Name: "first",
ModelName: "the model",
Credentials: []string{
"foo",
"bar",
"baz",
},
},
Instructions: "body",
},
Source: types.ToolSource{LineNo: 1},
}).Equal(t, tools[0])
}

func TestMultiline(t *testing.T) {
input := `
name: first
credential: foo
, bar,
baz
model: the model
body
`
tools, err := ParseTools(strings.NewReader(input))
require.NoError(t, err)

assert.Len(t, tools, 1)
autogold.Expect(types.Tool{
ToolDef: types.ToolDef{
Parameters: types.Parameters{
Name: "first",
ModelName: "the model",
Credentials: []string{
"foo",
"bar",
"baz",
},
},
Instructions: "body",
},
Source: types.ToolSource{LineNo: 1},
}).Equal(t, tools[0])
}
32 changes: 32 additions & 0 deletions pkg/tests/runner2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,35 @@ echo ${FOO}:${INPUT}
resp, err = r.Chat(context.Background(), nil, prg, nil, `"foo":"123"}`)
r.AssertStep(t, resp, err)
}

func TestShareCreds(t *testing.T) {
r := tester.NewRunner(t)
prg, err := loader.ProgramFromSource(context.Background(), `
creds: foo
#!/bin/bash
echo $CRED
echo $CRED2
---
name: foo
share credentials: bar
---
name: bar
share credentials: baz
#!/bin/bash
echo '{"env": {"CRED": "that worked"}}'
---
name: baz
#!/bin/bash
echo '{"env": {"CRED2": "that also worked"}}'
`, "")
require.NoError(t, err)

resp, err := r.Chat(context.Background(), nil, prg, nil, "")
r.AssertStep(t, resp, err)
}
3 changes: 1 addition & 2 deletions pkg/types/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"
"strings"

"github.com/fatih/color"
"github.com/getkin/kin-openapi/openapi3"
)

Expand Down Expand Up @@ -112,7 +111,7 @@ func (c CompletionMessage) String() string {
}
buf.WriteString(content.Text)
if content.ToolCall != nil {
buf.WriteString(fmt.Sprintf("<tool call> %s -> %s", color.GreenString(content.ToolCall.Function.Name), content.ToolCall.Function.Arguments))
buf.WriteString(fmt.Sprintf("<tool call> %s -> %s", content.ToolCall.Function.Name, content.ToolCall.Function.Arguments))
}
}
return buf.String()
Expand Down
Loading

0 comments on commit 1b7c477

Please sign in to comment.