From 4be1c2e123b893470bbef469eedf0f941569f2d2 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Thu, 20 Jun 2024 13:54:34 -0500 Subject: [PATCH] Add basic unit test --- go/plugins/ollama/ollama_test.go | 114 +++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 go/plugins/ollama/ollama_test.go diff --git a/go/plugins/ollama/ollama_test.go b/go/plugins/ollama/ollama_test.go new file mode 100644 index 000000000..a0b61f2b5 --- /dev/null +++ b/go/plugins/ollama/ollama_test.go @@ -0,0 +1,114 @@ +package ollama + +import ( + "testing" + + "github.com/firebase/genkit/go/ai" +) + +func TestConcatMessages(t *testing.T) { + tests := []struct { + name string + messages []*ai.Message + role ai.Role + want string + }{ + { + name: "Single message with matching role", + messages: []*ai.Message{ + { + Role: ai.RoleUser, + Content: []*ai.Part{ai.NewTextPart("Hello, how are you?")}, + }, + }, + role: ai.RoleUser, + want: "Hello, how are you?", + }, + { + name: "Multiple messages with mixed roles", + messages: []*ai.Message{ + { + Role: ai.RoleUser, + Content: []*ai.Part{ai.NewTextPart("Tell me a joke.")}, + }, + { + Role: ai.RoleModel, + Content: []*ai.Part{ai.NewTextPart("Why did the scarecrow win an award? Because he was outstanding in his field!")}, + }, + }, + role: ai.RoleModel, + want: "Why did the scarecrow win an award? Because he was outstanding in his field!", + }, + { + name: "No matching role", + messages: []*ai.Message{ + { + Role: ai.RoleUser, + Content: []*ai.Part{ai.NewTextPart("Any suggestions?")}, + }, + }, + role: ai.RoleSystem, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + input := &ai.GenerateRequest{Messages: tt.messages} + got := concatMessages(input, tt.role) + if got != tt.want { + t.Errorf("concatMessages() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestTranslateGenerateChunk(t *testing.T) { + tests := []struct { + name string + input string + want *ai.GenerateResponseChunk + wantErr bool + }{ + { + name: "Valid JSON response", + input: `{"model": "my-model", "created_at": "2024-06-20T12:34:56Z", "response": "This is a test response."}`, + want: &ai.GenerateResponseChunk{ + Content: []*ai.Part{ai.NewTextPart("This is a test response.")}, + }, + wantErr: false, + }, + { + name: "Invalid JSON", + input: `{invalid}`, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := translateGenerateChunk(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("translateGenerateChunk() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && !equalContent(got.Content, tt.want.Content) { + t.Errorf("translateGenerateChunk() got = %v, want %v", got, tt.want) + } + }) + } +} + +// Helper function to compare content +func equalContent(a, b []*ai.Part) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i].Text != b[i].Text { + return false + } + } + return true +}