From 8bc807c50fb8761879aeb549acb43f0566fafbb5 Mon Sep 17 00:00:00 2001 From: Nautics889 Date: Thu, 21 Sep 2023 22:36:52 +0300 Subject: [PATCH] refactor: type for `prompt_id` (#586) * (refactor): update type hint for `prompt_id` * (refactor): update prompt id to be passed in tests, make it be uuid object, not empty string * (fix): update type hint for `_last_prompt_id` * (docs): update docstring for `execute_code()` --- pandasai/helpers/code_manager.py | 13 ++++++------- pandasai/smart_datalake/__init__.py | 2 +- tests/test_codemanager.py | 14 ++++++++++---- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/pandasai/helpers/code_manager.py b/pandasai/helpers/code_manager.py index e07449b1f..00445d7e7 100644 --- a/pandasai/helpers/code_manager.py +++ b/pandasai/helpers/code_manager.py @@ -1,5 +1,6 @@ import re import ast +import uuid from collections import defaultdict import astor @@ -15,7 +16,7 @@ WHITELISTED_LIBRARIES, ) from ..middlewares.charts import ChartsMiddleware -from typing import Union, List, Optional, Generator +from typing import Union, List, Optional, Generator, Any from ..helpers.logger import Logger from ..schemas.df_config import Config import logging @@ -182,8 +183,8 @@ def _required_dfs(self, code: str) -> List[str]: def execute_code( self, code: str, - prompt_id: str, - ) -> str: + prompt_id: uuid.UUID, + ) -> Any: """ Execute the python code generated by LLMs to answer the question about the input dataframe. Run the code in the current context and return the @@ -191,12 +192,10 @@ def execute_code( Args: code (str): Python code to execute. - data_frame (pd.DataFrame): Full Pandas DataFrame. - use_error_correction_framework (bool): Turn on Error Correction mechanism. - Default to True. + prompt_id (uuid.UUID): UUID of the request. Returns: - str: The result of the code execution. The type of the result depends + Any: The result of the code execution. The type of the result depends on the generated code. """ diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index fc87d6c63..7ae01e5ed 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -49,7 +49,7 @@ class SmartDatalake: _cache: Cache = None _logger: Logger _start_time: float - _last_prompt_id: uuid + _last_prompt_id: uuid.UUID _code_manager: CodeManager _memory: Memory diff --git a/tests/test_codemanager.py b/tests/test_codemanager.py index d318d1bdb..b7a21cf6c 100644 --- a/tests/test_codemanager.py +++ b/tests/test_codemanager.py @@ -1,4 +1,5 @@ """Unit tests for the CodeManager class""" +import uuid from typing import Optional from unittest.mock import Mock, patch @@ -75,18 +76,23 @@ def test_run_code_for_calculations(self, code_manager: CodeManager): code = """def analyze_data(dfs): return {'type': 'number', 'value': 1 + 1}""" - assert code_manager.execute_code(code, "")["value"] == 2 + assert code_manager.execute_code(code, uuid.uuid4())["value"] == 2 assert code_manager.last_code_executed == code def test_run_code_invalid_code(self, code_manager: CodeManager): with pytest.raises(Exception): - code_manager.execute_code("1+ ", "") + # noinspection PyStatementEffect + code_manager.execute_code("1+ ", uuid.uuid4())["value"] def test_clean_code_remove_builtins(self, code_manager: CodeManager): builtins_code = """import set def analyze_data(dfs): return {'type': 'number', 'value': set([1, 2, 3])}""" - assert code_manager.execute_code(builtins_code, "")["value"] == {1, 2, 3} + assert code_manager.execute_code(builtins_code, uuid.uuid4())["value"] == { + 1, + 2, + 3, + } assert ( code_manager.last_code_executed == """def analyze_data(dfs): @@ -123,7 +129,7 @@ def test_clean_code_raise_bad_import_error(self, code_manager: CodeManager): print(os.listdir()) """ with pytest.raises(BadImportError): - code_manager.execute_code(malicious_code, "") + code_manager.execute_code(malicious_code, uuid.uuid4()) def test_remove_dfs_overwrites(self, code_manager: CodeManager): hallucinated_code = """def analyze_data(dfs):