Skip to content

Commit

Permalink
fix(): dropping name and description from dataframe
Browse files Browse the repository at this point in the history
  • Loading branch information
scaliseraoul-sinaptik committed Jan 30, 2025
1 parent 0499d4a commit 5929ca2
Show file tree
Hide file tree
Showing 18 changed files with 288 additions and 447 deletions.
6 changes: 2 additions & 4 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""

import os
import re
from io import BytesIO
from typing import List, Optional, Union
from zipfile import ZipFile
Expand Down Expand Up @@ -86,7 +85,6 @@ def create(
>>> create(
... path="my-org/my-dataset",
... df=my_dataframe,
... name="My Dataset",
... description="This is a sample dataset.",
... columns=[
... {"name": "id", "type": "integer", "description": "Primary key"},
Expand Down Expand Up @@ -241,8 +239,8 @@ def load(dataset_path: str) -> DataFrame:

def read_csv(filepath: str) -> DataFrame:
data = pd.read_csv(filepath)
name = f"table_{sanitize_sql_table_name(filepath)}"
return DataFrame(data, name=name)
table = f"table_{sanitize_sql_table_name(filepath)}"
return DataFrame(data, _table_name=table)


__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion pandasai/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _execute_local_sql_query(self, query: str) -> pd.DataFrame:
with duckdb.connect() as con:
# Register all DataFrames in the state
for df in self._state.dfs:
con.register(df.name, df)
con.register(df.schema.source.table, df)

# Execute the query and fetch the result as a pandas DataFrame
result = con.sql(query).df()
Expand Down
9 changes: 6 additions & 3 deletions pandasai/core/code_generation/code_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _check_direct_sql_func_def_exists(self, node: ast.AST) -> bool:
return isinstance(node, ast.FunctionDef) and node.name == "execute_sql_query"

def _replace_table_names(
self, sql_query: str, table_names: list, allowed_table_names: list
self, sql_query: str, table_names: list, allowed_table_names: dict
) -> str:
"""
Replace table names in the SQL query with case-sensitive or authorized table names.
Expand All @@ -54,8 +54,11 @@ def _clean_sql_query(self, sql_query: str) -> str:
"""
sql_query = sql_query.rstrip(";")
table_names = extract_table_names(sql_query)
allowed_table_names = {df.name: df.name for df in self.context.dfs} | {
f'"{df.name}"': df.name for df in self.context.dfs
allowed_table_names = {
df.schema.source.table: df.schema.source.table for df in self.context.dfs
} | {
f'"{df.schema.source.table}"': df.schema.source.table
for df in self.context.dfs
}
return self._replace_table_names(sql_query, table_names, allowed_table_names)

Expand Down
14 changes: 4 additions & 10 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ def _build_dataset(
return DataFrame(
df,
schema=schema,
name=schema.name,
description=schema.description,
path=dataset_path,
)
else:
Expand Down Expand Up @@ -123,25 +121,21 @@ def _get_loader_function(self, source_type: str):
f"Please install the {SUPPORTED_SOURCE_CONNECTORS[source_type]} library."
) from e

def _read_csv_or_parquet(self, file_path: str, format: str) -> DataFrame:
if format == "parquet":
def _read_csv_or_parquet(self, file_path: str, _format: str) -> DataFrame:
if _format == "parquet":
return DataFrame(
pd.read_parquet(file_path),
schema=self.schema,
path=self.dataset_path,
name=self.schema.name,
description=self.schema.description,
)
elif format == "csv":
elif _format == "csv":
return DataFrame(
pd.read_csv(file_path),
schema=self.schema,
path=self.dataset_path,
name=self.schema.name,
description=self.schema.description,
)
else:
raise ValueError(f"Unsupported file format: {format}")
raise ValueError(f"Unsupported file format: {_format}")

def _load_from_local_source(self) -> pd.DataFrame:
source_type = self.schema.source.type
Expand Down
4 changes: 4 additions & 0 deletions pandasai/data_loader/semantic_layer_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def validate_type_and_fields(cls, values):
raise ValueError(
f"For local source type '{_type}', 'path' must be defined."
)
if not table:
raise ValueError(
f"For local source type '{_type}', 'table' must be defined."
)

elif _type in REMOTE_SOURCE_TYPES:
if not connection:
Expand Down
46 changes: 29 additions & 17 deletions pandasai/dataframe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hashlib
import os
from io import BytesIO
from typing import TYPE_CHECKING, ClassVar, Optional, Union
from typing import TYPE_CHECKING, Optional, Union
from zipfile import ZipFile

import pandas as pd
Expand All @@ -21,7 +21,6 @@
from pandasai.helpers.dataframe_serializer import DataframeSerializer
from pandasai.helpers.path import find_project_root
from pandasai.helpers.session import get_pandaai_session
from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name

if TYPE_CHECKING:
from pandasai.agent.base import Agent
Expand All @@ -41,9 +40,8 @@ class DataFrame(pd.DataFrame):
_metadata = [
"_agent",
"_column_hash",
"_table_name",
"config",
"description",
"name",
"path",
"schema",
]
Expand All @@ -57,29 +55,33 @@ def __init__(
copy: bool | None = None,
**kwargs,
) -> None:
_name: Optional[str] = kwargs.pop("name", None)
_schema: Optional[SemanticLayerSchema] = kwargs.pop("schema", None)
_description: Optional[str] = kwargs.pop("description", None)
_path: Optional[str] = kwargs.pop("path", None)
_table_name: Optional[str] = kwargs.pop("_table_name", None)

super().__init__(
data=data, index=index, columns=columns, dtype=dtype, copy=copy
)

self._column_hash = self._calculate_column_hash()
print("dataframe_init method")
if _table_name:
print(f"dataframe_init {_table_name}")
self._table_name = _table_name

self.name = _name or f"table_{self._column_hash}"
self._column_hash = self._calculate_column_hash()
self.schema = _schema or DataFrame.get_default_schema(self)
self.description = _description
self.path = _path

self.config = pai.config.get()
self._agent: Optional[Agent] = None

def __repr__(self) -> str:
"""Return a string representation of the DataFrame."""
name_str = f"name='{self.name}'" if self.name else ""
desc_str = f"description='{self.description}'" if self.description else ""
name_str = f"name='{self.schema.name}'"
desc_str = (
f"description='{self.schema.description}'"
if self.schema.description
else ""
)
metadata = ", ".join(filter(None, [name_str, desc_str]))

return f"PandaAI DataFrame({metadata})\n{super().__repr__()}"
Expand Down Expand Up @@ -143,7 +145,7 @@ def serialize_dataframe(self) -> str:
Returns:
str: Serialized string representation of the DataFrame
"""
return DataframeSerializer().serialize(self)
return DataframeSerializer.serialize(self)

def get_head(self):
return self.head()
Expand All @@ -160,8 +162,8 @@ def push(self):

params = {
"path": self.path,
"description": self.description,
"name": self.name if self.name else "",
"description": self.schema.description,
"name": self.schema.name,
}

dataset_directory = os.path.join(find_project_root(), "datasets", self.path)
Expand Down Expand Up @@ -273,8 +275,18 @@ def get_default_schema(cls, dataframe: DataFrame) -> SemanticLayerSchema:
for name, dtype in dataframe.dtypes.items()
]

table_name = getattr(
dataframe, "_table_name", f"table_{dataframe._column_hash}"
)

print(f"default schema: {table_name}")

return SemanticLayerSchema(
name=dataframe.name,
source=Source(type="parquet", path="data.parquet"),
name=f"{dataframe._column_hash}",
source=Source(
type="parquet",
path="data.parquet",
table=table_name,
),
columns=columns_list,
)
16 changes: 0 additions & 16 deletions pandasai/dataframe/virtual_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ class VirtualDataFrame(DataFrame):
"_head",
"_loader",
"config",
"description",
"head",
"name",
"path",
"schema",
]
Expand All @@ -32,22 +30,8 @@ def __init__(self, *args, **kwargs):
raise VirtualizationError("Data loader is required for virtualization!")
self._head = None

schema: SemanticLayerSchema = kwargs.get("schema", None)
if not schema:
raise VirtualizationError("Schema is required for virtualization!")

name = kwargs.pop("name", None)

description = kwargs.pop("description", None)

table_name = schema.source.table or name or schema.name

table_description = description or schema.description

super().__init__(
self.get_head(),
name=table_name,
description=table_description,
*args,
**kwargs,
)
Expand Down
16 changes: 10 additions & 6 deletions pandasai/helpers/dataframe_serializer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import pandas as pd
import typing

if typing.TYPE_CHECKING:
from ..dataframe.base import DataFrame


class DataframeSerializer:
def __init__(self) -> None:
pass

def serialize(self, df: pd.DataFrame) -> str:
@staticmethod
def serialize(df: "DataFrame") -> str:
"""
Convert df to csv like format where csv is wrapped inside <dataframe></dataframe>
Args:
Expand All @@ -17,12 +21,12 @@ def serialize(self, df: pd.DataFrame) -> str:
dataframe_info = "<table"

# Add name attribute if available
if df.name is not None:
dataframe_info += f' table_name="{df.name}"'
if df.schema.source.table is not None:
dataframe_info += f' table_name="{df.schema.source.table}"'

# Add description attribute if available
if df.description is not None:
dataframe_info += f' description="{df.description}"'
if df.schema.description is not None:
dataframe_info += f' description="{df.schema.description}"'

dataframe_info += f' dimensions="{df.rows_count}x{df.columns_count}">'

Expand Down
54 changes: 11 additions & 43 deletions tests/unit_tests/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,40 +17,6 @@
class TestAgent:
"Unit tests for Agent class"

@pytest.fixture
def mysql_schema(self):
raw_schema = {
"name": "countries",
"source": {
"type": "mysql",
"connection": {
"host": "localhost",
"port": 3306,
"database": "test_db",
"user": "test_user",
"password": "test_password",
},
"table": "countries",
},
}
return SemanticLayerSchema(**raw_schema)

@pytest.fixture
def sample_df(self) -> DataFrame:
return DataFrame(
{
"country": ["United States", "United Kingdom", "Japan", "China"],
"gdp": [
19294482071552,
2891615567872,
4380756541440,
14631844184064,
],
"happiness_index": [6.94, 7.22, 5.87, 5.12],
},
name="countries",
)

@pytest.fixture
def llm(self, output: Optional[str] = None) -> FakeLLM:
return FakeLLM(output=output)
Expand All @@ -60,7 +26,7 @@ def config(self, llm: FakeLLM) -> dict:
return {"llm": llm}

@pytest.fixture
def agent(self, sample_df: pd.DataFrame, config: dict) -> Agent:
def agent(self, sample_df: DataFrame, config: dict) -> Agent:
return Agent(sample_df, config, vectorstore=MagicMock())

@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -108,9 +74,11 @@ def test_code_generation(self, mock_generate_code, sample_df, config):
assert response == "print(United States has the highest gdp)"

@patch("pandasai.agent.base.CodeGenerator")
def test_generate_code_with_cache_hit(self, mock_generate_code, agent: Agent):
def test_generate_code_with_cache_hit(
self, mock_generate_code, agent: Agent, sample_df
):
# Set up the cache to return a pre-cached response
cached_code = """execute_sql_query('SELECT country FROM countries ORDER BY gdp DESC LIMIT 1')
cached_code = f"""execute_sql_query('SELECT A FROM {sample_df.schema.source.table}')
print('Cached result: US has the highest GDP.')"""
agent._state.config.enable_cache = True
agent._state.cache.get = MagicMock(return_value=cached_code)
Expand Down Expand Up @@ -450,19 +418,19 @@ def test_train_method_with_code_but_no_queries(self, agent):
with pytest.raises(ValueError):
agent.train(codes)

def test_execute_local_sql_query_success(self, agent):
query = "SELECT count(*) as total from countries;"
expected_result = pd.DataFrame({"total": [4]})
def test_execute_local_sql_query_success(self, agent, sample_df):
query = f"SELECT count(*) as total from {sample_df.schema.source.table};"
expected_result = pd.DataFrame({"total": [3]})
result = agent._execute_local_sql_query(query)
pd.testing.assert_frame_equal(result, expected_result)

def test_execute_local_sql_query_failure(self, agent):
with pytest.raises(RuntimeError, match="SQL execution failed"):
agent._execute_local_sql_query("wrong query;")

def test_execute_sql_query_success_local(self, agent):
query = "SELECT count(*) as total from countries;"
expected_result = pd.DataFrame({"total": [4]})
def test_execute_sql_query_success_local(self, agent, sample_df):
query = f"SELECT count(*) as total from {sample_df.schema.source.table};"
expected_result = pd.DataFrame({"total": [3]})
result = agent._execute_sql_query(query)
pd.testing.assert_frame_equal(result, expected_result)

Expand Down
Loading

0 comments on commit 5929ca2

Please sign in to comment.