Skip to content

Commit

Permalink
Add Endpoints to facilitate Ollama based chats
Browse files Browse the repository at this point in the history
Add Endpoints to facilitate Ollama based chats.

Built to use with Open WebUI
  • Loading branch information
infosecwatchman committed Dec 19, 2024
1 parent 1ce5bd4 commit 103388e
Show file tree
Hide file tree
Showing 3 changed files with 282 additions and 0 deletions.
6 changes: 6 additions & 0 deletions cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ func Cli(version string) (err error) {
return
}

if currentFlags.ServeOllama {
registry.ConfigureVendors()
err = restapi.ServeOllama(registry, currentFlags.ServeAddress, version)
return
}

if currentFlags.UpdatePatterns {
err = registry.PatternsLoader.PopulateDB()
return
Expand Down
1 change: 1 addition & 0 deletions cli/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ type Flags struct {
InputHasVars bool `long:"input-has-vars" description:"Apply variables to user input"`
DryRun bool `long:"dry-run" description:"Show what would be sent to the model without actually sending it"`
Serve bool `long:"serve" description:"Serve the Fabric Rest API"`
ServeOllama bool `long:"serveOllama" description:"Serve the Fabric Rest API with ollama endpoints"`
ServeAddress string `long:"address" description:"The address to bind the REST API" default:":8080"`
Config string `long:"config" description:"Path to YAML config file"`
Version bool `long:"version" description:"Print current version"`
Expand Down
275 changes: 275 additions & 0 deletions restapi/ollama.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
package restapi

import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/danielmiessler/fabric/core"
"github.com/gin-gonic/gin"
"io"
"log"
"net/http"
"strings"
"time"
)

type OllamaModel struct {
Models []Model `json:"models"`
}
type Model struct {
Details ModelDetails `json:"details"`
Digest string `json:"digest"`
Model string `json:"model"`
ModifiedAt string `json:"modified_at"`
Name string `json:"name"`
Size int64 `json:"size"`
}

type ModelDetails struct {
Families []string `json:"families"`
Family string `json:"family"`
Format string `json:"format"`
ParameterSize string `json:"parameter_size"`
ParentModel string `json:"parent_model"`
QuantizationLevel string `json:"quantization_level"`
}

type APIConvert struct {
registry *core.PluginRegistry
r *gin.Engine
addr *string
}

type OllamaRequestBody struct {
Messages []OllamaMessage `json:"messages"`
Model string `json:"model"`
Options struct {
} `json:"options"`
Stream bool `json:"stream"`
}

type OllamaMessage struct {
Content string `json:"content"`
Role string `json:"role"`
}

type OllamaResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
DoneReason string `json:"done_reason,omitempty"`
Done bool `json:"done"`
TotalDuration int64 `json:"total_duration,omitempty"`
LoadDuration int `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int64 `json:"eval_duration,omitempty"`
}

type FabricResponseFormat struct {
Type string `json:"type"`
Format string `json:"format"`
Content string `json:"content"`
}

func ServeOllama(registry *core.PluginRegistry, address string, version string) (err error) {
r := gin.New()

// Middleware
r.Use(gin.Logger())
r.Use(gin.Recovery())

// Register routes
fabricDb := registry.Db
NewPatternsHandler(r, fabricDb.Patterns)
NewContextsHandler(r, fabricDb.Contexts)
NewSessionsHandler(r, fabricDb.Sessions)
NewChatHandler(r, registry, fabricDb)
NewConfigHandler(r, fabricDb)
NewModelsHandler(r, registry.VendorManager)

typeConversion := APIConvert{
registry: registry,
r: r,
addr: &address,
}
// Ollama Endpoints
r.GET("/api/tags", typeConversion.ollamaTags)
r.GET("/api/version", func(c *gin.Context) {
c.Data(200, "application/json", []byte(fmt.Sprintf("{\"%s\"}", version)))
return
})
r.POST("/api/chat", typeConversion.ollamaChat)

// Start server
err = r.Run(address)
if err != nil {
return err
}

return
}

func (f APIConvert) ollamaTags(c *gin.Context) {
patterns, err := f.registry.Db.Patterns.GetNames()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
return
}
var response OllamaModel
for _, pattern := range patterns {
today := time.Now().Format("2024-11-25T12:07:58.915991813-05:00")
details := ModelDetails{
Families: []string{"fabric"},
Family: "fabric",
Format: "custom",
ParameterSize: "42.0B",
ParentModel: "",
QuantizationLevel: "",
}
response.Models = append(response.Models, Model{
Details: details,
Digest: "365c0bd3c000a25d28ddbf732fe1c6add414de7275464c4e4d1c3b5fcb5d8ad1",
Model: fmt.Sprintf("%s:latest", pattern),
ModifiedAt: today,
Name: fmt.Sprintf("%s:latest", pattern),
Size: 0,
})
}

c.JSON(200, response)

}

func (f APIConvert) ollamaChat(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("Error reading body: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"})
return
}
var prompt OllamaRequestBody
err = json.Unmarshal(body, &prompt)
if err != nil {
log.Printf("Error unmarshalling body: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"})
return
}
now := time.Now()
var chat ChatRequest

if len(prompt.Messages) == 1 {
chat.Prompts = []PromptRequest{{
UserInput: prompt.Messages[0].Content,
Vendor: "",
Model: "",
ContextName: "",
PatternName: strings.Split(prompt.Model, ":")[0],
}}
} else if len(prompt.Messages) > 1 {
var content string
for _, msg := range prompt.Messages {
content = fmt.Sprintf("%s%s:%s\n", content, msg.Role, msg.Content)
}
chat.Prompts = []PromptRequest{{
UserInput: content,
Vendor: "",
Model: "",
ContextName: "",
PatternName: strings.Split(prompt.Model, ":")[0],
}}
}
fabricChatReq, err := json.Marshal(chat)
if err != nil {
log.Printf("Error marshalling body: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
return
}
ctx := context.Background()
var req *http.Request
if strings.Contains(*f.addr, "http") {
req, err = http.NewRequest("POST", fmt.Sprintf("%s/chat", *f.addr), bytes.NewBuffer(fabricChatReq))
} else {
req, err = http.NewRequest("POST", fmt.Sprintf("http://127.0.0.1%s/chat", *f.addr), bytes.NewBuffer(fabricChatReq))
}
if err != nil {
log.Fatal(err)
}

req = req.WithContext(ctx)

fabricRes, err := http.DefaultClient.Do(req)
if err != nil {
log.Printf("Error getting /chat body: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
return
}
body, err = io.ReadAll(fabricRes.Body)
if err != nil {
log.Printf("Error reading body: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"})
return
}
var forwardedResponse OllamaResponse
var forwardedResponses []OllamaResponse
var fabricResponse FabricResponseFormat
err = json.Unmarshal([]byte(strings.Split(strings.Split(string(body), "\n")[0], "data: ")[1]), &fabricResponse)
if err != nil {
log.Printf("Error unmarshalling body: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"})
return
}
for _, word := range strings.Split(fabricResponse.Content, " ") {
forwardedResponse = OllamaResponse{
Model: "",
CreatedAt: "",
Message: struct {
Role string `json:"role"`
Content string `json:"content"`
}(struct {
Role string
Content string
}{Content: fmt.Sprintf("%s ", word), Role: "assistant"}),
Done: false,
}
forwardedResponses = append(forwardedResponses, forwardedResponse)
}
forwardedResponse.Model = prompt.Model
forwardedResponse.CreatedAt = time.Now().UTC().Format("2006-01-02T15:04:05.999999999Z")
forwardedResponse.Message.Role = "assistant"
forwardedResponse.Message.Content = ""
forwardedResponse.DoneReason = "stop"
forwardedResponse.Done = true
forwardedResponse.TotalDuration = time.Since(now).Nanoseconds()
forwardedResponse.LoadDuration = int(time.Since(now).Nanoseconds())
forwardedResponse.PromptEvalCount = 42
forwardedResponse.PromptEvalDuration = int(time.Since(now).Nanoseconds())
forwardedResponse.EvalCount = 420
forwardedResponse.EvalDuration = time.Since(now).Nanoseconds()
forwardedResponses = append(forwardedResponses, forwardedResponse)

var res []byte
for _, response := range forwardedResponses {
marshalled, err := json.Marshal(response)
if err != nil {
log.Printf("Error marshalling body: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
return
}
for _, bytein := range marshalled {
res = append(res, bytein)
}
for _, bytebreak := range []byte("\n") {
res = append(res, bytebreak)
}
}
c.Data(200, "application/json", res)

//c.JSON(200, forwardedResponse)
return
}

0 comments on commit 103388e

Please sign in to comment.