From 103388ececa9ac9cf16f21a9a8f8a43d5a5ab951 Mon Sep 17 00:00:00 2001 From: InfosecWatchman Date: Thu, 19 Dec 2024 16:14:51 -0500 Subject: [PATCH] Add Endpoints to facilitate Ollama based chats Add Endpoints to facilitate Ollama based chats. Built to use with Open WebUI --- cli/cli.go | 6 + cli/flags.go | 1 + restapi/ollama.go | 275 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 282 insertions(+) create mode 100644 restapi/ollama.go diff --git a/cli/cli.go b/cli/cli.go index 5614f5043..81c0e0d3f 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -56,6 +56,12 @@ func Cli(version string) (err error) { return } + if currentFlags.ServeOllama { + registry.ConfigureVendors() + err = restapi.ServeOllama(registry, currentFlags.ServeAddress, version) + return + } + if currentFlags.UpdatePatterns { err = registry.PatternsLoader.PopulateDB() return diff --git a/cli/flags.go b/cli/flags.go index 552d8b5e9..33730a7a7 100644 --- a/cli/flags.go +++ b/cli/flags.go @@ -60,6 +60,7 @@ type Flags struct { InputHasVars bool `long:"input-has-vars" description:"Apply variables to user input"` DryRun bool `long:"dry-run" description:"Show what would be sent to the model without actually sending it"` Serve bool `long:"serve" description:"Serve the Fabric Rest API"` + ServeOllama bool `long:"serveOllama" description:"Serve the Fabric Rest API with ollama endpoints"` ServeAddress string `long:"address" description:"The address to bind the REST API" default:":8080"` Config string `long:"config" description:"Path to YAML config file"` Version bool `long:"version" description:"Print current version"` diff --git a/restapi/ollama.go b/restapi/ollama.go new file mode 100644 index 000000000..0d806729a --- /dev/null +++ b/restapi/ollama.go @@ -0,0 +1,275 @@ +package restapi + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "github.com/danielmiessler/fabric/core" + "github.com/gin-gonic/gin" + "io" + "log" + "net/http" + "strings" + "time" +) + +type OllamaModel struct { + Models []Model `json:"models"` +} +type Model struct { + Details ModelDetails `json:"details"` + Digest string `json:"digest"` + Model string `json:"model"` + ModifiedAt string `json:"modified_at"` + Name string `json:"name"` + Size int64 `json:"size"` +} + +type ModelDetails struct { + Families []string `json:"families"` + Family string `json:"family"` + Format string `json:"format"` + ParameterSize string `json:"parameter_size"` + ParentModel string `json:"parent_model"` + QuantizationLevel string `json:"quantization_level"` +} + +type APIConvert struct { + registry *core.PluginRegistry + r *gin.Engine + addr *string +} + +type OllamaRequestBody struct { + Messages []OllamaMessage `json:"messages"` + Model string `json:"model"` + Options struct { + } `json:"options"` + Stream bool `json:"stream"` +} + +type OllamaMessage struct { + Content string `json:"content"` + Role string `json:"role"` +} + +type OllamaResponse struct { + Model string `json:"model"` + CreatedAt string `json:"created_at"` + Message struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + DoneReason string `json:"done_reason,omitempty"` + Done bool `json:"done"` + TotalDuration int64 `json:"total_duration,omitempty"` + LoadDuration int `json:"load_duration,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + PromptEvalDuration int `json:"prompt_eval_duration,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + EvalDuration int64 `json:"eval_duration,omitempty"` +} + +type FabricResponseFormat struct { + Type string `json:"type"` + Format string `json:"format"` + Content string `json:"content"` +} + +func ServeOllama(registry *core.PluginRegistry, address string, version string) (err error) { + r := gin.New() + + // Middleware + r.Use(gin.Logger()) + r.Use(gin.Recovery()) + + // Register routes + fabricDb := registry.Db + NewPatternsHandler(r, fabricDb.Patterns) + NewContextsHandler(r, fabricDb.Contexts) + NewSessionsHandler(r, fabricDb.Sessions) + NewChatHandler(r, registry, fabricDb) + NewConfigHandler(r, fabricDb) + NewModelsHandler(r, registry.VendorManager) + + typeConversion := APIConvert{ + registry: registry, + r: r, + addr: &address, + } + // Ollama Endpoints + r.GET("/api/tags", typeConversion.ollamaTags) + r.GET("/api/version", func(c *gin.Context) { + c.Data(200, "application/json", []byte(fmt.Sprintf("{\"%s\"}", version))) + return + }) + r.POST("/api/chat", typeConversion.ollamaChat) + + // Start server + err = r.Run(address) + if err != nil { + return err + } + + return +} + +func (f APIConvert) ollamaTags(c *gin.Context) { + patterns, err := f.registry.Db.Patterns.GetNames() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err}) + return + } + var response OllamaModel + for _, pattern := range patterns { + today := time.Now().Format("2024-11-25T12:07:58.915991813-05:00") + details := ModelDetails{ + Families: []string{"fabric"}, + Family: "fabric", + Format: "custom", + ParameterSize: "42.0B", + ParentModel: "", + QuantizationLevel: "", + } + response.Models = append(response.Models, Model{ + Details: details, + Digest: "365c0bd3c000a25d28ddbf732fe1c6add414de7275464c4e4d1c3b5fcb5d8ad1", + Model: fmt.Sprintf("%s:latest", pattern), + ModifiedAt: today, + Name: fmt.Sprintf("%s:latest", pattern), + Size: 0, + }) + } + + c.JSON(200, response) + +} + +func (f APIConvert) ollamaChat(c *gin.Context) { + body, err := io.ReadAll(c.Request.Body) + if err != nil { + log.Printf("Error reading body: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"}) + return + } + var prompt OllamaRequestBody + err = json.Unmarshal(body, &prompt) + if err != nil { + log.Printf("Error unmarshalling body: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"}) + return + } + now := time.Now() + var chat ChatRequest + + if len(prompt.Messages) == 1 { + chat.Prompts = []PromptRequest{{ + UserInput: prompt.Messages[0].Content, + Vendor: "", + Model: "", + ContextName: "", + PatternName: strings.Split(prompt.Model, ":")[0], + }} + } else if len(prompt.Messages) > 1 { + var content string + for _, msg := range prompt.Messages { + content = fmt.Sprintf("%s%s:%s\n", content, msg.Role, msg.Content) + } + chat.Prompts = []PromptRequest{{ + UserInput: content, + Vendor: "", + Model: "", + ContextName: "", + PatternName: strings.Split(prompt.Model, ":")[0], + }} + } + fabricChatReq, err := json.Marshal(chat) + if err != nil { + log.Printf("Error marshalling body: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err}) + return + } + ctx := context.Background() + var req *http.Request + if strings.Contains(*f.addr, "http") { + req, err = http.NewRequest("POST", fmt.Sprintf("%s/chat", *f.addr), bytes.NewBuffer(fabricChatReq)) + } else { + req, err = http.NewRequest("POST", fmt.Sprintf("http://127.0.0.1%s/chat", *f.addr), bytes.NewBuffer(fabricChatReq)) + } + if err != nil { + log.Fatal(err) + } + + req = req.WithContext(ctx) + + fabricRes, err := http.DefaultClient.Do(req) + if err != nil { + log.Printf("Error getting /chat body: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err}) + return + } + body, err = io.ReadAll(fabricRes.Body) + if err != nil { + log.Printf("Error reading body: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"}) + return + } + var forwardedResponse OllamaResponse + var forwardedResponses []OllamaResponse + var fabricResponse FabricResponseFormat + err = json.Unmarshal([]byte(strings.Split(strings.Split(string(body), "\n")[0], "data: ")[1]), &fabricResponse) + if err != nil { + log.Printf("Error unmarshalling body: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"}) + return + } + for _, word := range strings.Split(fabricResponse.Content, " ") { + forwardedResponse = OllamaResponse{ + Model: "", + CreatedAt: "", + Message: struct { + Role string `json:"role"` + Content string `json:"content"` + }(struct { + Role string + Content string + }{Content: fmt.Sprintf("%s ", word), Role: "assistant"}), + Done: false, + } + forwardedResponses = append(forwardedResponses, forwardedResponse) + } + forwardedResponse.Model = prompt.Model + forwardedResponse.CreatedAt = time.Now().UTC().Format("2006-01-02T15:04:05.999999999Z") + forwardedResponse.Message.Role = "assistant" + forwardedResponse.Message.Content = "" + forwardedResponse.DoneReason = "stop" + forwardedResponse.Done = true + forwardedResponse.TotalDuration = time.Since(now).Nanoseconds() + forwardedResponse.LoadDuration = int(time.Since(now).Nanoseconds()) + forwardedResponse.PromptEvalCount = 42 + forwardedResponse.PromptEvalDuration = int(time.Since(now).Nanoseconds()) + forwardedResponse.EvalCount = 420 + forwardedResponse.EvalDuration = time.Since(now).Nanoseconds() + forwardedResponses = append(forwardedResponses, forwardedResponse) + + var res []byte + for _, response := range forwardedResponses { + marshalled, err := json.Marshal(response) + if err != nil { + log.Printf("Error marshalling body: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err}) + return + } + for _, bytein := range marshalled { + res = append(res, bytein) + } + for _, bytebreak := range []byte("\n") { + res = append(res, bytebreak) + } + } + c.Data(200, "application/json", res) + + //c.JSON(200, forwardedResponse) + return +}