diff --git a/tests/test_csv_scraper_multi_graph.py b/tests/test_csv_scraper_multi_graph.py new file mode 100644 index 00000000..294c2d70 --- /dev/null +++ b/tests/test_csv_scraper_multi_graph.py @@ -0,0 +1,86 @@ +import unittest + +from scrapegraphai.graphs.csv_scraper_graph import CSVScraperGraph +from scrapegraphai.graphs.csv_scraper_multi_graph import CSVScraperMultiGraph +from unittest.mock import MagicMock, patch + +class TestCSVScraperMultiGraph(unittest.TestCase): + def test_create_graph_structure_and_run(self): + """ + Test if the CSVScraperMultiGraph creates the correct graph structure + with GraphIteratorNode and MergeAnswersNode, initializes properly, + and executes the run method correctly. + """ + prompt = "Test prompt" + source = ["url1", "url2"] + config = { + "llm": { + "model": "test-model", + "model_provider": "openai", + "temperature": 0 # Adding temperature to match the actual implementation + }, + "embedder": {"model": "test-embedder"}, + "headless": True, + "verbose": False, + "model_token": 1000 + } + + with patch('scrapegraphai.graphs.csv_scraper_multi_graph.GraphIteratorNode') as mock_iterator_node, \ + patch('scrapegraphai.graphs.csv_scraper_multi_graph.MergeAnswersNode') as mock_merge_node, \ + patch('scrapegraphai.graphs.csv_scraper_multi_graph.BaseGraph') as mock_base_graph, \ + patch('scrapegraphai.graphs.abstract_graph.AbstractGraph._create_llm') as mock_create_llm: + + # Mock the _create_llm method to return a MagicMock + mock_llm = MagicMock() + mock_create_llm.return_value = mock_llm + + csv_scraper_multi_graph = CSVScraperMultiGraph(prompt, source, config) + + # Check if GraphIteratorNode is created with correct parameters + mock_iterator_node.assert_called_once_with( + input="user_prompt & jsons", + output=["results"], + node_config={ + "graph_instance": CSVScraperGraph, + "scraper_config": csv_scraper_multi_graph.copy_config, + } + ) + + # Check if MergeAnswersNode is created with correct parameters + mock_merge_node.assert_called_once_with( + input="user_prompt & results", + output=["answer"], + node_config={"llm_model": mock_llm, "schema": csv_scraper_multi_graph.copy_schema} + ) + + # Check if BaseGraph is created with correct structure + mock_base_graph.assert_called_once() + graph_args = mock_base_graph.call_args[1] + self.assertEqual(len(graph_args['nodes']), 2) + self.assertEqual(len(graph_args['edges']), 1) + self.assertEqual(graph_args['entry_point'], mock_iterator_node.return_value) + self.assertEqual(graph_args['graph_name'], "CSVScraperMultiGraph") + + # Check if the graph attribute is set correctly + self.assertIsInstance(csv_scraper_multi_graph.graph, MagicMock) + + # Check if the prompt and source are set correctly + self.assertEqual(csv_scraper_multi_graph.prompt, prompt) + self.assertEqual(csv_scraper_multi_graph.source, source) + + # Check if the config is copied correctly + self.assertDictEqual(csv_scraper_multi_graph.copy_config, config) + + # Test the run method + mock_execute = MagicMock(return_value=({"answer": "Test answer"}, {})) + csv_scraper_multi_graph.graph.execute = mock_execute + + result = csv_scraper_multi_graph.run() + + mock_execute.assert_called_once_with({"user_prompt": prompt, "jsons": source}) + self.assertEqual(result, "Test answer") + + # Test the case when no answer is found + mock_execute.return_value = ({}, {}) + result = csv_scraper_multi_graph.run() + self.assertEqual(result, "No answer found.") \ No newline at end of file diff --git a/tests/test_json_scraper_multi_graph.py b/tests/test_json_scraper_multi_graph.py new file mode 100644 index 00000000..78c02277 --- /dev/null +++ b/tests/test_json_scraper_multi_graph.py @@ -0,0 +1,56 @@ +from scrapegraphai.graphs.json_scraper_multi_graph import JSONScraperMultiGraph +from unittest import TestCase +from unittest.mock import MagicMock, patch + +class TestJSONScraperMultiGraph(TestCase): + def test_empty_source_list(self): + """ + Test that JSONScraperMultiGraph handles an empty source list gracefully. + This test ensures that the graph is created correctly with an empty list of sources + and returns a default message when run with no results. + """ + prompt = "Test prompt" + empty_source = [] + config = { + "llm": { + "model": "test_model", + "model_provider": "test_provider" + } + } + + with patch('scrapegraphai.graphs.json_scraper_multi_graph.BaseGraph') as mock_base_graph, \ + patch('scrapegraphai.graphs.json_scraper_multi_graph.GraphIteratorNode') as mock_graph_iterator_node, \ + patch('scrapegraphai.graphs.json_scraper_multi_graph.MergeAnswersNode') as mock_merge_answers_node, \ + patch('scrapegraphai.graphs.abstract_graph.AbstractGraph._create_llm') as mock_create_llm: + + # Create mock instances + mock_graph_instance = MagicMock() + mock_base_graph.return_value = mock_graph_instance + + # Mock the execute method to return a dictionary with no answer + mock_graph_instance.execute.return_value = ({"answer": "No answer found."}, {}) + + # Mock the _create_llm method + mock_create_llm.return_value = MagicMock() + + # Initialize the JSONScraperMultiGraph + graph = JSONScraperMultiGraph(prompt, empty_source, config) + + # Run the graph + result = graph.run() + + # Assert that the graph was created with the correct nodes + mock_graph_iterator_node.assert_called_once() + mock_merge_answers_node.assert_called_once() + + # Assert that BaseGraph was initialized with the correct parameters + mock_base_graph.assert_called_once() + _, kwargs = mock_base_graph.call_args + self.assertEqual(len(kwargs['nodes']), 2) + self.assertEqual(len(kwargs['edges']), 1) + + # Assert that the execute method was called with the correct inputs + mock_graph_instance.execute.assert_called_once_with({"user_prompt": prompt, "jsons": empty_source}) + + # Assert the result + self.assertEqual(result, "No answer found.") \ No newline at end of file diff --git a/tests/test_script_creator_graph.py b/tests/test_script_creator_graph.py new file mode 100644 index 00000000..a798aeac --- /dev/null +++ b/tests/test_script_creator_graph.py @@ -0,0 +1,317 @@ +import unittest + +from pydantic import BaseModel +from scrapegraphai.graphs.base_graph import BaseGraph +from scrapegraphai.graphs.script_creator_graph import ScriptCreatorGraph +from unittest.mock import MagicMock, patch + +class TestScriptCreatorGraph(unittest.TestCase): + def test_init_with_local_dir(self): + """ + Test initializing ScriptCreatorGraph with a local directory source. + This test verifies that the input_key is set correctly for local sources. + """ + prompt = "Generate a script to scrape local HTML files" + source = "/path/to/local/directory" + config = { + "library": "beautifulsoup", + "llm": {"model": "mock_model"} + } + + with patch('scrapegraphai.graphs.script_creator_graph.AbstractGraph.__init__') as mock_init: + graph = ScriptCreatorGraph(prompt, source, config) + + mock_init.assert_called_once_with(prompt, config, source, None) + self.assertEqual(graph.library, "beautifulsoup") + self.assertEqual(graph.input_key, "local_dir") + + @patch('scrapegraphai.graphs.abstract_graph.AbstractGraph.__init__') + @patch('scrapegraphai.graphs.base_graph.BaseGraph') + def test_run_method(self, mock_base_graph, mock_abstract_init): + """ + Test the run method of ScriptCreatorGraph. + This test verifies that the run method correctly executes the graph + and returns the expected answer for both successful and unsuccessful scenarios. + """ + # Setup + prompt = "Test prompt" + source = "https://example.com" + config = { + "library": "beautifulsoup", + "llm": {"model": "mock_model"} + } + + # Mock AbstractGraph.__init__ + mock_abstract_init.return_value = None + + # Create ScriptCreatorGraph instance + graph = ScriptCreatorGraph(prompt, source, config) + + # Set necessary attributes manually + graph.prompt = prompt + graph.source = source + graph.input_key = "url" + graph.model_token = 1000 + + # Mock BaseGraph instance + mock_graph_instance = MagicMock() + mock_base_graph.return_value = mock_graph_instance + graph.graph = mock_graph_instance + + # Test successful scenario + mock_graph_instance.execute.return_value = ({"answer": "Mocked answer"}, None) + result = graph.run() + self.assertEqual(result, "Mocked answer") + mock_graph_instance.execute.assert_called_once_with({ + "user_prompt": prompt, + "url": source + }) + + # Test unsuccessful scenario (no answer found) + mock_graph_instance.execute.return_value = ({}, None) + result = graph.run() + self.assertEqual(result, "No answer found ") + + # Verify that execute was called twice in total + self.assertEqual(mock_graph_instance.execute.call_count, 2) + + @patch('scrapegraphai.graphs.abstract_graph.AbstractGraph.__init__') + def test_run_method(self, mock_abstract_init): + """ + Test the run method of ScriptCreatorGraph. + This test verifies that: + 1. The run method correctly executes the graph and returns the expected answer. + 2. The input to the graph execution is correctly constructed. + 3. The method handles both successful and unsuccessful scenarios. + """ + # Setup + prompt = "Test prompt" + source = "https://example.com" + config = { + "library": "beautifulsoup", + "llm": {"model": "mock_model"} + } + + # Mock AbstractGraph.__init__ + mock_abstract_init.return_value = None + + # Create ScriptCreatorGraph instance + graph = ScriptCreatorGraph(prompt, source, config) + + # Set necessary attributes manually + graph.prompt = prompt + graph.source = source + graph.input_key = "url" + graph.model_token = 1000 + + # Mock graph execution + mock_base_graph = MagicMock() + graph.graph = mock_base_graph + + # Test successful scenario + mock_base_graph.execute.return_value = ({"answer": "Mocked answer"}, None) + result = graph.run() + + # Assertions for successful scenario + self.assertEqual(result, "Mocked answer") + mock_base_graph.execute.assert_called_once_with({ + "user_prompt": prompt, + "url": source + }) + + # Reset mock for next test + mock_base_graph.execute.reset_mock() + + # Test unsuccessful scenario (no answer found) + mock_base_graph.execute.return_value = ({}, None) + result = graph.run() + + # Assertions for unsuccessful scenario + self.assertEqual(result, "No answer found ") + mock_base_graph.execute.assert_called_once_with({ + "user_prompt": prompt, + "url": source + }) + + @patch('scrapegraphai.graphs.abstract_graph.AbstractGraph.__init__') + @patch('scrapegraphai.graphs.script_creator_graph.ScriptCreatorGraph._create_graph') + def test_run_method(self, mock_create_graph, mock_abstract_init): + """ + Test the run method of ScriptCreatorGraph. + This test verifies that: + 1. The ScriptCreatorGraph is initialized correctly. + 2. The run method executes the graph and returns the expected answer. + 3. The input to the graph execution is correctly constructed. + 4. The method handles both successful and unsuccessful scenarios. + """ + # Setup + prompt = "Test prompt" + source = "https://example.com" + config = { + "library": "beautifulsoup", + "llm": {"model": "mock_model"} + } + + # Mock AbstractGraph.__init__ + mock_abstract_init.return_value = None + + # Create ScriptCreatorGraph instance + graph = ScriptCreatorGraph(prompt, source, config) + + # Set necessary attributes manually + graph.prompt = prompt + graph.source = source + graph.input_key = "url" + graph.model_token = 1000 + + # Mock graph creation and execution + mock_base_graph = MagicMock() + mock_create_graph.return_value = mock_base_graph + graph.graph = mock_base_graph + + # Test successful scenario + mock_base_graph.execute.return_value = ({"answer": "Mocked answer"}, None) + result = graph.run() + + # Assertions for successful scenario + self.assertEqual(result, "Mocked answer") + mock_base_graph.execute.assert_called_once_with({ + "user_prompt": prompt, + "url": source + }) + + # Reset mock for next test + mock_base_graph.execute.reset_mock() + + # Test unsuccessful scenario (no answer found) + mock_base_graph.execute.return_value = ({}, None) + result = graph.run() + + # Assertions for unsuccessful scenario + self.assertEqual(result, "No answer found ") + mock_base_graph.execute.assert_called_once_with({ + "user_prompt": prompt, + "url": source + }) + + @patch('scrapegraphai.graphs.abstract_graph.AbstractGraph.__init__') + @patch('scrapegraphai.graphs.script_creator_graph.ScriptCreatorGraph._create_graph') + def test_run_method(self, mock_create_graph, mock_abstract_init): + """ + Test the run method of ScriptCreatorGraph. + This test verifies that: + 1. The ScriptCreatorGraph is initialized correctly. + 2. The run method executes the graph and returns the expected answer. + 3. The input to the graph execution is correctly constructed. + 4. The method handles both successful and unsuccessful scenarios. + """ + # Setup + prompt = "Test prompt" + source = "https://example.com" + config = { + "library": "beautifulsoup", + "llm": {"model": "mock_model"} + } + + # Mock AbstractGraph.__init__ + mock_abstract_init.return_value = None + + # Create ScriptCreatorGraph instance + graph = ScriptCreatorGraph(prompt, source, config) + + # Set necessary attributes manually + graph.prompt = prompt + graph.source = source + graph.input_key = "url" + graph.model_token = 1000 + + # Mock graph creation and execution + mock_base_graph = MagicMock() + mock_create_graph.return_value = mock_base_graph + graph.graph = mock_base_graph + + # Test successful scenario + mock_base_graph.execute.return_value = ({"answer": "Mocked answer"}, None) + result = graph.run() + + # Assertions for successful scenario + self.assertEqual(result, "Mocked answer") + mock_base_graph.execute.assert_called_once_with({ + "user_prompt": prompt, + "url": source + }) + + # Reset mock for next test + mock_base_graph.execute.reset_mock() + + # Test unsuccessful scenario (no answer found) + mock_base_graph.execute.return_value = ({}, None) + result = graph.run() + + # Assertions for unsuccessful scenario + self.assertEqual(result, "No answer found ") + mock_base_graph.execute.assert_called_once_with({ + "user_prompt": prompt, + "url": source + }) + + @patch('scrapegraphai.graphs.abstract_graph.AbstractGraph.__init__') + @patch('scrapegraphai.graphs.script_creator_graph.ScriptCreatorGraph._create_graph') + def test_run_method(self, mock_create_graph, mock_abstract_init): + """ + Test the run method of ScriptCreatorGraph. + This test verifies that: + 1. The ScriptCreatorGraph is initialized correctly. + 2. The run method executes the graph and returns the expected answer. + 3. The input to the graph execution is correctly constructed. + 4. The method handles both successful and unsuccessful scenarios. + """ + # Setup + prompt = "Test prompt" + source = "https://example.com" + config = { + "library": "beautifulsoup", + "llm": {"model": "mock_model"} + } + + # Mock AbstractGraph.__init__ + mock_abstract_init.return_value = None + + # Create ScriptCreatorGraph instance + graph = ScriptCreatorGraph(prompt, source, config) + + # Set necessary attributes manually + graph.prompt = prompt + graph.source = source + graph.input_key = "url" + graph.model_token = 1000 + + # Mock graph creation and execution + mock_base_graph = MagicMock() + mock_create_graph.return_value = mock_base_graph + graph.graph = mock_base_graph + + # Test successful scenario + mock_base_graph.execute.return_value = ({"answer": "Mocked answer"}, None) + result = graph.run() + + # Assertions for successful scenario + self.assertEqual(result, "Mocked answer") + mock_base_graph.execute.assert_called_once_with({ + "user_prompt": prompt, + "url": source + }) + + # Reset mock for next test + mock_base_graph.execute.reset_mock() + + # Test unsuccessful scenario (no answer found) + mock_base_graph.execute.return_value = ({}, None) + result = graph.run() + + # Assertions for unsuccessful scenario + self.assertEqual(result, "No answer found ") + mock_base_graph.execute.assert_called_once_with({ + "user_prompt": prompt, + "url": source + }) \ No newline at end of file diff --git a/tests/test_script_creator_multi_graph.py b/tests/test_script_creator_multi_graph.py new file mode 100644 index 00000000..cabd8f8e --- /dev/null +++ b/tests/test_script_creator_multi_graph.py @@ -0,0 +1,299 @@ +import pytest + +from copy import deepcopy +from pydantic import BaseModel +from scrapegraphai.graphs.base_graph import BaseGraph +from scrapegraphai.graphs.script_creator_multi_graph import ScriptCreatorMultiGraph +from unittest.mock import AsyncMock, MagicMock, patch + +class TestScriptCreatorMultiGraph: + @pytest.mark.asyncio + @patch('scrapegraphai.graphs.script_creator_multi_graph.GraphIteratorNode') + @patch('scrapegraphai.graphs.script_creator_multi_graph.MergeGeneratedScriptsNode') + async def test_run_with_empty_source(self, mock_merge_node, mock_iterator_node): + """ + Test the ScriptCreatorMultiGraph.run() method with an empty source list. + This test checks if the graph handles the case when no URLs are provided. + """ + # Arrange + prompt = "What is Chioggia famous for?" + source = [] + config = {"llm": {"model": "openai/gpt-3.5-turbo"}} + + # Mock the execute method of BaseGraph to return a predefined state + mock_state = {"merged_script": "No URLs provided, unable to generate script."} + with patch('scrapegraphai.graphs.script_creator_multi_graph.BaseGraph.execute', return_value=(mock_state, {})): + graph = ScriptCreatorMultiGraph(prompt, source, config) + + # Act + result = graph.run() + + # Assert + assert result == "No URLs provided, unable to generate script." + assert mock_iterator_node.call_count == 1 + assert mock_merge_node.call_count == 1 + + @pytest.mark.asyncio + @patch('scrapegraphai.graphs.script_creator_multi_graph.GraphIteratorNode') + @patch('scrapegraphai.graphs.script_creator_multi_graph.MergeGeneratedScriptsNode') + @patch('scrapegraphai.graphs.script_creator_multi_graph.BaseGraph.execute') + async def test_run_with_multiple_urls(self, mock_execute, mock_merge_node, mock_iterator_node): + """ + Test the ScriptCreatorMultiGraph.run() method with multiple URLs in the source list. + This test checks if the graph correctly processes multiple URLs and generates a merged script. + """ + # Arrange + prompt = "What are the main attractions in Venice and Chioggia?" + source = ["https://example.com/venice", "https://example.com/chioggia"] + config = {"llm": {"model": "openai/gpt-3.5-turbo"}} + + mock_state = {"merged_script": "Generated script for Venice and Chioggia attractions"} + mock_execute.return_value = (mock_state, {}) + + graph = ScriptCreatorMultiGraph(prompt, source, config) + + # Act + result = graph.run() + + # Assert + assert result == "Generated script for Venice and Chioggia attractions" + mock_execute.assert_called_once() + mock_iterator_node.assert_called_once() + mock_merge_node.assert_called_once() + + # Check if the correct inputs were passed to the execute method + expected_inputs = {"user_prompt": prompt, "urls": source} + actual_inputs = mock_execute.call_args[0][0] + assert actual_inputs == expected_inputs + + @pytest.mark.asyncio + @patch('scrapegraphai.graphs.script_creator_multi_graph.BaseGraph') + async def test_invalid_llm_configuration(self, mock_base_graph): + """ + Test the ScriptCreatorMultiGraph initialization with an invalid LLM configuration. + This test checks if the graph raises a ValueError when an unsupported LLM model is provided. + """ + # Arrange + prompt = "What is Chioggia famous for?" + source = ["https://example.com/chioggia"] + invalid_config = {"llm": {"model": "unsupported_model"}} + + # Act & Assert + with pytest.raises(ValueError, match="Unsupported LLM model"): + ScriptCreatorMultiGraph(prompt, source, invalid_config) + + # Ensure that BaseGraph was not instantiated due to the invalid configuration + mock_base_graph.assert_not_called() + + @pytest.mark.asyncio + @patch('scrapegraphai.graphs.script_creator_multi_graph.BaseGraph.execute') + async def test_run_with_execution_failure(self, mock_execute: BaseGraph.execute): + """ + Test the ScriptCreatorMultiGraph.run() method when graph execution fails. + This test checks if the method handles the failure gracefully and returns an error message. + """ + # Arrange + prompt = "What is Chioggia famous for?" + source = ["https://example.com/chioggia"] + config = {"llm": {"model": "openai/gpt-3.5-turbo"}} + + # Simulate a failure in graph execution + mock_execute.side_effect = Exception("Graph execution failed") + + graph = ScriptCreatorMultiGraph(prompt, source, config) + + # Act + result = graph.run() + + # Assert + assert result == "Failed to generate the script." + mock_execute.assert_called_once() + + # Check if the correct inputs were passed to the execute method + expected_inputs = {"user_prompt": prompt, "urls": source} + actual_inputs = mock_execute.call_args[0][0] + assert actual_inputs == expected_inputs + + @pytest.mark.asyncio + @patch('scrapegraphai.graphs.script_creator_multi_graph.GraphIteratorNode') + async def test_custom_schema_passed_to_graph_iterator(self, mock_graph_iterator_node): + """ + Test that a custom schema is correctly passed to the GraphIteratorNode + when initializing ScriptCreatorMultiGraph. + """ + # Arrange + class CustomSchema(BaseModel): + title: str + content: str + + prompt = "What is Chioggia famous for?" + source = ["https://example.com/chioggia"] + config = {"llm": {"model": "openai/gpt-3.5-turbo"}} + + # Act + graph = ScriptCreatorMultiGraph(prompt, source, config, schema=CustomSchema) + + # Assert + mock_graph_iterator_node.assert_called_once() + _, kwargs = mock_graph_iterator_node.call_args + assert kwargs['schema'] == CustomSchema + assert isinstance(graph.copy_schema, type(CustomSchema)) + + @pytest.mark.asyncio + async def test_config_and_schema_deep_copy(self): + """ + Test that the config and schema are properly deep copied during initialization + of ScriptCreatorMultiGraph. This ensures that modifications to the original + config or schema don't affect the internal state of the ScriptCreatorMultiGraph instance. + """ + # Arrange + class CustomSchema(BaseModel): + title: str + content: str + + prompt = "What is Chioggia famous for?" + source = ["https://example.com/chioggia"] + config = {"llm": {"model": "openai/gpt-3.5-turbo"}, "custom_key": {"nested": "value"}} + schema = CustomSchema + + # Act + graph = ScriptCreatorMultiGraph(prompt, source, config, schema=schema) + + # Assert + assert graph.copy_config == config + assert graph.copy_config is not config + assert graph.copy_schema == schema + assert graph.copy_schema is not schema + + # Modify original config and schema + config["custom_key"]["nested"] = "modified" + schema.update_forward_refs() + + # Check that the copied versions remain unchanged + assert graph.copy_config["custom_key"]["nested"] == "value" + assert not hasattr(graph.copy_schema, "update_forward_refs") + + @pytest.mark.asyncio + @patch('scrapegraphai.graphs.script_creator_multi_graph.BaseGraph.execute') + async def test_run_with_merge_failure(self, mock_execute): + """ + Test the ScriptCreatorMultiGraph.run() method when the MergeGeneratedScriptsNode fails to merge scripts. + This test checks if the method handles the failure gracefully and returns an error message + when the merged_script is not present in the final state. + """ + # Arrange + prompt = "What is Chioggia famous for?" + source = ["https://example.com/chioggia"] + config = {"llm": {"model": "openai/gpt-3.5-turbo"}} + + # Simulate a failure in merging scripts by returning a state without 'merged_script' + mock_execute.return_value = ({"some_other_key": "value"}, {}) + + graph = ScriptCreatorMultiGraph(prompt, source, config) + + # Act + result = graph.run() + + # Assert + assert result == "Failed to generate the script." + mock_execute.assert_called_once() + + # Check if the correct inputs were passed to the execute method + expected_inputs = {"user_prompt": prompt, "urls": source} + actual_inputs = mock_execute.call_args[0][0] + assert actual_inputs == expected_inputs + + @pytest.mark.asyncio + @patch('scrapegraphai.graphs.script_creator_multi_graph.GraphIteratorNode') + @patch('scrapegraphai.graphs.script_creator_multi_graph.MergeGeneratedScriptsNode') + @patch('scrapegraphai.graphs.script_creator_multi_graph.BaseGraph.execute') + async def test_run_with_empty_scripts_list(self, mock_execute, mock_merge_node, mock_iterator_node): + """ + Test the ScriptCreatorMultiGraph.run() method when the GraphIteratorNode returns an empty list of scripts. + This test checks if the graph handles the case when no scripts are generated from the input URLs. + """ + # Arrange + prompt = "What is Chioggia famous for?" + source = ["https://example.com/chioggia"] + config = {"llm": {"model": "openai/gpt-3.5-turbo"}} + + # Mock the GraphIteratorNode to return an empty list of scripts + mock_iterator_node.return_value.execute.return_value = ({"scripts": []}, {}) + + # Mock the MergeGeneratedScriptsNode to return a failure message + mock_merge_node.return_value.execute.return_value = ({"merged_script": "No scripts were generated."}, {}) + + # Mock the BaseGraph.execute to return the result of MergeGeneratedScriptsNode + mock_execute.return_value = ({"merged_script": "No scripts were generated."}, {}) + + graph = ScriptCreatorMultiGraph(prompt, source, config) + + # Act + result = graph.run() + + # Assert + assert result == "No scripts were generated." + mock_iterator_node.assert_called_once() + mock_merge_node.assert_called_once() + mock_execute.assert_called_once() + + # Check if MergeGeneratedScriptsNode was called with an empty list of scripts + merge_node_inputs = mock_merge_node.return_value.execute.call_args[0][0] + assert merge_node_inputs['scripts'] == [] + + @pytest.mark.asyncio + @patch('scrapegraphai.graphs.script_creator_multi_graph.BaseGraph') + @patch('scrapegraphai.graphs.script_creator_multi_graph.GraphIteratorNode') + @patch('scrapegraphai.graphs.script_creator_multi_graph.MergeGeneratedScriptsNode') + async def test_custom_embedder_model_configuration(self, mock_merge_node, mock_iterator_node, mock_base_graph): + """ + Test that a custom embedder model configuration is correctly passed to the graph nodes + when initializing ScriptCreatorMultiGraph. + """ + # Arrange + prompt = "What is Chioggia famous for?" + source = ["https://example.com/chioggia"] + config = { + "llm": {"model": "openai/gpt-3.5-turbo"}, + "embedder": {"model": "custom/embedder-model"} + } + + # Act + graph = ScriptCreatorMultiGraph(prompt, source, config) + + # Assert + mock_iterator_node.assert_called_once() + iterator_node_config = mock_iterator_node.call_args[1]['node_config'] + assert 'scraper_config' in iterator_node_config + assert iterator_node_config['scraper_config']['embedder']['model'] == "custom/embedder-model" + + mock_merge_node.assert_called_once() + merge_node_config = mock_merge_node.call_args[1]['node_config'] + assert 'llm_model' in merge_node_config + assert merge_node_config['llm_model']['model'] == "openai/gpt-3.5-turbo" + + mock_base_graph.assert_called_once() + + @pytest.mark.asyncio + async def test_custom_model_token_limit(self): + """ + Test that a custom model token limit is correctly set when initializing ScriptCreatorMultiGraph. + This test verifies that the model_token attribute is set correctly and that it's included in the copy_config. + """ + # Arrange + prompt = "What is Chioggia famous for?" + source = ["https://example.com/chioggia"] + custom_token_limit = 2000 + config = { + "llm": {"model": "openai/gpt-3.5-turbo"}, + "model_token": custom_token_limit + } + + # Act + graph = ScriptCreatorMultiGraph(prompt, source, config) + + # Assert + assert graph.model_token == custom_token_limit + assert graph.copy_config['model_token'] == custom_token_limit + assert graph.copy_config is not config + assert graph.copy_config == config # We use == instead of deepcopy for simplicity \ No newline at end of file diff --git a/tests/test_search_graph.py b/tests/test_search_graph.py new file mode 100644 index 00000000..4ebb6f72 --- /dev/null +++ b/tests/test_search_graph.py @@ -0,0 +1,41 @@ +import os +import unittest + +from scrapegraphai.graphs.search_graph import SearchGraph +from unittest.mock import MagicMock, patch + +class TestSearchGraph(unittest.TestCase): + @patch.dict(os.environ, {"OPENAI_API_KEY": "dummy_api_key"}) + @patch('scrapegraphai.graphs.base_graph.BaseGraph.execute') + def test_get_considered_urls(self, mock_execute): + """ + Test that get_considered_urls() returns the correct list of URLs after running the graph. + This test mocks the OpenAI API key and the graph execution to simulate the behavior. + """ + # Mock the configuration + config = { + "llm": {"model": "openai/gpt-3.5-turbo"}, + "max_results": 2 + } + + # Mock the execute method to return a predefined final state + mock_execute.return_value = ( + {"urls": ["https://example1.com", "https://example2.com"], "answer": "Chioggia is famous for its beaches."}, + {} + ) + + # Create a SearchGraph instance + search_graph = SearchGraph("What is Chioggia famous for?", config) + + # Run the graph + result = search_graph.run() + + # Check if the result is correct + self.assertEqual(result, "Chioggia is famous for its beaches.") + + # Check if get_considered_urls returns the correct list + considered_urls = search_graph.get_considered_urls() + self.assertEqual(considered_urls, ["https://example1.com", "https://example2.com"]) + + # Verify that the execute method was called + mock_execute.assert_called_once() \ No newline at end of file diff --git a/tests/test_smart_scraper_multi_concat_graph.py b/tests/test_smart_scraper_multi_concat_graph.py new file mode 100644 index 00000000..e384bf59 --- /dev/null +++ b/tests/test_smart_scraper_multi_concat_graph.py @@ -0,0 +1,81 @@ +import unittest + +from scrapegraphai.graphs.abstract_graph import AbstractGraph +from scrapegraphai.graphs.base_graph import BaseGraph +from scrapegraphai.graphs.smart_scraper_multi_concat_graph import SmartScraperMultiConcatGraph +from unittest import mock + +class TestSmartScraperMultiConcatGraph(unittest.TestCase): + @mock.patch.object(AbstractGraph, '_create_llm') + @mock.patch.object(BaseGraph, 'execute') + def test_concat_answers_when_results_less_than_or_equal_to_two(self, mock_execute, mock_create_llm): + """ + Test that the ConcatAnswersNode is used when the number of results + is less than or equal to 2. + """ + # Mock the _create_llm method to return a dummy LLM object + mock_create_llm.return_value = mock.MagicMock() + + # Mock the config and schema + mock_config = {"llm": {"model": "openai/gpt-3.5-turbo"}} + mock_schema = None + + # Mock the BaseGraph execute method to return a predefined result + mock_execute.return_value = ({"answer": "Concatenated answer"}, {}) + + # Create an instance of SmartScraperMultiConcatGraph + graph = SmartScraperMultiConcatGraph( + prompt="Test prompt", + source=["http://example1.com", "http://example2.com"], + config=mock_config, + schema=mock_schema + ) + + # Run the graph + result = graph.run() + + # Assert that the result is the concatenated answer + self.assertEqual(result, "Concatenated answer") + + # Verify that the execute method was called with the correct inputs + mock_execute.assert_called_once_with({ + "user_prompt": "Test prompt", + "urls": ["http://example1.com", "http://example2.com"] + }) + + @mock.patch.object(AbstractGraph, '_create_llm') + @mock.patch.object(BaseGraph, 'execute') + def test_merge_answers_when_results_more_than_two(self, mock_execute, mock_create_llm): + """ + Test that the MergeAnswersNode is used when the number of results + is more than 2. + """ + # Mock the _create_llm method to return a dummy LLM object + mock_create_llm.return_value = mock.MagicMock() + + # Mock the config and schema + mock_config = {"llm": {"model": "openai/gpt-3.5-turbo"}} + mock_schema = None + + # Mock the BaseGraph execute method to return a predefined result + mock_execute.return_value = ({"answer": "Merged answer"}, {}) + + # Create an instance of SmartScraperMultiConcatGraph + graph = SmartScraperMultiConcatGraph( + prompt="Test prompt", + source=["http://example1.com", "http://example2.com", "http://example3.com"], + config=mock_config, + schema=mock_schema + ) + + # Run the graph + result = graph.run() + + # Assert that the result is the merged answer + self.assertEqual(result, "Merged answer") + + # Verify that the execute method was called with the correct inputs + mock_execute.assert_called_once_with({ + "user_prompt": "Test prompt", + "urls": ["http://example1.com", "http://example2.com", "http://example3.com"] + }) \ No newline at end of file diff --git a/tests/test_speech_graph.py b/tests/test_speech_graph.py new file mode 100644 index 00000000..76ccb81f --- /dev/null +++ b/tests/test_speech_graph.py @@ -0,0 +1,220 @@ +import unittest + +from pydantic import BaseModel +from scrapegraphai.graphs.speech_graph import SpeechGraph +from scrapegraphai.models.openai_tts import OpenAITextToSpeech +from unittest.mock import MagicMock, patch + +class TestSpeechGraph(unittest.TestCase): + @patch('scrapegraphai.graphs.speech_graph.BaseGraph') + @patch('scrapegraphai.graphs.abstract_graph.init_chat_model') + @patch('scrapegraphai.models.openai_tts.OpenAI') + def test_speech_graph_initialization(self, mock_openai, mock_init_chat_model, mock_base_graph): + """ + Test the initialization of SpeechGraph with both URL and local directory sources. + This test covers the scenario where input_key is set correctly based on the source, + and ensures that the graph is created with the proper configuration. + """ + # Arrange + prompt = "Summarize the contents" + config = { + "llm": {"model": "openai/gpt-3.5-turbo"}, + "tts_model": { + "api_key": "test_api_key", + "base_url": "https://api.openai.com/v1" + } + } + + class TestSchema(BaseModel): + summary: str + + # Mock the LLM initialization + mock_llm = MagicMock() + mock_init_chat_model.return_value = mock_llm + + # Test with URL source + url_source = "https://example.com" + url_speech_graph = SpeechGraph(prompt, url_source, config, TestSchema) + + # Assert for URL source + self.assertEqual(url_speech_graph.input_key, "url") + self.assertEqual(url_speech_graph.prompt, prompt) + self.assertEqual(url_speech_graph.source, url_source) + self.assertEqual(url_speech_graph.config, config) + self.assertEqual(url_speech_graph.schema, TestSchema) + + # Test with local directory source + local_source = "/path/to/local/directory" + local_speech_graph = SpeechGraph(prompt, local_source, config, TestSchema) + + # Assert for local directory source + self.assertEqual(local_speech_graph.input_key, "local_dir") + self.assertEqual(local_speech_graph.prompt, prompt) + self.assertEqual(local_speech_graph.source, local_source) + self.assertEqual(local_speech_graph.config, config) + self.assertEqual(local_speech_graph.schema, TestSchema) + + # Verify that _create_graph was called for both instances + self.assertEqual(mock_base_graph.call_count, 2) + + # Verify that OpenAI client was initialized with the correct configuration + mock_openai.assert_called_with(api_key="test_api_key", base_url="https://api.openai.com/v1") + + # Verify that the graph attribute is set for both instances + self.assertIsNotNone(url_speech_graph.graph) + self.assertIsNotNone(local_speech_graph.graph) + + @patch('scrapegraphai.graphs.speech_graph.BaseGraph') + @patch('scrapegraphai.graphs.abstract_graph.init_chat_model') + @patch('scrapegraphai.models.openai_tts.OpenAI') + def test_speech_graph_initialization(self, mock_openai, mock_init_chat_model, mock_base_graph): + """ + Test the initialization of SpeechGraph with both URL and local directory sources. + This test covers the scenario where input_key is set correctly based on the source, + ensures that the graph is created with the proper configuration, and verifies + that the OpenAI client for text-to-speech is initialized correctly. + """ + # Arrange + prompt = "Summarize the contents" + config = { + "llm": {"model": "openai/gpt-3.5-turbo"}, + "tts_model": { + "api_key": "test_api_key", + "base_url": "https://api.openai.com/v1" + } + } + + class TestSchema(BaseModel): + summary: str + + # Mock the LLM initialization + mock_llm = MagicMock() + mock_init_chat_model.return_value = mock_llm + + # Test with URL source + url_source = "https://example.com" + url_speech_graph = SpeechGraph(prompt, url_source, config, TestSchema) + + # Assert for URL source + self.assertEqual(url_speech_graph.input_key, "url") + self.assertEqual(url_speech_graph.prompt, prompt) + self.assertEqual(url_speech_graph.source, url_source) + self.assertEqual(url_speech_graph.config, config) + self.assertEqual(url_speech_graph.schema, TestSchema) + + # Test with local directory source + local_source = "/path/to/local/directory" + local_speech_graph = SpeechGraph(prompt, local_source, config, TestSchema) + + # Assert for local directory source + self.assertEqual(local_speech_graph.input_key, "local_dir") + self.assertEqual(local_speech_graph.prompt, prompt) + self.assertEqual(local_speech_graph.source, local_source) + self.assertEqual(local_speech_graph.config, config) + self.assertEqual(local_speech_graph.schema, TestSchema) + + # Verify that _create_graph was called for both instances + self.assertEqual(mock_base_graph.call_count, 2) + + # Verify that OpenAI client was initialized with the correct configuration + mock_openai.assert_called_with(api_key="test_api_key", base_url="https://api.openai.com/v1") + + # Verify that the graph attribute is set for both instances + self.assertIsNotNone(url_speech_graph.graph) + self.assertIsNotNone(local_speech_graph.graph) + + @patch('scrapegraphai.graphs.speech_graph.BaseGraph') + @patch('scrapegraphai.graphs.abstract_graph.init_chat_model') + @patch('scrapegraphai.models.openai_tts.OpenAI') + def test_speech_graph_initialization(self, mock_openai, mock_init_chat_model, mock_base_graph): + """ + Test the initialization of SpeechGraph with both URL and local directory sources. + This test covers the scenario where input_key is set correctly based on the source, + ensures that the graph is created with the proper configuration, and verifies + that the OpenAI client for text-to-speech is initialized correctly. + """ + # Arrange + prompt = "Summarize the contents" + config = { + "llm": {"model": "openai/gpt-3.5-turbo"}, + "tts_model": { + "api_key": "test_api_key", + "base_url": "https://api.openai.com/v1" + }, + "output_path": "test_output.mp3" + } + + class TestSchema(BaseModel): + summary: str + + # Mock the LLM initialization + mock_llm = MagicMock() + mock_init_chat_model.return_value = mock_llm + + # Test with URL source + url_source = "https://example.com" + url_speech_graph = SpeechGraph(prompt, url_source, config, TestSchema) + + # Assert for URL source + self.assertEqual(url_speech_graph.input_key, "url") + self.assertEqual(url_speech_graph.prompt, prompt) + self.assertEqual(url_speech_graph.source, url_source) + self.assertEqual(url_speech_graph.config, config) + self.assertEqual(url_speech_graph.schema, TestSchema) + + # Test with local directory source + local_source = "/path/to/local/directory" + local_speech_graph = SpeechGraph(prompt, local_source, config, TestSchema) + + # Assert for local directory source + self.assertEqual(local_speech_graph.input_key, "local_dir") + self.assertEqual(local_speech_graph.prompt, prompt) + self.assertEqual(local_speech_graph.source, local_source) + self.assertEqual(local_speech_graph.config, config) + self.assertEqual(local_speech_graph.schema, TestSchema) + + # Verify that _create_graph was called for both instances + self.assertEqual(mock_base_graph.call_count, 2) + + # Verify that OpenAI client was initialized with the correct configuration + mock_openai.assert_called_with(api_key="test_api_key", base_url="https://api.openai.com/v1") + + # Verify that the graph attribute is set for both instances + self.assertIsNotNone(url_speech_graph.graph) + self.assertIsNotNone(local_speech_graph.graph) + + @patch('scrapegraphai.graphs.speech_graph.BaseGraph') + @patch('scrapegraphai.graphs.abstract_graph.init_chat_model') + @patch('scrapegraphai.graphs.speech_graph.save_audio_from_bytes') + def test_speech_graph_run(self, mock_save_audio, mock_init_chat_model, mock_base_graph): + """ + Test the run method of SpeechGraph to ensure it executes the graph and saves the audio output. + """ + # Arrange + prompt = "Summarize the contents" + source = "https://example.com" + config = { + "llm": {"model": "openai/gpt-3.5-turbo"}, + "tts_model": { + "api_key": "test_api_key", + "base_url": "https://api.openai.com/v1" + }, + "output_path": "test_output.mp3" + } + + mock_llm = MagicMock() + mock_init_chat_model.return_value = mock_llm + + mock_graph = MagicMock() + mock_graph.execute.return_value = ({"answer": "Test answer", "audio": b"fake_audio_data"}, {}) + mock_base_graph.return_value = mock_graph + + speech_graph = SpeechGraph(prompt, source, config) + + # Act + result = speech_graph.run() + + # Assert + self.assertEqual(result, "Test answer") + mock_graph.execute.assert_called_once_with({"user_prompt": prompt, "url": source}) + mock_save_audio.assert_called_once_with(b"fake_audio_data", "test_output.mp3") \ No newline at end of file diff --git a/tests/test_xml_scraper_graph.py b/tests/test_xml_scraper_graph.py new file mode 100644 index 00000000..51b8086b --- /dev/null +++ b/tests/test_xml_scraper_graph.py @@ -0,0 +1,50 @@ +import unittest + +from scrapegraphai.graphs.xml_scraper_graph import XMLScraperGraph +from unittest.mock import MagicMock, patch + +class TestXMLScraperGraph(unittest.TestCase): + @patch('scrapegraphai.graphs.xml_scraper_graph.BaseGraph') + @patch.object(XMLScraperGraph, '_create_llm') + def test_xml_scraper_graph_with_directory_source(self, mock_create_llm, MockBaseGraph): + """ + Test XMLScraperGraph with a directory source containing multiple XML files. + This test checks if the graph correctly handles a directory input and processes multiple XML files. + """ + # Mock the _create_llm method to return a mock LLM + mock_llm = MagicMock() + mock_create_llm.return_value = mock_llm + + # Mock the BaseGraph and its execute method + mock_execute = MagicMock(return_value=({ + "answer": "Processed multiple XML files from directory" + }, {})) + MockBaseGraph.return_value.execute = mock_execute + + # Create a mock directory path + mock_dir = "/path/to/xml/directory" + + # Create an instance of XMLScraperGraph with a directory source + xml_scraper = XMLScraperGraph( + prompt="Summarize the content of all XML files", + source=mock_dir, + config={"llm": {"model": "mock_model"}} + ) + + # Assert that the input_key is set to "xml_dir" for directory source + self.assertEqual(xml_scraper.input_key, "xml_dir") + + # Run the graph + result = xml_scraper.run() + + # Assert that the execute method was called with the correct inputs + mock_execute.assert_called_once_with({ + "user_prompt": "Summarize the content of all XML files", + "xml_dir": mock_dir + }) + + # Assert that the result is as expected + self.assertEqual(result, "Processed multiple XML files from directory") + + # Assert that _create_llm was called + mock_create_llm.assert_called_once() \ No newline at end of file