diff --git a/tests/unit_tests/services/gpt4all.py b/tests/unit_tests/services/gpt4all.py new file mode 100644 index 0000000..e139b4e --- /dev/null +++ b/tests/unit_tests/services/gpt4all.py @@ -0,0 +1,43 @@ +import unittest +from unittest.mock import patch, MagicMock +from gpt4all_service import GPT4AllService +from omegaconf import DictConfig + +class TestGPT4AllService(unittest.TestCase): + + def setUp(self): + self.cfg_mock = MagicMock(spec=DictConfig) + self.collection_name = "test_collection" + self.gpt4all_service = GPT4AllService(cfg=self.cfg_mock, collection_name=self.collection_name) + + def test_init(self): + self.assertEqual(self.gpt4all_service.collection_name, self.collection_name) + # Add more assertions as needed + + @patch('gpt4all_service.generate_collection') + def test_create_collection(self, generate_collection_mock): + collection_name = "new_collection" + self.assertTrue(self.gpt4all_service.create_collection(name=collection_name)) + generate_collection_mock.assert_called_once_with(collection_name, self.cfg_mock.gpt4all_embeddings.size) + + @patch('gpt4all_service.DirectoryLoader') + @patch('gpt4all_service.logger') + def test_embed_documents(self, logger_mock, directory_loader_mock): + directory_loader_mock.return_value.load_and_split.return_value = [] + self.gpt4all_service.embed_documents(directory="dummy_path", file_ending=".pdf") + directory_loader_mock.assert_called_once() + logger_mock.info.assert_called() + + @patch('gpt4all_service.GPT4All') + @patch('gpt4all_service.generate_prompt') + def test_summarize_text(self, generate_prompt_mock, gpt4all_mock): + generate_prompt_mock.return_value = "Mock Prompt" + gpt4all_instance = gpt4all_mock.return_value + gpt4all_instance.generate.return_value = "Mock Summary" + summary = self.gpt4all_service.summarize_text(text="Dummy text") + generate_prompt_mock.assert_called_once() + gpt4all_instance.generate.assert_called_once_with("Mock Prompt", max_tokens=300) + self.assertEqual(summary, "Mock Summary") + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit_tests/services/openai.py b/tests/unit_tests/services/openai.py new file mode 100644 index 0000000..f461463 --- /dev/null +++ b/tests/unit_tests/services/openai.py @@ -0,0 +1,49 @@ +import unittest +from unittest.mock import patch, MagicMock +from open_ai_service import OpenAIService +from omegaconf import DictConfig + +class TestOpenAIService(unittest.TestCase): + + def setUp(self): + self.cfg_mock = MagicMock(spec=DictConfig) + self.collection_name = "test_collection" + self.openai_service = OpenAIService(cfg=self.cfg_mock, collection_name=self.collection_name) + + def test_init(self): + # Test initialization logic here + self.assertEqual(self.openai_service.collection_name, self.collection_name) + # Add more assertions as needed + + @patch('open_ai_service.generate_collection') + @patch('open_ai_service.logger') + def test_create_collection(self, logger_mock, generate_collection_mock): + collection_name = "new_collection" + self.assertTrue(self.openai_service.create_collection(name=collection_name)) + generate_collection_mock.assert_called_once_with(collection_name, self.cfg_mock.openai_embeddings.size) + logger_mock.info.assert_called_once() + + @patch('open_ai_service.DirectoryLoader') + @patch('open_ai_service.logger') + def test_embed_documents(self, logger_mock, directory_loader_mock): + # Mock the loader and its return value + directory_loader_mock.return_value.load_and_split.return_value = [] + self.openai_service.embed_documents(directory="dummy_path", file_ending=".pdf") + directory_loader_mock.assert_called_once() + logger_mock.info.assert_called() + + @patch('open_ai_service.openai') + @patch('open_ai_service.generate_prompt') + def test_summarize_text(self, generate_prompt_mock, openai_mock): + # Setup mock return values + generate_prompt_mock.return_value = "Mock Prompt" + openai_mock.api_key = "dummy_key" + openai_mock.chat.completions.create.return_value = MagicMock(choices=[MagicMock(messages=MagicMock(content="Mock Summary"))]) + + summary = self.openai_service.summarize_text(text="Dummy text") + generate_prompt_mock.assert_called_once() + openai_mock.chat.completions.create.assert_called_once() + self.assertEqual(summary, "Mock Summary") + +if __name__ == '__main__': + unittest.main()