Skip to content

Commit

Permalink
Merge pull request #1215 from infosecwatchman/main
Browse files Browse the repository at this point in the history
Add Endpoints to facilitate Ollama based chats
  • Loading branch information
eugeis authored Dec 21, 2024
2 parents 89edd71 + 103388e commit 65285fd
Show file tree
Hide file tree
Showing 3 changed files with 282 additions and 0 deletions.
6 changes: 6 additions & 0 deletions cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ func Cli(version string) (err error) {
return
}

if currentFlags.ServeOllama {
registry.ConfigureVendors()
err = restapi.ServeOllama(registry, currentFlags.ServeAddress, version)
return
}

if currentFlags.UpdatePatterns {
err = registry.PatternsLoader.PopulateDB()
return
Expand Down
1 change: 1 addition & 0 deletions cli/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type Flags struct {
InputHasVars bool `long:"input-has-vars" description:"Apply variables to user input"`
DryRun bool `long:"dry-run" description:"Show what would be sent to the model without actually sending it"`
Serve bool `long:"serve" description:"Serve the Fabric Rest API"`
ServeOllama bool `long:"serveOllama" description:"Serve the Fabric Rest API with ollama endpoints"`
ServeAddress string `long:"address" description:"The address to bind the REST API" default:":8080"`
Config string `long:"config" description:"Path to YAML config file"`
Version bool `long:"version" description:"Print current version"`
Expand Down
275 changes: 275 additions & 0 deletions restapi/ollama.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
package restapi

import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/danielmiessler/fabric/core"
"github.com/gin-gonic/gin"
"io"
"log"
"net/http"
"strings"
"time"
)

type OllamaModel struct {
Models []Model `json:"models"`
}
type Model struct {
Details ModelDetails `json:"details"`
Digest string `json:"digest"`
Model string `json:"model"`
ModifiedAt string `json:"modified_at"`
Name string `json:"name"`
Size int64 `json:"size"`
}

type ModelDetails struct {
Families []string `json:"families"`
Family string `json:"family"`
Format string `json:"format"`
ParameterSize string `json:"parameter_size"`
ParentModel string `json:"parent_model"`
QuantizationLevel string `json:"quantization_level"`
}

type APIConvert struct {
registry *core.PluginRegistry
r *gin.Engine
addr *string
}

type OllamaRequestBody struct {
Messages []OllamaMessage `json:"messages"`
Model string `json:"model"`
Options struct {
} `json:"options"`
Stream bool `json:"stream"`
}

type OllamaMessage struct {
Content string `json:"content"`
Role string `json:"role"`
}

type OllamaResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
DoneReason string `json:"done_reason,omitempty"`
Done bool `json:"done"`
TotalDuration int64 `json:"total_duration,omitempty"`
LoadDuration int `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int64 `json:"eval_duration,omitempty"`
}

type FabricResponseFormat struct {
Type string `json:"type"`
Format string `json:"format"`
Content string `json:"content"`
}

func ServeOllama(registry *core.PluginRegistry, address string, version string) (err error) {
r := gin.New()

// Middleware
r.Use(gin.Logger())
r.Use(gin.Recovery())

// Register routes
fabricDb := registry.Db
NewPatternsHandler(r, fabricDb.Patterns)
NewContextsHandler(r, fabricDb.Contexts)
NewSessionsHandler(r, fabricDb.Sessions)
NewChatHandler(r, registry, fabricDb)
NewConfigHandler(r, fabricDb)
NewModelsHandler(r, registry.VendorManager)

typeConversion := APIConvert{
registry: registry,
r: r,
addr: &address,
}
// Ollama Endpoints
r.GET("/api/tags", typeConversion.ollamaTags)
r.GET("/api/version", func(c *gin.Context) {
c.Data(200, "application/json", []byte(fmt.Sprintf("{\"%s\"}", version)))
return
})
r.POST("/api/chat", typeConversion.ollamaChat)

// Start server
err = r.Run(address)
if err != nil {
return err
}

return
}

func (f APIConvert) ollamaTags(c *gin.Context) {
patterns, err := f.registry.Db.Patterns.GetNames()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
return
}
var response OllamaModel
for _, pattern := range patterns {
today := time.Now().Format("2024-11-25T12:07:58.915991813-05:00")
details := ModelDetails{
Families: []string{"fabric"},
Family: "fabric",
Format: "custom",
ParameterSize: "42.0B",
ParentModel: "",
QuantizationLevel: "",
}
response.Models = append(response.Models, Model{
Details: details,
Digest: "365c0bd3c000a25d28ddbf732fe1c6add414de7275464c4e4d1c3b5fcb5d8ad1",
Model: fmt.Sprintf("%s:latest", pattern),
ModifiedAt: today,
Name: fmt.Sprintf("%s:latest", pattern),
Size: 0,
})
}

c.JSON(200, response)

}

func (f APIConvert) ollamaChat(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("Error reading body: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"})
return
}
var prompt OllamaRequestBody
err = json.Unmarshal(body, &prompt)
if err != nil {
log.Printf("Error unmarshalling body: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"})
return
}
now := time.Now()
var chat ChatRequest

if len(prompt.Messages) == 1 {
chat.Prompts = []PromptRequest{{
UserInput: prompt.Messages[0].Content,
Vendor: "",
Model: "",
ContextName: "",
PatternName: strings.Split(prompt.Model, ":")[0],
}}
} else if len(prompt.Messages) > 1 {
var content string
for _, msg := range prompt.Messages {
content = fmt.Sprintf("%s%s:%s\n", content, msg.Role, msg.Content)
}
chat.Prompts = []PromptRequest{{
UserInput: content,
Vendor: "",
Model: "",
ContextName: "",
PatternName: strings.Split(prompt.Model, ":")[0],
}}
}
fabricChatReq, err := json.Marshal(chat)
if err != nil {
log.Printf("Error marshalling body: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
return
}
ctx := context.Background()
var req *http.Request
if strings.Contains(*f.addr, "http") {
req, err = http.NewRequest("POST", fmt.Sprintf("%s/chat", *f.addr), bytes.NewBuffer(fabricChatReq))
} else {
req, err = http.NewRequest("POST", fmt.Sprintf("http://127.0.0.1%s/chat", *f.addr), bytes.NewBuffer(fabricChatReq))
}
if err != nil {
log.Fatal(err)
}

req = req.WithContext(ctx)

fabricRes, err := http.DefaultClient.Do(req)
if err != nil {
log.Printf("Error getting /chat body: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
return
}
body, err = io.ReadAll(fabricRes.Body)
if err != nil {
log.Printf("Error reading body: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"})
return
}
var forwardedResponse OllamaResponse
var forwardedResponses []OllamaResponse
var fabricResponse FabricResponseFormat
err = json.Unmarshal([]byte(strings.Split(strings.Split(string(body), "\n")[0], "data: ")[1]), &fabricResponse)
if err != nil {
log.Printf("Error unmarshalling body: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"})
return
}
for _, word := range strings.Split(fabricResponse.Content, " ") {
forwardedResponse = OllamaResponse{
Model: "",
CreatedAt: "",
Message: struct {
Role string `json:"role"`
Content string `json:"content"`
}(struct {
Role string
Content string
}{Content: fmt.Sprintf("%s ", word), Role: "assistant"}),
Done: false,
}
forwardedResponses = append(forwardedResponses, forwardedResponse)
}
forwardedResponse.Model = prompt.Model
forwardedResponse.CreatedAt = time.Now().UTC().Format("2006-01-02T15:04:05.999999999Z")
forwardedResponse.Message.Role = "assistant"
forwardedResponse.Message.Content = ""
forwardedResponse.DoneReason = "stop"
forwardedResponse.Done = true
forwardedResponse.TotalDuration = time.Since(now).Nanoseconds()
forwardedResponse.LoadDuration = int(time.Since(now).Nanoseconds())
forwardedResponse.PromptEvalCount = 42
forwardedResponse.PromptEvalDuration = int(time.Since(now).Nanoseconds())
forwardedResponse.EvalCount = 420
forwardedResponse.EvalDuration = time.Since(now).Nanoseconds()
forwardedResponses = append(forwardedResponses, forwardedResponse)

var res []byte
for _, response := range forwardedResponses {
marshalled, err := json.Marshal(response)
if err != nil {
log.Printf("Error marshalling body: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
return
}
for _, bytein := range marshalled {
res = append(res, bytein)
}
for _, bytebreak := range []byte("\n") {
res = append(res, bytebreak)
}
}
c.Data(200, "application/json", res)

//c.JSON(200, forwardedResponse)
return
}

0 comments on commit 65285fd

Please sign in to comment.