-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: filter only the columns that are provided in the schema (#1562)
- Loading branch information
Showing
2 changed files
with
71 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -426,3 +426,56 @@ def test_load_with_schema_and_path(self, sample_schema): | |
match="Provide only one of 'dataset_path' or 'schema', not both.", | ||
): | ||
result = loader.load("test/users", sample_schema) | ||
|
||
def test_filter_columns_with_schema_columns(self, sample_schema): | ||
"""Test that columns are filtered correctly when schema columns are specified.""" | ||
loader = DatasetLoader() | ||
loader.schema = sample_schema | ||
|
||
# Create a DataFrame with extra columns | ||
df = pd.DataFrame( | ||
{ | ||
"email": ["[email protected]"], | ||
"first_name": ["John"], | ||
"timestamp": ["2023-01-01"], | ||
"extra_col": ["extra"], # This column should be filtered out | ||
} | ||
) | ||
|
||
filtered_df = loader._filter_columns(df) | ||
assert list(filtered_df.columns) == ["email", "first_name", "timestamp"] | ||
assert "extra_col" not in filtered_df.columns | ||
|
||
def test_filter_columns_without_schema_columns(self): | ||
"""Test that all columns are kept when no schema columns are specified.""" | ||
loader = DatasetLoader() | ||
# Create schema without columns | ||
loader.schema = SemanticLayerSchema( | ||
**{"name": "Users", "source": {"type": "csv", "path": "users.csv"}} | ||
) | ||
|
||
df = pd.DataFrame({"col1": [1], "col2": [2], "col3": [3]}) | ||
|
||
filtered_df = loader._filter_columns(df) | ||
assert list(filtered_df.columns) == ["col1", "col2", "col3"] | ||
|
||
def test_filter_columns_with_non_matching_columns(self, sample_schema): | ||
"""Test filtering when schema columns don't match DataFrame columns.""" | ||
loader = DatasetLoader() | ||
loader.schema = sample_schema | ||
|
||
# Create DataFrame with none of the schema columns | ||
df = pd.DataFrame({"different_col1": [1], "different_col2": [2]}) | ||
|
||
filtered_df = loader._filter_columns(df) | ||
assert len(filtered_df.columns) == 0 # Should return empty DataFrame | ||
|
||
def test_filter_columns_without_schema(self): | ||
"""Test that all columns are kept when no schema is set.""" | ||
loader = DatasetLoader() | ||
loader.schema = None | ||
|
||
df = pd.DataFrame({"col1": [1], "col2": [2]}) | ||
|
||
filtered_df = loader._filter_columns(df) | ||
assert list(filtered_df.columns) == ["col1", "col2"] |