Skip to content

Commit

Permalink
Turn off streaming code for now
Browse files Browse the repository at this point in the history
  • Loading branch information
huangjeff5 committed Jun 14, 2024
1 parent c59ea6d commit 61d4998
Showing 1 changed file with 60 additions and 54 deletions.
114 changes: 60 additions & 54 deletions go/plugins/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@
package ollama

import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"time"

Expand All @@ -31,14 +29,20 @@ import (

const provider = "ollama"

var roleMapping = map[ai.Role]string{
ai.RoleUser: "user",
ai.RoleModel: "assistant",
ai.RoleSystem: "system",
}

func defineModel(name string, serverAddress string) {
meta := &ai.ModelMetadata{
Label: "Ollama - " + name,
Supports: ai.ModelCapabilities{
Multiturn: true,
},
}
g := generator{model: name, serverAddress: serverAddress}
g := &generator{model: name, serverAddress: serverAddress}
ai.DefineModel(provider, name, meta, g.generate)
}

Expand All @@ -50,7 +54,7 @@ func Model(name string) *ai.ModelAction {

// Config provides configuration options for the Init function.
type Config struct {
// API key. Required.
// Server Address of oLLama.
ServerAddress string
// Generative models to provide.
Models []string
Expand Down Expand Up @@ -80,16 +84,22 @@ stream: if false the response will be returned as a single response object, rath
raw: if true no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API
keep_alive: controls how long the model will stay loaded into memory following the request (default: 5m)
*/

type ollamaMessage struct {
Role string // json:"role"
Content string // json:"content"
}

type ollamaRequest struct {
Messages []map[string]string `json:"messages"`
Model string `json:"model"`
Stream bool `json:"stream"`
Messages []*ollamaMessage `json:"messages"`
Model string `json:"model"`
Stream bool `json:"stream"`
}

// Generate makes a request to the Ollama API and processes the response.
func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb func(context.Context, *ai.GenerateResponseChunk) error) (*ai.GenerateResponse, error) {
// Step 1: Combine parts from all messages into a single payload slice
var messages []map[string]string
var messages []*ollamaMessage

// Add all messages to history field.
for _, m := range input.Messages {
Expand Down Expand Up @@ -136,65 +146,49 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
}
return response, nil
} else {
// Handle streaming response here
scanner := bufio.NewScanner(resp.Body) // Create a scanner to read lines
for scanner.Scan() {
line := scanner.Text()

chunk, err := translateChunk(line)
if err != nil {
// Handle parsing error (log, maybe send an error candidate?)
return nil, fmt.Errorf("error translating chunk: %v", err)
// TODO: Handle Streaming
/*
// Handle streaming response here
scanner := bufio.NewScanner(resp.Body) // Create a scanner to read lines
for scanner.Scan() {
line := scanner.Text()
chunk, err := translateChunk(line)
if err != nil {
// Handle parsing error (log, maybe send an error candidate?)
return nil, fmt.Errorf("error translating chunk: %v", err)
}
cb(ctx, chunk)
}
cb(ctx, chunk)
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading stream: %v", err)
}
// Handle end of stream (optional: send a final candidate to signal completion)
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading stream: %v", err)
}
// TODO: Handle end of stream (optional: send a final candidate to signal completion)
*/
}
//Return an empty generate response, since we use callback for streaming
return &ai.GenerateResponse{}, nil
}

// convertParts serializes a slice of *ai.Part into a a map (represent Ollama message type)
func convertParts(role ai.Role, parts []*ai.Part) (map[string]string, error) {
roleMapping := map[ai.Role]string{
ai.RoleUser: "user",
ai.RoleModel: "assistant",
ai.RoleSystem: "system",
// Add more mappings as needed
// convertParts serializes a slice of *ai.Part into an ollamaMessage (represents Ollama message type)
func convertParts(role ai.Role, parts []*ai.Part) (*ollamaMessage, error) {
// Initialize the message with the correct role from the mapping
message := &ollamaMessage{
Role: roleMapping[role],
}

partMap := map[string]string{}
// Concatenate content from all parts
for _, part := range parts {
partMap["role"] = roleMapping[role]
switch {
case part.IsText():
partMap["content"] = part.Text
default:
if part.IsText() {
message.Content += part.Text
} else {
return nil, errors.New("unknown content type")
}
}
return partMap, nil
return message, nil
}

func translateChunk(input string) (*ai.GenerateResponseChunk, error) {
log.Printf("translating chunk")
var response GenerateResponse

if err := json.Unmarshal([]byte(input), &response); err != nil {
return nil, fmt.Errorf("error parsing response JSON: %v", err)
}
chunk := &ai.GenerateResponseChunk{
Index: 0,
Content: make([]*ai.Part, 0, 1),
}
return chunk, nil
}

type GenerateResponse struct {
type generateResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Message struct {
Expand All @@ -205,7 +199,7 @@ type GenerateResponse struct {

// translateResponse deserializes a JSON response from the Ollama API into a GenerateResponse.
func translateResponse(responseData []byte) (*ai.GenerateResponse, error) {
var response GenerateResponse
var response generateResponse

if err := json.Unmarshal(responseData, &response); err != nil {
return nil, fmt.Errorf("error parsing response JSON: %v", err)
Expand All @@ -226,3 +220,15 @@ func translateResponse(responseData []byte) (*ai.GenerateResponse, error) {
generateResponse.Candidates = append(generateResponse.Candidates, aiCandidate)
return generateResponse, nil
}

func translateChunk(input string) (*ai.GenerateResponseChunk, error) {
var response generateResponse

if err := json.Unmarshal([]byte(input), &response); err != nil {
return nil, fmt.Errorf("error parsing response JSON: %v", err)
}
chunk := &ai.GenerateResponseChunk{}
aiPart := ai.NewTextPart(response.Message.Content)
chunk.Content = append(chunk.Content, aiPart)
return chunk, nil
}

0 comments on commit 61d4998

Please sign in to comment.