-
Notifications
You must be signed in to change notification settings - Fork 150
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
da4006f
commit 4be1c2e
Showing
1 changed file
with
114 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
package ollama | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/firebase/genkit/go/ai" | ||
) | ||
|
||
func TestConcatMessages(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
messages []*ai.Message | ||
role ai.Role | ||
want string | ||
}{ | ||
{ | ||
name: "Single message with matching role", | ||
messages: []*ai.Message{ | ||
{ | ||
Role: ai.RoleUser, | ||
Content: []*ai.Part{ai.NewTextPart("Hello, how are you?")}, | ||
}, | ||
}, | ||
role: ai.RoleUser, | ||
want: "Hello, how are you?", | ||
}, | ||
{ | ||
name: "Multiple messages with mixed roles", | ||
messages: []*ai.Message{ | ||
{ | ||
Role: ai.RoleUser, | ||
Content: []*ai.Part{ai.NewTextPart("Tell me a joke.")}, | ||
}, | ||
{ | ||
Role: ai.RoleModel, | ||
Content: []*ai.Part{ai.NewTextPart("Why did the scarecrow win an award? Because he was outstanding in his field!")}, | ||
}, | ||
}, | ||
role: ai.RoleModel, | ||
want: "Why did the scarecrow win an award? Because he was outstanding in his field!", | ||
}, | ||
{ | ||
name: "No matching role", | ||
messages: []*ai.Message{ | ||
{ | ||
Role: ai.RoleUser, | ||
Content: []*ai.Part{ai.NewTextPart("Any suggestions?")}, | ||
}, | ||
}, | ||
role: ai.RoleSystem, | ||
want: "", | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
input := &ai.GenerateRequest{Messages: tt.messages} | ||
got := concatMessages(input, tt.role) | ||
if got != tt.want { | ||
t.Errorf("concatMessages() = %q, want %q", got, tt.want) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestTranslateGenerateChunk(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
input string | ||
want *ai.GenerateResponseChunk | ||
wantErr bool | ||
}{ | ||
{ | ||
name: "Valid JSON response", | ||
input: `{"model": "my-model", "created_at": "2024-06-20T12:34:56Z", "response": "This is a test response."}`, | ||
want: &ai.GenerateResponseChunk{ | ||
Content: []*ai.Part{ai.NewTextPart("This is a test response.")}, | ||
}, | ||
wantErr: false, | ||
}, | ||
{ | ||
name: "Invalid JSON", | ||
input: `{invalid}`, | ||
want: nil, | ||
wantErr: true, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
got, err := translateGenerateChunk(tt.input) | ||
if (err != nil) != tt.wantErr { | ||
t.Errorf("translateGenerateChunk() error = %v, wantErr %v", err, tt.wantErr) | ||
return | ||
} | ||
if !tt.wantErr && !equalContent(got.Content, tt.want.Content) { | ||
t.Errorf("translateGenerateChunk() got = %v, want %v", got, tt.want) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
// Helper function to compare content | ||
func equalContent(a, b []*ai.Part) bool { | ||
if len(a) != len(b) { | ||
return false | ||
} | ||
for i := range a { | ||
if a[i].Text != b[i].Text { | ||
return false | ||
} | ||
} | ||
return true | ||
} |