Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Schema parameter type - Unit Tests #896

Closed
wants to merge 8 commits into from
86 changes: 86 additions & 0 deletions tests/test_csv_scraper_multi_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import unittest

from scrapegraphai.graphs.csv_scraper_graph import CSVScraperGraph
from scrapegraphai.graphs.csv_scraper_multi_graph import CSVScraperMultiGraph
from unittest.mock import MagicMock, patch

class TestCSVScraperMultiGraph(unittest.TestCase):
def test_create_graph_structure_and_run(self):
"""
Test if the CSVScraperMultiGraph creates the correct graph structure
with GraphIteratorNode and MergeAnswersNode, initializes properly,
and executes the run method correctly.
"""
prompt = "Test prompt"
source = ["url1", "url2"]
config = {
"llm": {
"model": "test-model",
"model_provider": "openai",
"temperature": 0 # Adding temperature to match the actual implementation
},
"embedder": {"model": "test-embedder"},
"headless": True,
"verbose": False,
"model_token": 1000
}

with patch('scrapegraphai.graphs.csv_scraper_multi_graph.GraphIteratorNode') as mock_iterator_node, \
patch('scrapegraphai.graphs.csv_scraper_multi_graph.MergeAnswersNode') as mock_merge_node, \
patch('scrapegraphai.graphs.csv_scraper_multi_graph.BaseGraph') as mock_base_graph, \
patch('scrapegraphai.graphs.abstract_graph.AbstractGraph._create_llm') as mock_create_llm:

# Mock the _create_llm method to return a MagicMock
mock_llm = MagicMock()
mock_create_llm.return_value = mock_llm

csv_scraper_multi_graph = CSVScraperMultiGraph(prompt, source, config)

# Check if GraphIteratorNode is created with correct parameters
mock_iterator_node.assert_called_once_with(
input="user_prompt & jsons",
output=["results"],
node_config={
"graph_instance": CSVScraperGraph,
"scraper_config": csv_scraper_multi_graph.copy_config,
}
)

# Check if MergeAnswersNode is created with correct parameters
mock_merge_node.assert_called_once_with(
input="user_prompt & results",
output=["answer"],
node_config={"llm_model": mock_llm, "schema": csv_scraper_multi_graph.copy_schema}
)

# Check if BaseGraph is created with correct structure
mock_base_graph.assert_called_once()
graph_args = mock_base_graph.call_args[1]
self.assertEqual(len(graph_args['nodes']), 2)
self.assertEqual(len(graph_args['edges']), 1)
self.assertEqual(graph_args['entry_point'], mock_iterator_node.return_value)
self.assertEqual(graph_args['graph_name'], "CSVScraperMultiGraph")

# Check if the graph attribute is set correctly
self.assertIsInstance(csv_scraper_multi_graph.graph, MagicMock)

# Check if the prompt and source are set correctly
self.assertEqual(csv_scraper_multi_graph.prompt, prompt)
self.assertEqual(csv_scraper_multi_graph.source, source)

# Check if the config is copied correctly
self.assertDictEqual(csv_scraper_multi_graph.copy_config, config)

# Test the run method
mock_execute = MagicMock(return_value=({"answer": "Test answer"}, {}))
csv_scraper_multi_graph.graph.execute = mock_execute

result = csv_scraper_multi_graph.run()

mock_execute.assert_called_once_with({"user_prompt": prompt, "jsons": source})
self.assertEqual(result, "Test answer")

# Test the case when no answer is found
mock_execute.return_value = ({}, {})
result = csv_scraper_multi_graph.run()
self.assertEqual(result, "No answer found.")
56 changes: 56 additions & 0 deletions tests/test_json_scraper_multi_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from scrapegraphai.graphs.json_scraper_multi_graph import JSONScraperMultiGraph
from unittest import TestCase
from unittest.mock import MagicMock, patch

class TestJSONScraperMultiGraph(TestCase):
def test_empty_source_list(self):
"""
Test that JSONScraperMultiGraph handles an empty source list gracefully.
This test ensures that the graph is created correctly with an empty list of sources
and returns a default message when run with no results.
"""
prompt = "Test prompt"
empty_source = []
config = {
"llm": {
"model": "test_model",
"model_provider": "test_provider"
}
}

with patch('scrapegraphai.graphs.json_scraper_multi_graph.BaseGraph') as mock_base_graph, \
patch('scrapegraphai.graphs.json_scraper_multi_graph.GraphIteratorNode') as mock_graph_iterator_node, \
patch('scrapegraphai.graphs.json_scraper_multi_graph.MergeAnswersNode') as mock_merge_answers_node, \
patch('scrapegraphai.graphs.abstract_graph.AbstractGraph._create_llm') as mock_create_llm:

# Create mock instances
mock_graph_instance = MagicMock()
mock_base_graph.return_value = mock_graph_instance

# Mock the execute method to return a dictionary with no answer
mock_graph_instance.execute.return_value = ({"answer": "No answer found."}, {})

# Mock the _create_llm method
mock_create_llm.return_value = MagicMock()

# Initialize the JSONScraperMultiGraph
graph = JSONScraperMultiGraph(prompt, empty_source, config)

# Run the graph
result = graph.run()

# Assert that the graph was created with the correct nodes
mock_graph_iterator_node.assert_called_once()
mock_merge_answers_node.assert_called_once()

# Assert that BaseGraph was initialized with the correct parameters
mock_base_graph.assert_called_once()
_, kwargs = mock_base_graph.call_args
self.assertEqual(len(kwargs['nodes']), 2)
self.assertEqual(len(kwargs['edges']), 1)

# Assert that the execute method was called with the correct inputs
mock_graph_instance.execute.assert_called_once_with({"user_prompt": prompt, "jsons": empty_source})

# Assert the result
self.assertEqual(result, "No answer found.")
Loading
Loading