Skip to content

Commit

Permalink
Add basic unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
huangjeff5 committed Jun 20, 2024
1 parent da4006f commit 4be1c2e
Showing 1 changed file with 114 additions and 0 deletions.
114 changes: 114 additions & 0 deletions go/plugins/ollama/ollama_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package ollama

import (
"testing"

"github.com/firebase/genkit/go/ai"
)

func TestConcatMessages(t *testing.T) {
tests := []struct {
name string
messages []*ai.Message
role ai.Role
want string
}{
{
name: "Single message with matching role",
messages: []*ai.Message{
{
Role: ai.RoleUser,
Content: []*ai.Part{ai.NewTextPart("Hello, how are you?")},
},
},
role: ai.RoleUser,
want: "Hello, how are you?",
},
{
name: "Multiple messages with mixed roles",
messages: []*ai.Message{
{
Role: ai.RoleUser,
Content: []*ai.Part{ai.NewTextPart("Tell me a joke.")},
},
{
Role: ai.RoleModel,
Content: []*ai.Part{ai.NewTextPart("Why did the scarecrow win an award? Because he was outstanding in his field!")},
},
},
role: ai.RoleModel,
want: "Why did the scarecrow win an award? Because he was outstanding in his field!",
},
{
name: "No matching role",
messages: []*ai.Message{
{
Role: ai.RoleUser,
Content: []*ai.Part{ai.NewTextPart("Any suggestions?")},
},
},
role: ai.RoleSystem,
want: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
input := &ai.GenerateRequest{Messages: tt.messages}
got := concatMessages(input, tt.role)
if got != tt.want {
t.Errorf("concatMessages() = %q, want %q", got, tt.want)
}
})
}
}

func TestTranslateGenerateChunk(t *testing.T) {
tests := []struct {
name string
input string
want *ai.GenerateResponseChunk
wantErr bool
}{
{
name: "Valid JSON response",
input: `{"model": "my-model", "created_at": "2024-06-20T12:34:56Z", "response": "This is a test response."}`,
want: &ai.GenerateResponseChunk{
Content: []*ai.Part{ai.NewTextPart("This is a test response.")},
},
wantErr: false,
},
{
name: "Invalid JSON",
input: `{invalid}`,
want: nil,
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := translateGenerateChunk(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("translateGenerateChunk() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && !equalContent(got.Content, tt.want.Content) {
t.Errorf("translateGenerateChunk() got = %v, want %v", got, tt.want)
}
})
}
}

// Helper function to compare content
func equalContent(a, b []*ai.Part) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].Text != b[i].Text {
return false
}
}
return true
}

0 comments on commit 4be1c2e

Please sign in to comment.