-
Notifications
You must be signed in to change notification settings - Fork 0
/
chat_cmd.go
178 lines (153 loc) · 4.01 KB
/
chat_cmd.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
package main
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"os"
"strings"
"time"
"github.com/yiblet/hlp/parse"
)
type chatCmd struct {
File string `arg:"required,positional" help:"the input chat file, if you pass - the command will read from stdin"`
Write *string `arg:"positional" help:"the output chat file, if you pass - the output will be the same as input"`
MaxTokens int `arg:"--tokens,-t" default:"0" help:"the maximum amount of tokens allowed in the output"`
Temperature float32 `-arg:"--temp" default:"0.7"`
Color bool `default:"false"`
Model string `arg:"--model,-m" help:"set openai model"`
}
// appendChatFile appends a new chat response to the specified file.
// The role parameter must be one of "system", "assistant", or "user".
// The content parameter should contain the text of the chat response.
// The file will be created if it doesn't exist, and the new chat response
// will be appended to the existing content in the file.
//
// The format of the appended chat response will be:
//
// --- role
// content
//
// The function returns an error if the role is invalid or if there's an issue
// while opening or writing to the file.
func (args *chatCmd) appendChatFile(writer io.Writer, role, content string) error {
if err := parse.ValidateRole(role); err != nil {
return err
}
buf := bufio.NewWriter(writer)
if _, err := fmt.Fprintf(buf, "--- %s\n%s\n", role, content); err != nil {
return err
}
return buf.Flush()
}
func (args *chatCmd) writeTo(
input string,
content string,
writer io.Writer,
) error {
// generate output for writing
output := bufio.NewWriter(writer)
if _, err := output.WriteString(input); err != nil {
return err
}
output.WriteRune('\n')
if input[len(input)-1] != '\n' { // add an extra line if needed
output.WriteRune('\n')
}
if err := args.appendChatFile(output, "assistant", content); err != nil {
return err
}
return output.Flush()
}
func (args *chatCmd) write(
input string,
content string,
) error {
if args.Write == nil {
return nil
}
outfile := *args.Write
if outfile == "-" {
if args.File == "-" {
return fmt.Errorf("cannot output to stdin")
}
outfile = args.File
}
file, err := os.OpenFile(outfile, os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
return err
}
defer file.Close()
if err != nil {
return err
}
return args.writeTo(input, content, file)
}
func (args *chatCmd) outputWriter() (io.Writer, func() error) {
var outputWriter io.Writer
var close func() error
if !args.Color {
outputWriter = os.Stdout
close = func() error { return nil }
} else {
outputWriter, close = getOutputWriter()
}
return outputWriter, close
}
func (args *chatCmd) readAll(reader io.Reader) error {
var buf [4096]byte
for {
_, err := reader.Read(buf[:])
if err != nil {
return err
}
}
}
func (args *chatCmd) Execute(ctx context.Context, config *config) error {
model := args.Model
if model == "" {
model = strings.TrimSpace(config.Model())
}
var err error
client := config.Client()
var file io.ReadCloser
if args.File != "-" {
file, err = os.Open(args.File)
if err != nil {
return err
}
defer file.Close()
} else {
file = os.Stdin
}
var inputContent strings.Builder
reader := io.TeeReader(file, &inputContent)
// Read and parse the file
messages, err := parse.ParseChatFile(reader)
if err != nil {
return err
}
if err := args.readAll(reader); err != nil && !errors.Is(err, io.EOF) {
return err
}
var outputContent strings.Builder
outputWriter, closeWriter := args.outputWriter()
defer closeWriter()
writer := io.MultiWriter(&outputContent, outputWriter)
// Call ChatCompletionStream with the parsed messages
err = aiStream(ctx, client, aiStreamInput{
Messages: messages,
MaxTokens: args.MaxTokens,
Temperature: &args.Temperature,
Model: model,
Timeout: 2 * time.Minute,
}, func(message string) error {
fmt.Fprint(writer, message)
return nil
})
if err != nil {
return err
}
return args.write(inputContent.String(), outputContent.String())
}