Skip to content

Commit

Permalink
feature(FileManager): adding FileManager to make feasible work with t…
Browse files Browse the repository at this point in the history
…he library in other environment (#1573)
  • Loading branch information
scaliseraoul authored Jan 31, 2025
1 parent d2350a1 commit 0c6738b
Show file tree
Hide file tree
Showing 16 changed files with 176 additions and 268 deletions.
17 changes: 8 additions & 9 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,17 @@ def create(

org_name, dataset_name = get_validated_dataset_path(path)

dataset_directory = os.path.join(
find_project_root(), "datasets", org_name, dataset_name
)
dataset_directory = str(os.path.join(org_name, dataset_name))

schema_path = os.path.join(str(dataset_directory), "schema.yaml")
parquet_file_path = os.path.join(str(dataset_directory), "data.parquet")
schema_path = os.path.join(dataset_directory, "schema.yaml")
parquet_file_path = os.path.join(dataset_directory, "data.parquet")

file_manager = config.get().file_manager
# Check if dataset already exists
if os.path.exists(dataset_directory) and os.path.exists(schema_path):
if file_manager.exists(dataset_directory) and file_manager.exists(schema_path):
raise ValueError(f"Dataset already exists at path: {path}")

os.makedirs(dataset_directory, exist_ok=True)
file_manager.mkdir(dataset_directory)

if df is None and source is None and not view:
raise InvalidConfigError(
Expand All @@ -135,8 +135,7 @@ def create(
if columns:
schema.columns = [Column(**column) for column in columns]

with open(schema_path, "w") as yml_file:
yml_file.write(schema.to_yaml())
file_manager.write(schema_path, schema.to_yaml())

print(f"Dataset saved successfully to path: {dataset_directory}")

Expand Down
3 changes: 3 additions & 0 deletions pandasai/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
from abc import ABC, abstractmethod
from importlib.util import find_spec
from typing import Any, Dict, Optional

from pydantic import BaseModel, ConfigDict

from pandasai.helpers.filemanager import DefaultFileManager, FileManager
from pandasai.llm.base import LLM


Expand All @@ -13,6 +15,7 @@ class Config(BaseModel):
enable_cache: bool = True
max_retries: int = 3
llm: Optional[LLM] = None
file_manager: FileManager = DefaultFileManager()

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down
40 changes: 0 additions & 40 deletions pandasai/core/prompts/file_based_prompt.py

This file was deleted.

26 changes: 12 additions & 14 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import MethodNotImplementedError
from pandasai.helpers.path import find_project_root
from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name

from .. import ConfigManager
from ..constants import (
LOCAL_SOURCE_TYPES,
)
Expand Down Expand Up @@ -48,21 +48,22 @@ def create_loader_from_path(cls, dataset_path: str) -> "DatasetLoader":
"""
Factory method to create the appropriate loader based on the dataset type.
"""
schema = cls._read_local_schema(dataset_path)
schema = cls._read_schema_file(dataset_path)
return DatasetLoader.create_loader_from_schema(schema, dataset_path)

@staticmethod
def _read_local_schema(dataset_path: str) -> SemanticLayerSchema:
schema_path = os.path.join(
find_project_root(), "datasets", dataset_path, "schema.yaml"
)
if not os.path.exists(schema_path):
def _read_schema_file(dataset_path: str) -> SemanticLayerSchema:
schema_path = os.path.join(dataset_path, "schema.yaml")

file_manager = ConfigManager.get().file_manager

if not file_manager.exists(schema_path):
raise FileNotFoundError(f"Schema file not found: {schema_path}")

with open(schema_path, "r") as file:
raw_schema = yaml.safe_load(file)
raw_schema["name"] = sanitize_sql_table_name(raw_schema["name"])
return SemanticLayerSchema(**raw_schema)
schema_file = file_manager.load(schema_path)
raw_schema = yaml.safe_load(schema_file)
raw_schema["name"] = sanitize_sql_table_name(raw_schema["name"])
return SemanticLayerSchema(**raw_schema)

def load(self) -> DataFrame:
"""
Expand All @@ -80,6 +81,3 @@ def _apply_transformations(self, df: pd.DataFrame) -> pd.DataFrame:

transformation_manager = TransformationManager(df)
return transformation_manager.apply_transformations(self.schema.transformations)

def _get_abs_dataset_path(self):
return os.path.join(find_project_root(), "datasets", self.dataset_path)
2 changes: 1 addition & 1 deletion pandasai/data_loader/local_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _load_from_local_source(self) -> pd.DataFrame:
)

filepath = os.path.join(
str(self._get_abs_dataset_path()),
self.dataset_path,
self.schema.source.path,
)

Expand Down
1 change: 0 additions & 1 deletion pandasai/data_loader/sql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra
raise MaliciousQueryError(
"The SQL query is deemed unsafe and will not be executed."
)

try:
dataframe: pd.DataFrame = load_function(
connection_info, formatted_query, params
Expand Down
57 changes: 24 additions & 33 deletions pandasai/dataframe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pandas._typing import Axes, Dtype

import pandasai as pai
from pandasai.config import Config
from pandasai.config import Config, ConfigManager
from pandasai.core.response import BaseResponse
from pandasai.data_loader.semantic_layer_schema import (
Column,
Expand All @@ -19,7 +19,6 @@
)
from pandasai.exceptions import DatasetNotFound, PandaAIApiKeyError
from pandasai.helpers.dataframe_serializer import DataframeSerializer
from pandasai.helpers.path import find_project_root
from pandasai.helpers.session import get_pandaai_session

if TYPE_CHECKING:
Expand Down Expand Up @@ -164,38 +163,32 @@ def push(self):
"name": self.schema.name,
}

dataset_directory = os.path.join(find_project_root(), "datasets", self.path)

dataset_directory = os.path.join("datasets", self.path)
file_manager = ConfigManager.get().file_manager
headers = {"accept": "application/json", "x-authorization": f"Bearer {api_key}"}

files = []
schema_file_path = os.path.join(dataset_directory, "schema.yaml")
data_file_path = os.path.join(dataset_directory, "data.parquet")

try:
# Open schema.yaml
schema_file = open(schema_file_path, "rb")
files.append(("files", ("schema.yaml", schema_file, "application/x-yaml")))

# Check if data.parquet exists and open it
if os.path.exists(data_file_path):
data_file = open(data_file_path, "rb")
files.append(
("files", ("data.parquet", data_file, "application/octet-stream"))
)

# Send the POST request
request_session.post(
"/datasets/push",
files=files,
params=params,
headers=headers,
# Open schema.yaml
schema_file = file_manager.load_binary(schema_file_path)
files.append(("files", ("schema.yaml", schema_file, "application/x-yaml")))

# Check if data.parquet exists and open it
if file_manager.exists(data_file_path):
data_file = file_manager.load_binary(data_file_path)
files.append(
("files", ("data.parquet", data_file, "application/octet-stream"))
)

finally:
# Ensure files are closed after the request
for _, (name, file, _) in files:
file.close()
# Send the POST request
request_session.post(
"/datasets/push",
files=files,
params=params,
headers=headers,
)

print("Your dataset was successfully pushed to the remote server!")
print(f"🔗 URL: https://app.pandabi.ai/datasets/{self.path}")
Expand All @@ -218,20 +211,18 @@ def pull(self):

with ZipFile(BytesIO(file_data.content)) as zip_file:
for file_name in zip_file.namelist():
target_path = os.path.join(
find_project_root(), "datasets", self.path, file_name
)
target_path = os.path.join(self.path, file_name)

file_manager = ConfigManager.get().file_manager
# Check if the file already exists
if os.path.exists(target_path):
if file_manager.exists(target_path):
print(f"Replacing existing file: {target_path}")

# Ensure target directory exists
os.makedirs(os.path.dirname(target_path), exist_ok=True)
file_manager.mkdir(os.path.dirname(target_path))

# Extract the file
with open(target_path, "wb") as f:
f.write(zip_file.read(file_name))
file_manager.write_binary(target_path, zip_file.read(file_name))

# Reloads the Dataframe
from pandasai import DatasetLoader
Expand Down
73 changes: 73 additions & 0 deletions pandasai/helpers/filemanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
from abc import ABC, abstractmethod

from pandasai.helpers.path import find_project_root


class FileManager(ABC):
"""Abstract base class for file loaders, supporting local and remote backends."""

@abstractmethod
def load(self, file_path: str) -> str:
"""Reads the content of a file."""
pass

@abstractmethod
def load_binary(self, file_path: str) -> bytes:
"""Reads the content of a file as bytes."""
pass

@abstractmethod
def write(self, file_path: str, content: str) -> None:
"""Writes content to a file."""
pass

@abstractmethod
def write_binary(self, file_path: str, content: bytes) -> None:
"""Writes binary content to a file."""
pass

@abstractmethod
def exists(self, file_path: str) -> bool:
"""Checks if a file or directory exists."""
pass

@abstractmethod
def mkdir(self, dir_path: str) -> None:
"""Creates a directory if it doesn't exist."""
pass


class DefaultFileManager(FileManager):
"""Local file system implementation of FileLoader."""

def __init__(self):
self.base_path = os.path.join(find_project_root(), "datasets")

def load(self, file_path: str) -> str:
full_path = os.path.join(self.base_path, file_path)
with open(full_path, "r", encoding="utf-8") as f:
return f.read()

def load_binary(self, file_path: str) -> bytes:
full_path = os.path.join(self.base_path, file_path)
with open(full_path, "rb") as f:
return f.read()

def write(self, file_path: str, content: str) -> None:
full_path = os.path.join(self.base_path, file_path)
with open(full_path, "w", encoding="utf-8") as f:
f.write(content)

def write_binary(self, file_path: str, content: bytes) -> None:
full_path = os.path.join(self.base_path, file_path)
with open(full_path, "wb") as f:
f.write(content)

def exists(self, file_path: str) -> bool:
full_path = os.path.join(self.base_path, file_path)
return os.path.exists(full_path)

def mkdir(self, dir_path: str) -> None:
full_path = os.path.join(self.base_path, dir_path)
os.makedirs(full_path, exist_ok=True)
1 change: 1 addition & 0 deletions pandasai/helpers/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def find_project_root(filename=None):

# Get the path of the file that is be
# ing executed

current_file_path = os.path.abspath(os.getcwd())

# Navigate back until we either find a $filename file or there is no parent
Expand Down
3 changes: 2 additions & 1 deletion tests/unit_tests/agent/test_agent_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import pytest

import pandasai as pai
from pandasai import DataFrame, find_project_root
from pandasai import DataFrame
from pandasai.core.response import (
ChartResponse,
DataFrameResponse,
NumberResponse,
StringResponse,
)
from pandasai.helpers.filemanager import find_project_root

# Read the API key from an environment variable
API_KEY = os.getenv("PANDABI_API_KEY_TEST_CHAT", None)
Expand Down
3 changes: 2 additions & 1 deletion tests/unit_tests/agent/test_agent_llm_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from pydantic import BaseModel

import pandasai as pai
from pandasai import DataFrame, find_project_root
from pandasai import DataFrame
from pandasai.helpers.path import find_project_root

# Read the API key from an environment variable
JUDGE_OPENAI_API_KEY = os.getenv("JUDGE_OPENAI_API_KEY", None)
Expand Down
13 changes: 13 additions & 0 deletions tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import pytest

from pandasai import ConfigManager
from pandasai.data_loader.loader import DatasetLoader
from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema
from pandasai.dataframe.base import DataFrame
from pandasai.helpers.filemanager import DefaultFileManager
from pandasai.helpers.path import find_project_root


Expand Down Expand Up @@ -171,3 +173,14 @@ def mock_loader_instance(sample_df):
mock_create_loader_from_schema.return_value = mock_loader_instance

yield mock_loader_instance


@pytest.fixture
def mock_file_manager():
"""Fixture to mock FileManager and its methods."""
with patch.object(ConfigManager, "get") as mock_config_get:
# Create a mock FileManager
mock_file_manager = MagicMock()
mock_file_manager.exists.return_value = False
mock_config_get.return_value.file_manager = mock_file_manager
yield mock_file_manager
Loading

0 comments on commit 0c6738b

Please sign in to comment.