Skip to content

Commit

Permalink
feat: filter only the columns that are provided in the schema (#1562)
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri authored Jan 30, 2025
1 parent 4a093c6 commit a0b5878
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 0 deletions.
18 changes: 18 additions & 0 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def load(
self.dataset_path = self.schema.source.path

df = self._load_from_local_source()
df = self._filter_columns(df)
df = self._apply_transformations(df)

# Convert to pandas DataFrame while preserving internal data
Expand Down Expand Up @@ -202,6 +203,23 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra
f"Failed to execute query for '{source_type}' with: {formatted_query}"
) from e

def _filter_columns(self, df: pd.DataFrame) -> pd.DataFrame:
"""Filter DataFrame columns based on schema columns if specified.
Args:
df (pd.DataFrame): Input DataFrame to filter
Returns:
pd.DataFrame: DataFrame with only columns specified in schema
"""
if not self.schema or not self.schema.columns:
return df

schema_columns = [col.name for col in self.schema.columns]
df_columns = df.columns.tolist()
columns_to_keep = [col for col in df_columns if col in schema_columns]
return df[columns_to_keep]

def _apply_transformations(self, df: pd.DataFrame) -> pd.DataFrame:
if not self.schema.transformations:
return df
Expand Down
53 changes: 53 additions & 0 deletions tests/unit_tests/dataframe/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,56 @@ def test_load_with_schema_and_path(self, sample_schema):
match="Provide only one of 'dataset_path' or 'schema', not both.",
):
result = loader.load("test/users", sample_schema)

def test_filter_columns_with_schema_columns(self, sample_schema):
"""Test that columns are filtered correctly when schema columns are specified."""
loader = DatasetLoader()
loader.schema = sample_schema

# Create a DataFrame with extra columns
df = pd.DataFrame(
{
"email": ["[email protected]"],
"first_name": ["John"],
"timestamp": ["2023-01-01"],
"extra_col": ["extra"], # This column should be filtered out
}
)

filtered_df = loader._filter_columns(df)
assert list(filtered_df.columns) == ["email", "first_name", "timestamp"]
assert "extra_col" not in filtered_df.columns

def test_filter_columns_without_schema_columns(self):
"""Test that all columns are kept when no schema columns are specified."""
loader = DatasetLoader()
# Create schema without columns
loader.schema = SemanticLayerSchema(
**{"name": "Users", "source": {"type": "csv", "path": "users.csv"}}
)

df = pd.DataFrame({"col1": [1], "col2": [2], "col3": [3]})

filtered_df = loader._filter_columns(df)
assert list(filtered_df.columns) == ["col1", "col2", "col3"]

def test_filter_columns_with_non_matching_columns(self, sample_schema):
"""Test filtering when schema columns don't match DataFrame columns."""
loader = DatasetLoader()
loader.schema = sample_schema

# Create DataFrame with none of the schema columns
df = pd.DataFrame({"different_col1": [1], "different_col2": [2]})

filtered_df = loader._filter_columns(df)
assert len(filtered_df.columns) == 0 # Should return empty DataFrame

def test_filter_columns_without_schema(self):
"""Test that all columns are kept when no schema is set."""
loader = DatasetLoader()
loader.schema = None

df = pd.DataFrame({"col1": [1], "col2": [2]})

filtered_df = loader._filter_columns(df)
assert list(filtered_df.columns) == ["col1", "col2"]

0 comments on commit a0b5878

Please sign in to comment.