Skip to content

Commit

Permalink
Merge pull request #87 from Leizhenpeng/feat_pic_varity
Browse files Browse the repository at this point in the history
feat : support image variation
  • Loading branch information
Leizhenpeng authored Mar 12, 2023
2 parents 0069ff1 + 6709c77 commit 9875b6e
Show file tree
Hide file tree
Showing 18 changed files with 994 additions and 424 deletions.
14 changes: 14 additions & 0 deletions code/handlers/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,17 @@ func parseFileKey(content string) string {
fileKey := contentMap["file_key"].(string)
return fileKey
}

func parseImageKey(content string) string {
var contentMap map[string]interface{}
err := json.Unmarshal([]byte(content), &contentMap)
if err != nil {
fmt.Println(err)
return ""
}
if contentMap["image_key"] == nil {
return ""
}
imageKey := contentMap["image_key"].(string)
return imageKey
}
67 changes: 59 additions & 8 deletions code/handlers/event_action.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"start-feishubot/initialization"
"start-feishubot/services"
"start-feishubot/services/openai"
"start-feishubot/utils"
"start-feishubot/utils/audio"
)
Expand All @@ -18,6 +19,7 @@ type MsgInfo struct {
chatId *string
qParsed string
fileKey string
imageKey string
sessionId *string
mention []*larkim.MentionEvent
}
Expand Down Expand Up @@ -93,7 +95,7 @@ func (*RolePlayAction) Execute(a *ActionInfo) bool {
if system, foundSystem := utils.EitherCutPrefix(a.info.qParsed,
"/system ", "角色扮演 "); foundSystem {
a.handler.sessionCache.Clear(*a.info.sessionId)
systemMsg := append([]services.Messages{}, services.Messages{
systemMsg := append([]openai.Messages{}, openai.Messages{
Role: "system", Content: system,
})
a.handler.sessionCache.SetMsg(*a.info.sessionId, systemMsg)
Expand Down Expand Up @@ -133,8 +135,59 @@ func (*PicAction) Execute(a *ActionInfo) bool {
return false
}

// 生成图片
mode := a.handler.sessionCache.GetMode(*a.info.sessionId)
//fmt.Println("mode: ", mode)

// 收到一张图片,且不在图片创作模式下, 提醒是否切换到图片创作模式
if a.info.msgType == "image" && mode != services.ModePicCreate {
sendPicModeCheckCard(*a.ctx, a.info.sessionId, a.info.msgId)
return false
}

if a.info.msgType == "image" && mode == services.ModePicCreate {
//保存图片
imageKey := a.info.imageKey
//fmt.Printf("fileKey: %s \n", imageKey)
msgId := a.info.msgId
//fmt.Println("msgId: ", *msgId)
req := larkim.NewGetMessageResourceReqBuilder().MessageId(
*msgId).FileKey(imageKey).Type("image").Build()
resp, err := initialization.GetLarkClient().Im.MessageResource.Get(context.Background(), req)
//fmt.Println(resp, err)
if err != nil {
//fmt.Println(err)
fmt.Sprintf("🤖️:图片下载失败,请稍后再试~\n 错误信息: %v", err)
return false
}

f := fmt.Sprintf("%s.png", imageKey)
resp.WriteFile(f)
defer os.Remove(f)
resolution := a.handler.sessionCache.GetPicResolution(*a.
info.sessionId)

openai.ConvertJpegToPNG(f)
openai.ConvertToRGBA(f, f)

//图片校验
err = openai.VerifyPngs([]string{f})
if err != nil {
replyMsg(*a.ctx, fmt.Sprintf("🤖️:无法解析图片,请发送原图并尝试重新操作~"),
a.info.msgId)
return false
}
bs64, err := a.handler.gpt.GenerateOneImageVariation(f, resolution)
if err != nil {
replyMsg(*a.ctx, fmt.Sprintf(
"🤖️:图片生成失败,请稍后再试~\n错误信息: %v", err), a.info.msgId)
return false
}
replayImagePlainByBase64(*a.ctx, bs64, a.info.msgId)
return false

}

// 生成图片
if mode == services.ModePicCreate {
resolution := a.handler.sessionCache.GetPicResolution(*a.
info.sessionId)
Expand All @@ -145,10 +198,8 @@ func (*PicAction) Execute(a *ActionInfo) bool {
"🤖️:图片生成失败,请稍后再试~\n错误信息: %v", err), a.info.msgId)
return false
}
replayImageByBase64(*a.ctx, bs64, a.info.msgId, a.info.sessionId,
replayImageCardByBase64(*a.ctx, bs64, a.info.msgId, a.info.sessionId,
a.info.qParsed)

//replayImageByBase64(*a.ctx, "", a.info.msgId, a.info.qParsed)
return false
}

Expand All @@ -160,7 +211,7 @@ type MessageAction struct { /*消息*/

func (*MessageAction) Execute(a *ActionInfo) bool {
msg := a.handler.sessionCache.GetMsg(*a.info.sessionId)
msg = append(msg, services.Messages{
msg = append(msg, openai.Messages{
Role: "user", Content: a.info.qParsed,
})
completions, err := a.handler.gpt.Completions(msg)
Expand Down Expand Up @@ -224,10 +275,10 @@ func (*AudioAction) Execute(a *ActionInfo) bool {
text, err := a.handler.gpt.AudioToText(output)
if err != nil {
fmt.Println(err)
sendMsg(*a.ctx, "🤖️:语音转换失败,请稍后再试~", a.info.msgId)

sendMsg(*a.ctx, fmt.Sprintf("🤖️:语音转换失败,请稍后再试~\n错误信息: %v", err), a.info.msgId)
return false
}
//删除文件
//fmt.Println("text: ", text)
a.info.qParsed = text
return true
Expand Down
78 changes: 68 additions & 10 deletions code/handlers/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"start-feishubot/initialization"
"start-feishubot/services"
"start-feishubot/services/openai"
"strings"

larkcard "github.com/larksuite/oapi-sdk-go/v3/card"
Expand All @@ -26,7 +27,7 @@ func chain(data *ActionInfo, actions ...Action) bool {
type MessageHandler struct {
sessionCache services.SessionServiceCacheInterface
msgCache services.MsgCacheInterface
gpt *services.ChatGPT
gpt *openai.ChatGPT
config initialization.Config
}

Expand All @@ -36,6 +37,7 @@ func (m MessageHandler) cardHandler(_ context.Context,
actionValue := cardAction.Action.Value
actionValueJson, _ := json.Marshal(actionValue)
json.Unmarshal(actionValueJson, &cardMsg)
//fmt.Println("cardMsg: ", cardMsg)
if cardMsg.Kind == ClearCardKind {
newCard, err, done := CommonProcessClearCache(cardMsg, m.sessionCache)
if done {
Expand All @@ -44,15 +46,28 @@ func (m MessageHandler) cardHandler(_ context.Context,
return nil, nil
}
if cardMsg.Kind == PicResolutionKind {
CommonProcessPicResolution(cardMsg, cardAction, m.sessionCache)
//todo: 暂时不允许 以图搜图 模式下的 再来一张
//CommonProcessPicResolution(cardMsg, cardAction, m.sessionCache)
return nil, nil
}
if cardMsg.Kind == PicMoreKind {
if cardMsg.Kind == PicTextMoreKind {
go func() {
m.CommonProcessPicMore(cardMsg)
}()
}
if cardMsg.Kind == PicVarMoreKind {
go func() {
m.CommonProcessPicMore(cardMsg)
}()
}
if cardMsg.Kind == PicModeChangeKind {
newCard, err, done := CommonProcessPicModeChange(cardMsg, m.sessionCache)
if done {
return newCard, err
}
return nil, nil

}
return nil, nil

}
Expand All @@ -63,7 +78,7 @@ func (m MessageHandler) CommonProcessPicMore(msg CardMsg) {
//fmt.Println("msg: ", msg)
question := msg.Value.(string)
bs64, _ := m.gpt.GenerateOneImage(question, resolution)
replayImageByBase64(context.Background(), bs64, &msg.MsgId,
replayImageCardByBase64(context.Background(), bs64, &msg.MsgId,
&msg.SessionId, question)
}

Expand Down Expand Up @@ -100,18 +115,60 @@ func CommonProcessClearCache(cardMsg CardMsg, session services.SessionServiceCac
return nil, nil, false
}

func CommonProcessPicModeChange(cardMsg CardMsg,
session services.SessionServiceCacheInterface) (
interface{}, error, bool) {
if cardMsg.Value == "1" {

sessionId := cardMsg.SessionId
session.Clear(sessionId)
session.SetMode(sessionId,
services.ModePicCreate)
session.SetPicResolution(sessionId,
services.Resolution256)

newCard, _ :=
newSendCard(
withHeader("🖼️ 已进入图片创作模式", larkcard.TemplateBlue),
withPicResolutionBtn(&sessionId),
withNote("提醒:回复文本或图片,让AI生成相关的图片。"))
return newCard, nil, true
}
if cardMsg.Value == "0" {
newCard, _ := newSendCard(
withHeader("️🎒 机器人提醒", larkcard.TemplateGreen),
withMainMd("依旧保留此话题的上下文信息"),
withNote("我们可以继续探讨这个话题,期待和您聊天。如果您有其他问题或者想要讨论的话题,请告诉我哦"),
)
return newCard, nil, true
}
return nil, nil, false
}
func judgeMsgType(event *larkim.P2MessageReceiveV1) (string, error) {
msgType := event.Event.Message.MessageType

switch *msgType {
case "text", "image", "audio":
return *msgType, nil
default:
return "", fmt.Errorf("unknown message type: %v", *msgType)
}

}

func (m MessageHandler) msgReceivedHandler(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
handlerType := judgeChatType(event)
if handlerType == "otherChat" {
fmt.Println("unknown chat type")
return nil
}
msgType := judgeMsgType(event)
if msgType != "text" && msgType != "audio" {
fmt.Println("unknown msg type")
//fmt.Println(larkcore.Prettify(event.Event.Message))

msgType, err := judgeMsgType(event)
if err != nil {
fmt.Printf("error getting message type: %v\n", err)
return nil
}
//fmt.Println(larkcore.Prettify(event.Event.Message))

content := event.Event.Message.Content
msgId := event.Event.Message.MessageId
Expand All @@ -130,6 +187,7 @@ func (m MessageHandler) msgReceivedHandler(ctx context.Context, event *larkim.P2
chatId: chatId,
qParsed: strings.Trim(parseContent(*content), " "),
fileKey: parseFileKey(*content),
imageKey: parseImageKey(*content),
sessionId: sessionId,
mention: mention,
}
Expand All @@ -142,11 +200,11 @@ func (m MessageHandler) msgReceivedHandler(ctx context.Context, event *larkim.P2
&ProcessedUniqueAction{}, //避免重复处理
&ProcessMentionAction{}, //判断机器人是否应该被调用
&AudioAction{}, //语音处理
&PicAction{}, //图片处理
&EmptyAction{}, //空消息处理
&ClearAction{}, //清除消息处理
&HelpAction{}, //帮助处理
&RolePlayAction{}, //角色扮演处理
&PicAction{}, //图片处理
&MessageAction{}, //消息处理

}
Expand All @@ -156,7 +214,7 @@ func (m MessageHandler) msgReceivedHandler(ctx context.Context, event *larkim.P2

var _ MessageHandlerInterface = (*MessageHandler)(nil)

func NewMessageHandler(gpt *services.ChatGPT,
func NewMessageHandler(gpt *openai.ChatGPT,
config initialization.Config) MessageHandlerInterface {
return &MessageHandler{
sessionCache: services.GetSessionCache(),
Expand Down
19 changes: 2 additions & 17 deletions code/handlers/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package handlers
import (
"context"
"start-feishubot/initialization"
"start-feishubot/services"
"start-feishubot/services/openai"

larkcard "github.com/larksuite/oapi-sdk-go/v3/card"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
Expand All @@ -24,7 +24,7 @@ const (
// handlers 所有消息类型类型的处理器
var handlers MessageHandlerInterface

func InitHandlers(gpt *services.ChatGPT, config initialization.Config) {
func InitHandlers(gpt *openai.ChatGPT, config initialization.Config) {
handlers = NewMessageHandler(gpt, config)
}

Expand Down Expand Up @@ -69,18 +69,3 @@ func judgeChatType(event *larkim.P2MessageReceiveV1) HandlerType {
}
return "otherChat"
}

func judgeMsgType(event *larkim.P2MessageReceiveV1) string {
msgType := event.Event.Message.MessageType
if *msgType == "text" {
return "text"
}
if *msgType == "image" {
return "image"
}
if *msgType == "audio" {
return "audio"
}

return ""
}
Loading

0 comments on commit 9875b6e

Please sign in to comment.