Skip to content

Commit

Permalink
fix(load): path fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
ArslanSaleem committed Dec 13, 2024
1 parent 89acbdc commit a6192ee
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 15 deletions.
4 changes: 4 additions & 0 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def load(dataset_path: str, virtualized=False) -> DataFrame:
Returns:
DataFrame: A new PandasAI DataFrame instance with loaded data.
"""
path_parts = dataset_path.split("/")
if len(path_parts) != 2:
raise ValueError("Path must be in format 'organization/dataset'")

global _dataset_loader
dataset_full_path = os.path.join(find_project_root(), "datasets", dataset_path)
if not os.path.exists(dataset_full_path):
Expand Down
4 changes: 2 additions & 2 deletions pandasai/chat/code_execution/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import importlib
import sys
import warnings
from typing import List
from typing import List, Union

from pandas.util.version import Version

Expand Down Expand Up @@ -92,7 +92,7 @@ def import_dependency(
name: str,
extra: str = "",
errors: str = "raise",
min_version: str | None = None,
min_version: Union[str, None] = None,
):
"""
Import an optional dependency.
Expand Down
25 changes: 17 additions & 8 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ def load(self, dataset_path: str, virtualized=False) -> DataFrame:
)
else:
# Initialize new dataset loader for virtualization
source_type = self.schema["source"]["type"]

if source_type in ["csv", "parquet"]:
raise ValueError(
"Virtualization is not supported for CSV and Parquet files."
)

data_loader = self.copy()
table_name = self.schema["source"].get("table", None) or self.schema["name"]
table_description = self.schema.get("description", None)
Expand All @@ -58,10 +65,11 @@ def load(self, dataset_path: str, virtualized=False) -> DataFrame:
path=dataset_path,
)

def _get_abs_dataset_path(self):
return os.path.join(find_project_root(), "datasets", self.dataset_path)

def _load_schema(self):
schema_path = os.path.join(
find_project_root(), "datasets", self.dataset_path, "schema.yaml"
)
schema_path = os.path.join(self._get_abs_dataset_path(), "schema.yaml")
if not os.path.exists(schema_path):
raise FileNotFoundError(f"Schema file not found: {schema_path}")

Expand All @@ -79,13 +87,13 @@ def _validate_source_type(self):
def _get_cache_file_path(self) -> str:
if "path" in self.schema["destination"]:
return os.path.join(
"datasets", self.dataset_path, self.schema["destination"]["path"]
self._get_abs_dataset_path(), self.schema["destination"]["path"]
)

file_extension = (
"parquet" if self.schema["destination"]["format"] == "parquet" else "csv"
)
return os.path.join("datasets", self.dataset_path, f"data.{file_extension}")
return os.path.join(self._get_abs_dataset_path(), f"data.{file_extension}")

def _is_cache_valid(self, cache_file: str) -> bool:
if not os.path.exists(cache_file):
Expand Down Expand Up @@ -154,10 +162,11 @@ def _read_csv_or_parquet(self, file_path: str, format: str) -> DataFrame:
def _load_from_source(self) -> pd.DataFrame:
source_type = self.schema["source"]["type"]
if source_type in ["csv", "parquet"]:
filpath = os.path.join(
"datasets", self.dataset_path, self.schema["source"]["path"]
filepath = os.path.join(
self._get_abs_dataset_path(),
self.schema["source"]["path"],
)
return self._read_csv_or_parquet(filpath, source_type)
return self._read_csv_or_parquet(filepath, source_type)

query_builder = QueryBuilder(self.schema)
query = query_builder.build_query()
Expand Down
18 changes: 13 additions & 5 deletions pandasai/data_loader/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ def _get_columns(self) -> str:
else:
return "*"

def _get_table_name(self):
table_name = self.schema["source"].get("table", None) or self.schema["name"]

if not table_name:
raise ValueError("Table name not found in schema!")

table_name = table_name.lower()

return table_name

def _add_order_by(self) -> str:
if "order_by" not in self.schema:
return ""
Expand All @@ -40,16 +50,14 @@ def get_head_query(self, n=5):
source = self.schema.get("source", {})
source_type = source.get("type")

table_name = self.schema["source"]["table"]
table_name = self._get_table_name()

columns = self._get_columns()

order_by = "RAND()"
if source_type in {"sqlite", "postgres"}:
order_by = "RANDOM()"
order_by = "RANDOM()" if source_type in {"sqlite", "postgres"} else "RAND()"

return f"SELECT {columns} FROM {table_name} ORDER BY {order_by} LIMIT {n}"

def get_row_count(self):
table_name = self.schema["source"]["table"]
table_name = self._get_table_name()
return f"SELECT COUNT(*) FROM {table_name}"

0 comments on commit a6192ee

Please sign in to comment.