Skip to content

Commit

Permalink
feat: Add strict param to eager/lazy frame "rename"
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Sep 30, 2024
1 parent e402e70 commit b1c59d4
Show file tree
Hide file tree
Showing 13 changed files with 66 additions and 26 deletions.
8 changes: 5 additions & 3 deletions crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,10 @@ impl LazyFrame {
///
/// `existing` and `new` are iterables of the same length containing the old and
/// corresponding new column names. Renaming happens to all `existing` columns
/// simultaneously, not iteratively. (In particular, all columns in `existing` must
/// already exist in the `LazyFrame` when `rename` is called.)
pub fn rename<I, J, T, S>(self, existing: I, new: J) -> Self
/// simultaneously, not iteratively. If `strict` is true, all columns in `existing`
/// must be present in the `LazyFrame` when `rename` is called; otherwise, only
/// those columns that are actually found will be renamed (others will be ignored).
pub fn rename<I, J, T, S>(self, existing: I, new: J, strict: bool) -> Self
where
I: IntoIterator<Item = T>,
J: IntoIterator<Item = S>,
Expand All @@ -420,6 +421,7 @@ impl LazyFrame {
self.map_private(DslFunction::Rename {
existing: existing_vec.into(),
new: new_vec.into(),
strict,
})
}

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/tests/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ fn test_cse_columns_projections() -> PolarsResult<()> {

let left = left.cross_join(right.clone().select([col("A")]), None);
let q = left.join(
right.rename(["B"], ["C"]),
right.rename(["B"], ["C"], true),
[col("A"), col("C")],
[col("A"), col("C")],
JoinType::Left.into(),
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-lazy/src/tests/optimization_checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ fn test_lazy_filter_and_rename() {
let lf = df
.clone()
.lazy()
.rename(["a"], ["x"])
.rename(["a"], ["x"], true)
.filter(col("x").map(
|s: Column| Ok(Some(s.as_materialized_series().gt(3)?.into_column())),
GetOutput::from_type(DataType::Boolean),
Expand All @@ -337,7 +337,7 @@ fn test_lazy_filter_and_rename() {
assert!(lf.collect().unwrap().equals(&correct));

// now we check if the column is rename or added when we don't select
let lf = df.lazy().rename(["a"], ["x"]).filter(col("x").map(
let lf = df.lazy().rename(["a"], ["x"], true).filter(col("x").map(
|s: Column| Ok(Some(s.as_materialized_series().gt(3)?.into_column())),
GetOutput::from_type(DataType::Boolean),
));
Expand Down
12 changes: 9 additions & 3 deletions crates/polars-plan/src/plans/functions/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub enum DslFunction {
Rename {
existing: Arc<[PlSmallStr]>,
new: Arc<[PlSmallStr]>,
strict: bool,
},
Unnest(Vec<Selector>),
Stats(StatsFunction),
Expand Down Expand Up @@ -119,10 +120,15 @@ impl DslFunction {
offset,
schema: Default::default(),
},
DslFunction::Rename { existing, new } => {
DslFunction::Rename {
existing,
new,
strict,
} => {
let swapping = new.iter().any(|name| input_schema.get(name).is_some());
validate_columns_in_input(existing.as_ref(), input_schema, "rename")?;

if strict {
validate_columns_in_input(existing.as_ref(), input_schema, "rename")?;
}
FunctionIR::Rename {
existing,
new,
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-python/src/lazyframe/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -962,9 +962,9 @@ impl PyLazyFrame {
ldf.with_columns_seq(exprs.to_exprs()).into()
}

fn rename(&mut self, existing: Vec<String>, new: Vec<String>) -> Self {
fn rename(&mut self, existing: Vec<String>, new: Vec<String>, strict: bool) -> Self {
let ldf = self.ldf.clone();
ldf.rename(existing, new).into()
ldf.rename(existing, new, strict).into()
}

fn reverse(&self) -> Self {
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,7 @@ impl SQLContext {
lf = lf.rename(
select_modifiers.rename.keys(),
select_modifiers.rename.values(),
true,
);
};
lf
Expand Down Expand Up @@ -1380,7 +1381,7 @@ impl SQLContext {
} else {
let existing_columns: Vec<_> = schema.iter_names().collect();
let new_columns: Vec<_> = alias.columns.iter().map(|c| c.value.clone()).collect();
Ok(lf.rename(existing_columns, new_columns))
Ok(lf.rename(existing_columns, new_columns, true))
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ impl SQLExprVisitor<'_> {
let schema_entry = schema.get_at_index(0);
if let Some((old_name, _)) = schema_entry {
let new_name = String::from(old_name.as_str()) + rand_string.as_str();
lf = lf.rename([old_name.to_string()], [new_name.clone()]);
lf = lf.rename([old_name.to_string()], [new_name.clone()], true);
return Ok(Expr::SubPlan(
SpecialEq::new(Arc::new(lf.logical_plan)),
vec![new_name],
Expand Down
2 changes: 1 addition & 1 deletion crates/polars/tests/it/lazy/predicate_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn test_predicate_after_renaming() -> PolarsResult<()> {
"bar" => [3, 2, 1]
]?
.lazy()
.rename(["foo", "bar"], ["foo2", "bar2"])
.rename(["foo", "bar"], ["foo2", "bar2"], true)
.filter(col("foo2").eq(col("bar2")))
.collect()?;

Expand Down
2 changes: 1 addition & 1 deletion crates/polars/tests/it/lazy/projection_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fn test_swap_rename() -> PolarsResult<()> {
"b" => [2],
]?
.lazy()
.rename(["a", "b"], ["b", "a"])
.rename(["a", "b"], ["b", "a"], true)
.collect()?;

let expected = df![
Expand Down
14 changes: 10 additions & 4 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4479,7 +4479,9 @@ def reverse(self) -> DataFrame:
"""
return self.select(F.col("*").reverse())

def rename(self, mapping: dict[str, str] | Callable[[str], str]) -> DataFrame:
def rename(
self, mapping: dict[str, str] | Callable[[str], str], *, strict: bool = True
) -> DataFrame:
"""
Rename column names.
Expand All @@ -4488,6 +4490,10 @@ def rename(self, mapping: dict[str, str] | Callable[[str], str]) -> DataFrame:
mapping
Key value pairs that map from old name to new name, or a function
that takes the old name as input and returns the new name.
strict
Validate that all column names exist in the current schema,
and throw an exception if any do not. (Note that this parameter
is a no-op when passing a function to `mapping`).
Examples
--------
Expand Down Expand Up @@ -4517,7 +4523,7 @@ def rename(self, mapping: dict[str, str] | Callable[[str], str]) -> DataFrame:
│ 3 ┆ 8 ┆ c │
└─────┴─────┴─────┘
"""
return self.lazy().rename(mapping).collect(_eager=True)
return self.lazy().rename(mapping, strict=strict).collect(_eager=True)

def insert_column(self, index: int, column: Series) -> DataFrame:
"""
Expand Down Expand Up @@ -7475,8 +7481,8 @@ def drop(
Names of the columns that should be removed from the dataframe.
Accepts column selector input.
strict
Validate that all column names exist in the schema and throw an
exception if a column name does not exist in the schema.
Validate that all column names exist in the current schema,
and throw an exception if any do not.
Examples
--------
Expand Down
14 changes: 10 additions & 4 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4973,8 +4973,8 @@ def drop(
Names of the columns that should be removed from the dataframe.
Accepts column selector input.
strict
Validate that all column names exist in the schema and throw an
exception if a column name does not exist in the schema.
Validate that all column names exist in the current schema,
and throw an exception if any do not.
Examples
--------
Expand Down Expand Up @@ -5031,7 +5031,9 @@ def drop(
drop_cols = parse_into_list_of_expressions(*columns)
return self._from_pyldf(self._ldf.drop(drop_cols, strict=strict))

def rename(self, mapping: dict[str, str] | Callable[[str], str]) -> LazyFrame:
def rename(
self, mapping: dict[str, str] | Callable[[str], str], *, strict: bool = True
) -> LazyFrame:
"""
Rename column names.
Expand All @@ -5040,6 +5042,10 @@ def rename(self, mapping: dict[str, str] | Callable[[str], str]) -> LazyFrame:
mapping
Key value pairs that map from old name to new name, or a function
that takes the old name as input and returns the new name.
strict
Validate that all column names exist in the current schema,
and throw an exception if any do not. (Note that this parameter
is a no-op when passing a function to `mapping`).
Notes
-----
Expand Down Expand Up @@ -5083,7 +5089,7 @@ def rename(self, mapping: dict[str, str] | Callable[[str], str]) -> LazyFrame:
else:
existing = list(mapping.keys())
new = list(mapping.values())
return self._from_pyldf(self._ldf.rename(existing, new))
return self._from_pyldf(self._ldf.rename(existing, new, strict))

def reverse(self) -> LazyFrame:
"""
Expand Down
17 changes: 14 additions & 3 deletions py-polars/tests/unit/lazyframe/test_rename.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import pytest

import polars as pl
from polars.exceptions import ColumnNotFoundError


def test_lazy_rename() -> None:
df = pl.DataFrame({"x": [1], "y": [2]})
lf = pl.LazyFrame({"x": [1], "y": [2]})

result = lf.rename({"y": "x", "x": "y"}).select(["x", "y"]).collect()
assert result.to_dict(as_series=False) == {"x": [2], "y": [1]}

# the `strict` param controls whether we fail on columns not found in the frame
remap_colnames = {"b": "a", "y": "x", "a": "b", "x": "y"}
with pytest.raises(ColumnNotFoundError, match="'b' is invalid"):
lf.rename(remap_colnames).collect()

result = df.lazy().rename({"y": "x", "x": "y"}).select(["x", "y"])
assert result.collect().to_dict(as_series=False) == {"x": [2], "y": [1]}
result = lf.rename(remap_colnames, strict=False).collect()
assert result.to_dict(as_series=False) == {"x": [2], "y": [1]}


def test_remove_redundant_mapping_4668() -> None:
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/unit/test_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,17 @@ def test_projection_join_names_9955() -> None:

def test_projection_rename_10595() -> None:
lf = pl.LazyFrame(schema={"a": pl.Float32, "b": pl.Float32})

result = lf.select("a", "b").rename({"b": "a", "a": "b"}).select("a")
assert result.collect().schema == {"a": pl.Float32}

result = (
lf.select("a", "b")
.rename({"c": "d", "b": "a", "d": "c", "a": "b"}, strict=False)
.select("a")
)
assert result.collect().schema == {"a": pl.Float32}


def test_projection_count_11841() -> None:
pl.LazyFrame({"x": 1}).select(records=pl.len()).select(
Expand Down

0 comments on commit b1c59d4

Please sign in to comment.