diff --git a/ask_cmd.go b/ask_cmd.go index ba4c3fa..d934c9b 100644 --- a/ask_cmd.go +++ b/ask_cmd.go @@ -3,6 +3,8 @@ package main import ( "context" "fmt" + "io" + "os" "strings" "github.com/PullRequestInc/go-gpt3" @@ -30,17 +32,48 @@ type askCmd struct { Temperature float32 `default:"0.7"` Bash bool `arg:"--bash" help:"output only valid bash"` Model string `arg:"--model,-m" help:"set openai model"` + Attach []string `arg:"--attach,-a,separate" help:"attach additional files at the end of the message"` } -func (args *askCmd) messages() []gpt3.ChatCompletionRequestMessage { +func (args *askCmd) buildContent(ctx context.Context) (string, error) { + var sb strings.Builder + for idx, q := range args.Question { + if idx != 0 { + sb.WriteRune(' ') + } + sb.WriteString(q) + } + + if len(args.Question) > 0 && + !strings.HasSuffix(args.Question[len(args.Question)-1], "\n") { + sb.WriteRune('\n') + } + + for _, a := range args.Attach { + sb.WriteRune('\n') + file, err := os.Open(a) + if err != nil { + return "", err + } + defer file.Close() + _, err = io.Copy(&sb, file) + if err != nil { + return "", err + } + } + + return sb.String(), nil +} + +func (args *askCmd) messages(content string) []gpt3.ChatCompletionRequestMessage { if args.Bash { return []gpt3.ChatCompletionRequestMessage{ {Role: "system", Content: systemMessage}, - {Role: "user", Content: strings.Join(args.Question, " ")}, + {Role: "user", Content: content}, } } else { return []gpt3.ChatCompletionRequestMessage{ - {Role: "system", Content: strings.Join(args.Question, " ")}, + {Role: "system", Content: content}, } } @@ -54,8 +87,12 @@ func (args *askCmd) Execute(ctx context.Context, config *config) error { client := config.Client() lastMessage := "" - err := client.ChatCompletionStream(ctx, gpt3.ChatCompletionRequest{ - Messages: args.messages(), + content, err := args.buildContent(ctx) + if err != nil { + return fmt.Errorf("cannot build message: %w", err) + } + err = client.ChatCompletionStream(ctx, gpt3.ChatCompletionRequest{ + Messages: args.messages(content), MaxTokens: args.MaxTokens, Temperature: &args.Temperature, Stream: true,