diff --git a/pandasai/pipelines/smart_datalake_chat/code_execution.py b/pandasai/pipelines/smart_datalake_chat/code_execution.py index ba33a6fd2..08e60699a 100644 --- a/pandasai/pipelines/smart_datalake_chat/code_execution.py +++ b/pandasai/pipelines/smart_datalake_chat/code_execution.py @@ -42,9 +42,11 @@ def execute(self, input: Any, **kwargs) -> Any: pipeline_context.get_intermediate_value("last_prompt_id"), pipeline_context.get_intermediate_value("skills"), ) - result = pipeline_context.get_intermediate_value( - "code_manager" - ).execute_code( + + result = pipeline_context.query_exec_tracker.execute_func( + pipeline_context.get_intermediate_value( + "code_manager" + ).execute_code, code=code_to_run, context=code_context, ) diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index 85c0dd376..b4c8a7432 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -347,6 +347,9 @@ def chat(self, query: str, output_type: Optional[str] = None): self.update_intermediate_value_post_pipeline_execution(pipeline_context) + # publish query tracker + self._query_exec_tracker.publish() + return result def _validate_output(self, result: dict, output_type: Optional[str] = None): diff --git a/tests/pipelines/smart_datalake/test_code_execution.py b/tests/pipelines/smart_datalake/test_code_execution.py index 36227e230..51a6296b2 100644 --- a/tests/pipelines/smart_datalake/test_code_execution.py +++ b/tests/pipelines/smart_datalake/test_code_execution.py @@ -100,6 +100,8 @@ def mock_intermediate_values(key: str): return mock_code_manager context.get_intermediate_value = Mock(side_effect=mock_intermediate_values) + context._query_exec_tracker = Mock() + context.query_exec_tracker.execute_func = Mock(return_value="Mocked Result") result = code_execution.execute( input="Test Code", context=context, logger=logger @@ -157,17 +159,24 @@ def mock_execute_code(*args, **kwargs): raise Exception("Unit test exception") return "Mocked Result after retry" + # Conditional return of execute_func method based arguments it is called with + def mock_execute_func(*args, **kwargs): + if isinstance(args[0], Mock) and args[0].name == "execute_code": + return mock_execute_code(*args, **kwargs) + else: + return [ + "Interuppted Code", + "Exception Testing", + "Successful after Retry", + ] + mock_code_manager = Mock() - mock_code_manager.execute_code = Mock(side_effect=mock_execute_code) + mock_code_manager.execute_code = Mock() + mock_code_manager.execute_code.name = "execute_code" context._query_exec_tracker = Mock() - context.query_exec_tracker.execute_func = Mock( - return_value=[ - "Interuppted Code", - "Exception Testing", - "Successful after Retry", - ] - ) + + context.query_exec_tracker.execute_func = Mock(side_effect=mock_execute_func) def mock_intermediate_values(key: str): if key == "last_prompt_id": diff --git a/tests/test_smartdatalake.py b/tests/test_smartdatalake.py index c784167de..48ed862c8 100644 --- a/tests/test_smartdatalake.py +++ b/tests/test_smartdatalake.py @@ -149,6 +149,25 @@ def test_last_result_is_saved(self, _mocked_method, smart_datalake: SmartDatalak "value": "There are 10 countries in the dataframe.", } + @patch.object( + CodeManager, + "execute_code", + return_value={ + "type": "string", + "value": "There are 10 countries in the dataframe.", + }, + ) + @patch("pandasai.helpers.query_exec_tracker.QueryExecTracker.publish") + def test_query_tracker_publish_called_in_chat_method( + self, mock_query_tracker_publish, _mocked_method, smart_datalake: SmartDatalake + ): + assert smart_datalake.last_result is None + + _mocked_method.__name__ = "execute_code" + + smart_datalake.chat("How many countries are in the dataframe?") + mock_query_tracker_publish.assert_called() + def test_retry_on_error_with_single_df( self, smart_datalake: SmartDatalake, smart_dataframe: SmartDataframe ):