From 58fdac78a41e0b49c4957c601bc21e838f6684aa Mon Sep 17 00:00:00 2001 From: venjiang Date: Thu, 15 Aug 2024 20:11:34 +0800 Subject: [PATCH 1/5] function call on stream --- examples/go-chat/main.go | 12 +++++--- examples/go-generate/main.go | 7 +++++ openai/openai.go | 45 +++++++++++++++++++++++++--- server/routes.go | 58 +++++++++++++++++++++++++++++++++++- 4 files changed, 113 insertions(+), 9 deletions(-) diff --git a/examples/go-chat/main.go b/examples/go-chat/main.go index 7663fb8f4ee..b225954fd09 100644 --- a/examples/go-chat/main.go +++ b/examples/go-chat/main.go @@ -15,19 +15,19 @@ func main() { } messages := []api.Message{ - api.Message{ + { Role: "system", Content: "Provide very brief, concise responses", }, - api.Message{ + { Role: "user", Content: "Name some unusual animals", }, - api.Message{ + { Role: "assistant", Content: "Monotreme, platypus, echidna", }, - api.Message{ + { Role: "user", Content: "which of these is the most dangerous?", }, @@ -35,7 +35,11 @@ func main() { ctx := context.Background() req := &api.ChatRequest{ +<<<<<<< ours Model: "llama3.1", +======= + Model: "llama3ch", +>>>>>>> theirs Messages: messages, } diff --git a/examples/go-generate/main.go b/examples/go-generate/main.go index 2fe28742b5a..5aae4f12034 100644 --- a/examples/go-generate/main.go +++ b/examples/go-generate/main.go @@ -15,8 +15,15 @@ func main() { } req := &api.GenerateRequest{ +<<<<<<< ours Model: "gemma2", Prompt: "how many planets are there?", +======= + // Model: "gemma", + Model: "llama3ch", + // Prompt: "how many planets are there?", + Prompt: "有多少行星?", +>>>>>>> theirs // set streaming to false Stream: new(bool), diff --git a/openai/openai.go b/openai/openai.go index bda42b4da3d..468ce4d118d 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -32,8 +32,8 @@ type ErrorResponse struct { } type Message struct { - Role string `json:"role"` - Content any `json:"content"` + Role string `json:"role,omitempty"` + Content any `json:"content,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` } @@ -45,7 +45,7 @@ type Choice struct { type ChunkChoice struct { Index int `json:"index"` - Delta Message `json:"delta"` + Delta Message `json:"delta,omitempty"` FinishReason *string `json:"finish_reason"` } @@ -139,6 +139,8 @@ type CompletionChunk struct { } type ToolCall struct { + // Index is not nil only in chat completion chunk object + Index *int `json:"index,omitempty"` ID string `json:"id"` Type string `json:"type"` Function struct { @@ -243,7 +245,31 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { } } +// TODO: 修改这里兼容 /chal/completion stream func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { + toolCalls := make([]ToolCall, len(r.Message.ToolCalls)) + for i, tc := range r.Message.ToolCalls { + idx := i + toolCalls[i].Index = &idx + toolCalls[i].ID = toolCallId() + toolCalls[i].Type = "function" + toolCalls[i].Function.Name = tc.Function.Name + + args, err := json.Marshal(tc.Function.Arguments) + if err != nil { + slog.Error("could not marshall function arguments to json", "error", err) + continue + } + + toolCalls[i].Function.Arguments = string(args) + } + slog.Warn("toChunk", "toolCalls", toolCalls) + + message := Message{Role: "assistant", Content: r.Message.Content} + hasToolCalls := len(toolCalls) > 0 + if hasToolCalls { + message = Message{ToolCalls: toolCalls} + } return ChatCompletionChunk{ Id: id, Object: "chat.completion.chunk", @@ -252,8 +278,12 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { SystemFingerprint: "fp_ollama", Choices: []ChunkChoice{{ Index: 0, - Delta: Message{Role: "assistant", Content: r.Message.Content}, + // Delta: Message{Role: "assistant", Content: r.Message.Content}, + Delta: message, FinishReason: func(reason string) *string { + // if hasToolCalls { + // reason = "tool_calls" + // } if len(reason) > 0 { return &reason } @@ -588,6 +618,7 @@ func (w *BaseWriter) writeError(code int, data []byte) (int, error) { } func (w *ChatWriter) writeResponse(data []byte) (int, error) { + // slog.Warn("writeResponse", "data", string(data)) var chatResponse api.ChatResponse err := json.Unmarshal(data, &chatResponse) if err != nil { @@ -596,10 +627,12 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) { // chat chunk if w.stream { + // slog.Warn("writeResponse chunk", "resp", chatResponse) d, err := json.Marshal(toChunk(w.id, chatResponse)) if err != nil { return 0, err } + // slog.Warn("writeResponse chunk", "data", string(d)) w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) @@ -610,9 +643,13 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) { if chatResponse.Done { _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) if err != nil { + slog.Error("writeResponse done", "err", err) return 0, err } } + slog.Warn("writeResponse chunk", "done", chatResponse.Done, "len", len(data)) + + // slog.Warn("writeResponse stream", "done", chatResponse.Done, "data", string(data), "len", len(data)) return len(data), nil } diff --git a/server/routes.go b/server/routes.go index 6c470c174c1..53eee7f03ec 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1362,9 +1362,14 @@ func (s *Server) ChatHandler(c *gin.Context) { slog.Debug("chat request", "images", len(images), "prompt", prompt) + toolCallsCh := make(chan []api.ToolCall, 1) + doneCh := make(chan bool, 1) ch := make(chan any) go func() { + var sb strings.Builder defer close(ch) + defer close(toolCallsCh) + defer close(doneCh) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, @@ -1384,8 +1389,21 @@ func (s *Server) ChatHandler(c *gin.Context) { EvalDuration: r.EvalDuration, }, } + sb.WriteString(r.Content) if r.Done { + doneCh <- true + if len(req.Tools) > 0 { + slog.Warn("ollama resp", "content", sb.String()) + if toolCalls, ok := m.parseToolCalls(sb.String()); ok { + slog.Warn("ollama resp", "tool_calls", fmt.Sprintf("%+v", toolCalls)) + // res.Message.ToolCalls = toolCalls + // go func() { + toolCallsCh <- toolCalls + // }() + } + } + res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } @@ -1419,7 +1437,7 @@ func (s *Server) ChatHandler(c *gin.Context) { } resp.Message.Content = sb.String() - + // TODO: 可以简化 if len(req.Tools) > 0 { if toolCalls, ok := m.parseToolCalls(sb.String()); ok { resp.Message.ToolCalls = toolCalls @@ -1431,6 +1449,44 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + // change stream response with tool calls + // slog.Warn("tool_calls", "len", len(toolCallsCh), "toolCallsResp", fmt.Sprintf("%+v", toolCallsCh)) + var toolCalls []api.ToolCall +LOOP: + for { + select { + case tc := <-toolCallsCh: + toolCalls = tc + slog.Warn("tool_calls-1", "len", len(toolCalls), "toolCallsResp", fmt.Sprintf("%+v", toolCalls)) + break LOOP + + case done, ok := <-doneCh: + slog.Warn("done", "done", done, "ok", ok) + case <-ch: + continue + } + } + slog.Warn("tool_calls-2", "len", len(toolCalls), "toolCallsResp", fmt.Sprintf("%+v", toolCalls)) + if len(toolCalls) > 0 { + // reset the channel + ch = make(chan any, len(toolCalls)) + for i, toolCall := range toolCalls { + res := api.ChatResponse{ + Model: req.Model, + CreatedAt: time.Now().UTC(), + Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{toolCall}}, + } + if i == len(toolCalls)-1 { + // TODO: metrics + res.Done = true + res.DoneReason = "tool_calls" + } + slog.Warn("add tool calls resp", "resp", fmt.Sprintf("%+v", res)) + ch <- res + } + close(ch) + } + streamResponse(c, ch) } From 5f4d3c351a6cd9e061147fa90034083365dc98e1 Mon Sep 17 00:00:00 2001 From: venjiang Date: Tue, 20 Aug 2024 19:32:09 +0800 Subject: [PATCH 2/5] it's work --- openai/openai.go | 8 ----- server/routes.go | 91 +++++++++++++++++++++++++----------------------- 2 files changed, 48 insertions(+), 51 deletions(-) diff --git a/openai/openai.go b/openai/openai.go index 468ce4d118d..1c1910791c1 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -245,7 +245,6 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { } } -// TODO: 修改这里兼容 /chal/completion stream func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { toolCalls := make([]ToolCall, len(r.Message.ToolCalls)) for i, tc := range r.Message.ToolCalls { @@ -263,7 +262,6 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { toolCalls[i].Function.Arguments = string(args) } - slog.Warn("toChunk", "toolCalls", toolCalls) message := Message{Role: "assistant", Content: r.Message.Content} hasToolCalls := len(toolCalls) > 0 @@ -618,7 +616,6 @@ func (w *BaseWriter) writeError(code int, data []byte) (int, error) { } func (w *ChatWriter) writeResponse(data []byte) (int, error) { - // slog.Warn("writeResponse", "data", string(data)) var chatResponse api.ChatResponse err := json.Unmarshal(data, &chatResponse) if err != nil { @@ -627,12 +624,10 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) { // chat chunk if w.stream { - // slog.Warn("writeResponse chunk", "resp", chatResponse) d, err := json.Marshal(toChunk(w.id, chatResponse)) if err != nil { return 0, err } - // slog.Warn("writeResponse chunk", "data", string(d)) w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) @@ -647,9 +642,6 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) { return 0, err } } - slog.Warn("writeResponse chunk", "done", chatResponse.Done, "len", len(data)) - - // slog.Warn("writeResponse stream", "done", chatResponse.Done, "data", string(data), "len", len(data)) return len(data), nil } diff --git a/server/routes.go b/server/routes.go index 53eee7f03ec..733826b0b64 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1363,13 +1363,11 @@ func (s *Server) ChatHandler(c *gin.Context) { slog.Debug("chat request", "images", len(images), "prompt", prompt) toolCallsCh := make(chan []api.ToolCall, 1) - doneCh := make(chan bool, 1) ch := make(chan any) go func() { var sb strings.Builder defer close(ch) defer close(toolCallsCh) - defer close(doneCh) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, @@ -1392,15 +1390,11 @@ func (s *Server) ChatHandler(c *gin.Context) { sb.WriteString(r.Content) if r.Done { - doneCh <- true if len(req.Tools) > 0 { - slog.Warn("ollama resp", "content", sb.String()) + // slog.Warn("ollama resp", "content", sb.String()) if toolCalls, ok := m.parseToolCalls(sb.String()); ok { - slog.Warn("ollama resp", "tool_calls", fmt.Sprintf("%+v", toolCalls)) - // res.Message.ToolCalls = toolCalls - // go func() { + // slog.Warn("ollama resp", "tool_calls", fmt.Sprintf("%+v", toolCalls)) toolCallsCh <- toolCalls - // }() } } @@ -1437,12 +1431,12 @@ func (s *Server) ChatHandler(c *gin.Context) { } resp.Message.Content = sb.String() - // TODO: 可以简化 if len(req.Tools) > 0 { - if toolCalls, ok := m.parseToolCalls(sb.String()); ok { - resp.Message.ToolCalls = toolCalls - resp.Message.Content = "" - } + // if toolCalls, ok := m.parseToolCalls(sb.String()); ok { + toolCalls := <-toolCallsCh + resp.Message.ToolCalls = toolCalls + resp.Message.Content = "" + // } } c.JSON(http.StatusOK, resp) @@ -1451,43 +1445,54 @@ func (s *Server) ChatHandler(c *gin.Context) { // change stream response with tool calls // slog.Warn("tool_calls", "len", len(toolCallsCh), "toolCallsResp", fmt.Sprintf("%+v", toolCallsCh)) - var toolCalls []api.ToolCall -LOOP: - for { - select { - case tc := <-toolCallsCh: - toolCalls = tc - slog.Warn("tool_calls-1", "len", len(toolCalls), "toolCallsResp", fmt.Sprintf("%+v", toolCalls)) - break LOOP - - case done, ok := <-doneCh: - slog.Warn("done", "done", done, "ok", ok) - case <-ch: - continue + streamCh := make(chan any) + for rr := range ch { + switch t := rr.(type) { + case api.ChatResponse: + go func() { + // slog.Warn("reassign chat response", "content", t.Message.Content) + streamCh <- t + }() + case gin.H: + msg, ok := t["error"].(string) + if !ok { + msg = "unexpected error format in response" + } + c.JSON(http.StatusInternalServerError, gin.H{"error": msg}) + return + default: + c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"}) + return } } - slog.Warn("tool_calls-2", "len", len(toolCalls), "toolCallsResp", fmt.Sprintf("%+v", toolCalls)) - if len(toolCalls) > 0 { - // reset the channel - ch = make(chan any, len(toolCalls)) - for i, toolCall := range toolCalls { + + // if request tool calls + if len(req.Tools) > 0 { + toolCalls := <-toolCallsCh + hasToolCalls := len(toolCalls) > 0 + if hasToolCalls { + // reset the channel + toolCallRespCh := make(chan any, 1) res := api.ChatResponse{ - Model: req.Model, - CreatedAt: time.Now().UTC(), - Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{toolCall}}, - } - if i == len(toolCalls)-1 { - // TODO: metrics - res.Done = true - res.DoneReason = "tool_calls" + Model: req.Model, + CreatedAt: time.Now().UTC(), + Message: api.Message{Role: "assistant", ToolCalls: toolCalls}, + Done: true, + DoneReason: "tool_calls", } - slog.Warn("add tool calls resp", "resp", fmt.Sprintf("%+v", res)) - ch <- res + toolCallRespCh <- res + close(toolCallRespCh) + slog.Info("#1 tool calls stream response") + streamResponse(c, toolCallRespCh) + return + } else { + slog.Info("#1 no tool calls") } - close(ch) } - streamResponse(c, ch) + slog.Info("#2 stream response") + defer close(streamCh) + streamResponse(c, streamCh) } func handleScheduleError(c *gin.Context, name string, err error) { From aab555d37147ce3e1f50f27ee8bd2df006b39435 Mon Sep 17 00:00:00 2001 From: venjiang Date: Wed, 21 Aug 2024 17:34:44 +0800 Subject: [PATCH 3/5] fix --- server/routes.go | 51 ++++++++++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/server/routes.go b/server/routes.go index 733826b0b64..3801b0b4d3e 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1363,11 +1363,14 @@ func (s *Server) ChatHandler(c *gin.Context) { slog.Debug("chat request", "images", len(images), "prompt", prompt) toolCallsCh := make(chan []api.ToolCall, 1) + contentCh := make(chan string, 1) + ch := make(chan any) go func() { var sb strings.Builder defer close(ch) defer close(toolCallsCh) + defer close(contentCh) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, @@ -1390,10 +1393,10 @@ func (s *Server) ChatHandler(c *gin.Context) { sb.WriteString(r.Content) if r.Done { + content := sb.String() + contentCh <- content if len(req.Tools) > 0 { - // slog.Warn("ollama resp", "content", sb.String()) - if toolCalls, ok := m.parseToolCalls(sb.String()); ok { - // slog.Warn("ollama resp", "tool_calls", fmt.Sprintf("%+v", toolCalls)) + if toolCalls, ok := m.parseToolCalls(content); ok { toolCallsCh <- toolCalls } } @@ -1408,13 +1411,15 @@ func (s *Server) ChatHandler(c *gin.Context) { } }() + toolsRequired := len(req.Tools) > 0 + // no stream response if req.Stream != nil && !*req.Stream { var resp api.ChatResponse - var sb strings.Builder + // var sb strings.Builder for rr := range ch { switch t := rr.(type) { case api.ChatResponse: - sb.WriteString(t.Message.Content) + // sb.WriteString(t.Message.Content) resp = t case gin.H: msg, ok := t["error"].(string) @@ -1430,11 +1435,12 @@ func (s *Server) ChatHandler(c *gin.Context) { } } - resp.Message.Content = sb.String() - if len(req.Tools) > 0 { + // resp.Message.Content = sb.String() + content := <-contentCh + resp.Message.Content = content + if toolsRequired { // if toolCalls, ok := m.parseToolCalls(sb.String()); ok { - toolCalls := <-toolCallsCh - resp.Message.ToolCalls = toolCalls + resp.Message.ToolCalls = <-toolCallsCh resp.Message.Content = "" // } } @@ -1443,8 +1449,7 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - // change stream response with tool calls - // slog.Warn("tool_calls", "len", len(toolCallsCh), "toolCallsResp", fmt.Sprintf("%+v", toolCallsCh)) + // stream response streamCh := make(chan any) for rr := range ch { switch t := rr.(type) { @@ -1452,6 +1457,10 @@ func (s *Server) ChatHandler(c *gin.Context) { go func() { // slog.Warn("reassign chat response", "content", t.Message.Content) streamCh <- t + if t.Done { + // slog.Warn("close stream channel") + close(streamCh) + } }() case gin.H: msg, ok := t["error"].(string) @@ -1466,13 +1475,14 @@ func (s *Server) ChatHandler(c *gin.Context) { } } - // if request tool calls - if len(req.Tools) > 0 { + // if tools are required + if toolsRequired { toolCalls := <-toolCallsCh + // if tool calls are present, use different channel respose hasToolCalls := len(toolCalls) > 0 if hasToolCalls { // reset the channel - toolCallRespCh := make(chan any, 1) + toolCallsCh := make(chan any, 1) res := api.ChatResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), @@ -1480,18 +1490,17 @@ func (s *Server) ChatHandler(c *gin.Context) { Done: true, DoneReason: "tool_calls", } - toolCallRespCh <- res - close(toolCallRespCh) - slog.Info("#1 tool calls stream response") - streamResponse(c, toolCallRespCh) + toolCallsCh <- res + close(toolCallsCh) + slog.Info("[tools] stream response") + streamResponse(c, toolCallsCh) return } else { - slog.Info("#1 no tool calls") + slog.Info("[tools] no call") } } - slog.Info("#2 stream response") - defer close(streamCh) + slog.Info("stream response") streamResponse(c, streamCh) } From e990dec90fd0255aa6cbe9a0a82d2ede4d560f68 Mon Sep 17 00:00:00 2001 From: venjiang Date: Wed, 21 Aug 2024 18:16:41 +0800 Subject: [PATCH 4/5] clean --- examples/go-chat/main.go | 4 ---- examples/go-generate/main.go | 7 ------- 2 files changed, 11 deletions(-) diff --git a/examples/go-chat/main.go b/examples/go-chat/main.go index b225954fd09..bdbd2ae645b 100644 --- a/examples/go-chat/main.go +++ b/examples/go-chat/main.go @@ -35,11 +35,7 @@ func main() { ctx := context.Background() req := &api.ChatRequest{ -<<<<<<< ours Model: "llama3.1", -======= - Model: "llama3ch", ->>>>>>> theirs Messages: messages, } diff --git a/examples/go-generate/main.go b/examples/go-generate/main.go index 5aae4f12034..2fe28742b5a 100644 --- a/examples/go-generate/main.go +++ b/examples/go-generate/main.go @@ -15,15 +15,8 @@ func main() { } req := &api.GenerateRequest{ -<<<<<<< ours Model: "gemma2", Prompt: "how many planets are there?", -======= - // Model: "gemma", - Model: "llama3ch", - // Prompt: "how many planets are there?", - Prompt: "有多少行星?", ->>>>>>> theirs // set streaming to false Stream: new(bool), From 79379daab4075289e441961ee912d5393c290ce1 Mon Sep 17 00:00:00 2001 From: venjiang Date: Mon, 26 Aug 2024 16:27:44 +0800 Subject: [PATCH 5/5] fix: when tools exists, but no call function response is empty --- server/routes.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/server/routes.go b/server/routes.go index 3801b0b4d3e..5230378b187 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1415,11 +1415,9 @@ func (s *Server) ChatHandler(c *gin.Context) { // no stream response if req.Stream != nil && !*req.Stream { var resp api.ChatResponse - // var sb strings.Builder for rr := range ch { switch t := rr.(type) { case api.ChatResponse: - // sb.WriteString(t.Message.Content) resp = t case gin.H: msg, ok := t["error"].(string) @@ -1435,14 +1433,14 @@ func (s *Server) ChatHandler(c *gin.Context) { } } - // resp.Message.Content = sb.String() content := <-contentCh resp.Message.Content = content if toolsRequired { - // if toolCalls, ok := m.parseToolCalls(sb.String()); ok { - resp.Message.ToolCalls = <-toolCallsCh - resp.Message.Content = "" - // } + toolCalls := <-toolCallsCh + if len(toolCalls) > 0 { + resp.Message.ToolCalls = toolCalls + resp.Message.Content = "" + } } c.JSON(http.StatusOK, resp)