From edbfb28a2e618fca4b93acc2ece68f2bede1a5e7 Mon Sep 17 00:00:00 2001 From: Christopher Speller Date: Tue, 20 Aug 2024 20:38:14 +0200 Subject: [PATCH] Better behavour for anthropic with multiple messages and mixed bot usage --- server/ai/anthropic/anthropic.go | 15 +++++++++++++++ server/ai/conversation.go | 7 ------- server/post_processing.go | 7 ++++++- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/server/ai/anthropic/anthropic.go b/server/ai/anthropic/anthropic.go index 5a8ec407..1d08f0cf 100644 --- a/server/ai/anthropic/anthropic.go +++ b/server/ai/anthropic/anthropic.go @@ -32,10 +32,21 @@ func conversationToMessages(conversation ai.BotConversation) (string, []InputMes systemMessage := "" messages := make([]InputMessage, 0, len(conversation.Posts)) for _, post := range conversation.Posts { + previousRole := "" + previousContent := "" + if len(messages) > 0 { + previous := messages[len(messages)-1] + previousRole = previous.Role + previousContent = previous.Content + } switch post.Role { case ai.PostRoleSystem: systemMessage += post.Message case ai.PostRoleBot: + if previousRole == RoleAssistant { + previousContent += post.Message + continue + } messages = append(messages, InputMessage{ Role: RoleAssistant, @@ -43,6 +54,10 @@ func conversationToMessages(conversation ai.BotConversation) (string, []InputMes }, ) case ai.PostRoleUser: + if previousRole == RoleUser { + previousContent += post.Message + continue + } messages = append(messages, InputMessage{ Role: RoleUser, diff --git a/server/ai/conversation.go b/server/ai/conversation.go index a09f7dde..83eeb9ec 100644 --- a/server/ai/conversation.go +++ b/server/ai/conversation.go @@ -179,13 +179,6 @@ func (b *BotConversation) Truncate(maxTokens int, countTokens func(string) int) return false } -func GetPostRole(botID string, post *model.Post) PostRole { - if post.UserId == botID { - return PostRoleBot - } - return PostRoleUser -} - func FormatPostBody(post *model.Post) string { attachments := post.Attachments() if len(attachments) > 0 { diff --git a/server/post_processing.go b/server/post_processing.go index 5905ce18..3c72b138 100644 --- a/server/post_processing.go +++ b/server/post_processing.go @@ -404,8 +404,13 @@ func (p *Plugin) PostToAIPost(bot *Bot, post *model.Post) ai.Post { } } + role := ai.PostRoleUser + if p.IsAnyBot(post.UserId) { + role = ai.PostRoleBot + } + return ai.Post{ - Role: ai.GetPostRole(bot.mmBot.UserId, post), + Role: role, Message: ai.FormatPostBody(post), Files: files, }