Skip to content

Commit

Permalink
接入ai绘图
Browse files Browse the repository at this point in the history
  • Loading branch information
lokistars committed Mar 28, 2023
1 parent f4bbbb8 commit 3cfa8bb
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,20 @@ import (

var client = &http.Client{}

type request struct {
Model string `json:"model"`
Message []message `json:"messages"`
MaxTokens int `json:"-"`
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`
N int `json:"n"`
Stream bool `json:"stream"`
Stop string `json:"stop"`
}

type message struct {
Role string `json:"role"`
Content string `json:"content"`
}

type response struct {
Model string
messages []string
Choices []struct {
Text string `json:"text"`
} `json:"choices"`
}

// RequestOpenAiChat 请求OpenAi 聊天模型
// https://platform.openai.com/docs/api-reference/introduction 接口文档
func RequestOpenAiChat(msg string) []byte {
func RequestOpenAiChat(msg string, w http.ResponseWriter) []byte {

url := "https://api.openai.com/v1/chat/completions"

fmt.Println("message:", msg)

messages := make([]message, 1)
messages[0] = message{
Role: "user",
Content: msg,
}
reqData := request{
reqData := requestChat{
Model: "gpt-3.5-turbo",
Message: messages,
Temperature: 0.5,
Expand All @@ -63,42 +42,57 @@ func RequestOpenAiChat(msg string) []byte {
_ = resp.Body.Close()
}
}()

reader := bufio.NewReader(resp.Body)

decoder := json.NewDecoder(reader)
decoder.UseNumber()
for {
var delta struct {
Role string `json:"role"`
scanner := bufio.NewScanner(resp.Body)
role := ""

flusher, _ := w.(http.Flusher)
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
w.Header().Set("Transfer-Encoding", "chunked") // 支持分块传输
w.WriteHeader(http.StatusOK)
flusher.Flush()

for scanner.Scan() {
line := scanner.Text()[6:]
if line == "[DONE]" {
fmt.Println()
_, _ = w.Write([]byte("\n"))
break
}
err := decoder.Decode(&delta)
if err != nil {
if err == io.EOF {
break
}
fmt.Println("Error:", err)
res := responseChat{}
_ = json.Unmarshal([]byte(line), &res)
scanner.Scan()
choices := res.Choices[0]
if choices.FinishReason == "stop" {
role = ""
continue
}
switch delta.Role {
case "batch":
break
if role == "" {
role = choices.Delta.Role
}
switch role {
case "assistant":
//_, _ = w.Write([]byte(choices.Delta.Content))
_, _ = io.WriteString(w, choices.Delta.Content)
flusher.Flush()

// 实现流式响应
//encoder := json.NewEncoder(w)
//encoder.Encode(choices.Delta.Content)
//w.(http.Flusher).Flush()
case "additions":
fmt.Printf("Completion: %s\n", decoder)
case "batch":
case "selection":
break
default:
fmt.Println("Unknown delta role:", delta.Role)
fmt.Println("Unknown delta role:", choices.Delta.Role)
break
}
}

body, _ := io.ReadAll(resp.Body)
fmt.Println("成功,返回结果:", string(body))
}

return body
return nil
}

// RequestCompletions 请求OpenAi 智能补全模型
func RequestCompletions(msg string) []byte {
url := "https://api.openai.com/v1/completions"
type request struct {
Expand Down
55 changes: 55 additions & 0 deletions src/com/lucky/ai/api/OpenAiImages.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package api

import (
"bytes"
"encoding/json"
"fmt"
"net/http"
)

// RequestOpenAiImages 请求OpenAi 图片模型
func RequestOpenAiImages(prompt string) string {
url := "https://api.openai.com/v1/images/generations"
fmt.Println("prompt:", prompt)
reqData := requestImage{
Prompt: prompt,
Size: "1024x1024",
N: 1,
ResponseFormat: "url",
}
reqBody, _ := json.Marshal(reqData)
resp := httpRequest(http.MethodPost, url, bytes.NewReader(reqBody))

defer func() {
if nil != resp {
_ = resp.Body.Close()
}
}()
decoder := json.NewDecoder(resp.Body)

resBody := &responseImage{}

_ = decoder.Decode(resBody)
return resBody.Data[0].Url
}

// RequestOpenAiImageEdit 请求OpenAi 图片编辑模型
func RequestOpenAiImageEdit(prompt string) string {
url := "https://api.openai.com/v1/images/edits"
reqData := ImageEdit{
Image: "",
Mask: "",
Prompt: prompt,
Size: "1024x1024",
N: 1,
ResponseFormat: "url",
}
reqBody, _ := json.Marshal(reqData)
resp := httpRequest(http.MethodPost, url, bytes.NewReader(reqBody))
defer func() {
if nil != resp {
_ = resp.Body.Close()
}
}()
return ""
}
34 changes: 34 additions & 0 deletions src/com/lucky/ai/api/request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package api

type requestChat struct {
Model string `json:"model"`
Message []message `json:"messages"`
MaxTokens int `json:"-"`
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`
N int `json:"n"`
Stream bool `json:"stream"`
Stop string `json:"-"`
}

type message struct {
Role string `json:"role"`
Content string `json:"content"`
}

type requestImage struct {
Prompt string `json:"prompt"` // 描述信息
N int `json:"n" default:"1"`
Size string `json:"size" default:"1024x1024"`
ResponseFormat string `json:"response_format" default:"url"` // b64_json
}

type ImageEdit struct {
Image string `json:"image"` // 修改的头像
Mask string `json:"mask"` // 附加头像,需要编辑的位置
Prompt string `json:"prompt"`
Size string `json:"size"`
N int `json:"n"`
ResponseFormat string `json:"response_format"` // 响应格式 url 或 b64_json

}
22 changes: 22 additions & 0 deletions src/com/lucky/ai/api/response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package api

type responseImage struct {
Created int `json:"created"`
Data []struct {
B64Json string `json:"b64_json"`
Url string `json:"url"`
} `json:"data"`
}

type responseChat struct {
Id string `json:"id"`
Model string `json:"model"`
Choices []struct {
Delta struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"delta"`
FinishReason string `json:"finish_reason"`
Index string `json:"index"`
} `json:"choices"`
}
56 changes: 49 additions & 7 deletions src/com/lucky/ai/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"github.com/gorilla/websocket"
"io"
"lucky-ai/src/com/lucky/ai/api"
"net"
"net/http"
Expand All @@ -14,6 +15,7 @@ func Server() {
http.HandleFunc("/openai", webSocketHandShake)
http.HandleFunc("/openai/modelList", gatModelHandle)
http.HandleFunc("/openai/modelDetails", gatModelHandle)
http.HandleFunc("/images", imageHandle)

var str string
addrs, _ := net.InterfaceAddrs()
Expand Down Expand Up @@ -58,7 +60,7 @@ func webSocketHandShake(w http.ResponseWriter, r *http.Request) {
defer func() {
_ = conn.Close()
}()

fmt.Print("Request socket URL:", r.RemoteAddr, ": ")
for {
_, data, err := conn.ReadMessage()

Expand All @@ -67,6 +69,7 @@ func webSocketHandShake(w http.ResponseWriter, r *http.Request) {
break
}
fmt.Printf("收到消息:%v", data)
api.RequestCompletions(string(data))
}
}

Expand All @@ -76,7 +79,7 @@ func gatModelHandle(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
return
}

fmt.Print("Request model URL:", r.RemoteAddr, ": ")
switch r.Method {
case http.MethodGet:
if r.URL.Path == "/openai/modelList" {
Expand All @@ -90,11 +93,11 @@ func gatModelHandle(w http.ResponseWriter, r *http.Request) {
} else if r.URL.Path == "/openai/modelDetails" {
value := r.FormValue("model")
if "" == value {
w.Write([]byte("model is null"))
_, _ = w.Write([]byte("model is null"))
return
}
bytes := api.RetrieveModel(value)
w.Write(bytes)
_, _ = w.Write(bytes)
}
break
default:
Expand All @@ -107,15 +110,54 @@ func openAiHandle(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
return
}
fmt.Print("Request ai URL:", r.RemoteAddr, ": ")
switch r.Method {
case http.MethodGet:
msg := r.FormValue("msg")
if "" == msg {
w.Write([]byte("msg is null"))
_, _ = w.Write([]byte("msg is null"))
fmt.Println()
return
}
api.RequestOpenAiChat(msg, w)
break
default:
http.NotFound(w, r)
}
}

func imageHandle(w http.ResponseWriter, r *http.Request) {
if nil == w || nil == r {
http.NotFound(w, r)
return
}
fmt.Print("Request images URL:", r.RemoteAddr, ": ")
switch r.Method {
case http.MethodGet:

value := r.FormValue("prompt")
if "" == value {
_, _ = w.Write([]byte("prompt is null"))
return
}
w.Header().Set("Content-Type", "image/png")

images := api.RequestOpenAiImages(value)

// 读取图像文件
file, err := http.Get(images)
if err != nil {
fmt.Println(err)
return
}
defer file.Body.Close()

// 将图像写入响应中
_, err = io.Copy(w, file.Body)
if err != nil {
fmt.Println(err)
return
}
bytes := api.RequestOpenAiChat(msg)
w.Write(bytes)
break
default:
http.NotFound(w, r)
Expand Down
22 changes: 2 additions & 20 deletions src/com/lucky/main.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,8 @@
package main

import (
"encoding/json"
"fmt"
)
import "lucky-ai/src/com/lucky/ai/server"

func main() {
// 通过代理请求 curl https://www.google.com -x http://127.0.0.1:7890
//server.Server()
str := "data: {\"id\":\"chatcmpl-6yOiFpOOqOOsfkj2285j6p1CHZr9Z\",\"object\":\"chat.completion.chunk\",\"created\":1679852695,\"model\":\"gpt-3.5-turbo-0301\",\"choices\":[{\"delta\":{\"content\":\"\"},\"index\":0,\"finish_reason\":null}]}"
//fmt.Println(str)
bytes := []byte(str)
for {
var responseJSON map[string]interface{}
err := json.Unmarshal(bytes, &responseJSON)
if err != nil {
break
}
choices := responseJSON["choices"].([]interface{})
if len(choices) > 0 {
text := choices[0].(map[string]interface{})["delta"].(map[string]interface{})["content"].(string)
fmt.Print(text)
}
}
server.Server()
}

0 comments on commit 3cfa8bb

Please sign in to comment.