From b1c59d4e647d63c0d2bf278f0be5fa23a4f9895d Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Mon, 30 Sep 2024 13:19:13 +0400 Subject: [PATCH] feat: Add `strict` param to eager/lazy frame "rename" --- crates/polars-lazy/src/frame/mod.rs | 8 +++++--- crates/polars-lazy/src/tests/cse.rs | 2 +- .../src/tests/optimization_checks.rs | 4 ++-- crates/polars-plan/src/plans/functions/dsl.rs | 12 +++++++++--- crates/polars-python/src/lazyframe/general.rs | 4 ++-- crates/polars-sql/src/context.rs | 3 ++- crates/polars-sql/src/sql_expr.rs | 2 +- .../polars/tests/it/lazy/predicate_queries.rs | 2 +- .../polars/tests/it/lazy/projection_queries.rs | 2 +- py-polars/polars/dataframe/frame.py | 14 ++++++++++---- py-polars/polars/lazyframe/frame.py | 14 ++++++++++---- py-polars/tests/unit/lazyframe/test_rename.py | 17 ++++++++++++++--- py-polars/tests/unit/test_projections.py | 8 ++++++++ 13 files changed, 66 insertions(+), 26 deletions(-) diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index 54163fe33544..648bb198c57d 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -392,9 +392,10 @@ impl LazyFrame { /// /// `existing` and `new` are iterables of the same length containing the old and /// corresponding new column names. Renaming happens to all `existing` columns - /// simultaneously, not iteratively. (In particular, all columns in `existing` must - /// already exist in the `LazyFrame` when `rename` is called.) - pub fn rename(self, existing: I, new: J) -> Self + /// simultaneously, not iteratively. If `strict` is true, all columns in `existing` + /// must be present in the `LazyFrame` when `rename` is called; otherwise, only + /// those columns that are actually found will be renamed (others will be ignored). + pub fn rename(self, existing: I, new: J, strict: bool) -> Self where I: IntoIterator, J: IntoIterator, @@ -420,6 +421,7 @@ impl LazyFrame { self.map_private(DslFunction::Rename { existing: existing_vec.into(), new: new_vec.into(), + strict, }) } diff --git a/crates/polars-lazy/src/tests/cse.rs b/crates/polars-lazy/src/tests/cse.rs index 9122ba3814ff..6ed8e1cc67c8 100644 --- a/crates/polars-lazy/src/tests/cse.rs +++ b/crates/polars-lazy/src/tests/cse.rs @@ -307,7 +307,7 @@ fn test_cse_columns_projections() -> PolarsResult<()> { let left = left.cross_join(right.clone().select([col("A")]), None); let q = left.join( - right.rename(["B"], ["C"]), + right.rename(["B"], ["C"], true), [col("A"), col("C")], [col("A"), col("C")], JoinType::Left.into(), diff --git a/crates/polars-lazy/src/tests/optimization_checks.rs b/crates/polars-lazy/src/tests/optimization_checks.rs index 4a99413d48cc..51a6bf26cfb0 100644 --- a/crates/polars-lazy/src/tests/optimization_checks.rs +++ b/crates/polars-lazy/src/tests/optimization_checks.rs @@ -323,7 +323,7 @@ fn test_lazy_filter_and_rename() { let lf = df .clone() .lazy() - .rename(["a"], ["x"]) + .rename(["a"], ["x"], true) .filter(col("x").map( |s: Column| Ok(Some(s.as_materialized_series().gt(3)?.into_column())), GetOutput::from_type(DataType::Boolean), @@ -337,7 +337,7 @@ fn test_lazy_filter_and_rename() { assert!(lf.collect().unwrap().equals(&correct)); // now we check if the column is rename or added when we don't select - let lf = df.lazy().rename(["a"], ["x"]).filter(col("x").map( + let lf = df.lazy().rename(["a"], ["x"], true).filter(col("x").map( |s: Column| Ok(Some(s.as_materialized_series().gt(3)?.into_column())), GetOutput::from_type(DataType::Boolean), )); diff --git a/crates/polars-plan/src/plans/functions/dsl.rs b/crates/polars-plan/src/plans/functions/dsl.rs index fd4b740af9df..c72e4ebe3c06 100644 --- a/crates/polars-plan/src/plans/functions/dsl.rs +++ b/crates/polars-plan/src/plans/functions/dsl.rs @@ -42,6 +42,7 @@ pub enum DslFunction { Rename { existing: Arc<[PlSmallStr]>, new: Arc<[PlSmallStr]>, + strict: bool, }, Unnest(Vec), Stats(StatsFunction), @@ -119,10 +120,15 @@ impl DslFunction { offset, schema: Default::default(), }, - DslFunction::Rename { existing, new } => { + DslFunction::Rename { + existing, + new, + strict, + } => { let swapping = new.iter().any(|name| input_schema.get(name).is_some()); - validate_columns_in_input(existing.as_ref(), input_schema, "rename")?; - + if strict { + validate_columns_in_input(existing.as_ref(), input_schema, "rename")?; + } FunctionIR::Rename { existing, new, diff --git a/crates/polars-python/src/lazyframe/general.rs b/crates/polars-python/src/lazyframe/general.rs index 18b9323388e7..eb21666d2685 100644 --- a/crates/polars-python/src/lazyframe/general.rs +++ b/crates/polars-python/src/lazyframe/general.rs @@ -962,9 +962,9 @@ impl PyLazyFrame { ldf.with_columns_seq(exprs.to_exprs()).into() } - fn rename(&mut self, existing: Vec, new: Vec) -> Self { + fn rename(&mut self, existing: Vec, new: Vec, strict: bool) -> Self { let ldf = self.ldf.clone(); - ldf.rename(existing, new).into() + ldf.rename(existing, new, strict).into() } fn reverse(&self) -> Self { diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index e95bda62916d..342a5e0883d2 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -751,6 +751,7 @@ impl SQLContext { lf = lf.rename( select_modifiers.rename.keys(), select_modifiers.rename.values(), + true, ); }; lf @@ -1380,7 +1381,7 @@ impl SQLContext { } else { let existing_columns: Vec<_> = schema.iter_names().collect(); let new_columns: Vec<_> = alias.columns.iter().map(|c| c.value.clone()).collect(); - Ok(lf.rename(existing_columns, new_columns)) + Ok(lf.rename(existing_columns, new_columns, true)) } } } diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 148a7fe5735e..f9caa288cb82 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -263,7 +263,7 @@ impl SQLExprVisitor<'_> { let schema_entry = schema.get_at_index(0); if let Some((old_name, _)) = schema_entry { let new_name = String::from(old_name.as_str()) + rand_string.as_str(); - lf = lf.rename([old_name.to_string()], [new_name.clone()]); + lf = lf.rename([old_name.to_string()], [new_name.clone()], true); return Ok(Expr::SubPlan( SpecialEq::new(Arc::new(lf.logical_plan)), vec![new_name], diff --git a/crates/polars/tests/it/lazy/predicate_queries.rs b/crates/polars/tests/it/lazy/predicate_queries.rs index 63cccd65aeaa..49460facc118 100644 --- a/crates/polars/tests/it/lazy/predicate_queries.rs +++ b/crates/polars/tests/it/lazy/predicate_queries.rs @@ -11,7 +11,7 @@ fn test_predicate_after_renaming() -> PolarsResult<()> { "bar" => [3, 2, 1] ]? .lazy() - .rename(["foo", "bar"], ["foo2", "bar2"]) + .rename(["foo", "bar"], ["foo2", "bar2"], true) .filter(col("foo2").eq(col("bar2"))) .collect()?; diff --git a/crates/polars/tests/it/lazy/projection_queries.rs b/crates/polars/tests/it/lazy/projection_queries.rs index 03b7a44bc114..e5870e81ce4e 100644 --- a/crates/polars/tests/it/lazy/projection_queries.rs +++ b/crates/polars/tests/it/lazy/projection_queries.rs @@ -22,7 +22,7 @@ fn test_swap_rename() -> PolarsResult<()> { "b" => [2], ]? .lazy() - .rename(["a", "b"], ["b", "a"]) + .rename(["a", "b"], ["b", "a"], true) .collect()?; let expected = df![ diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 911d057a1f46..da79a0cf57df 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -4479,7 +4479,9 @@ def reverse(self) -> DataFrame: """ return self.select(F.col("*").reverse()) - def rename(self, mapping: dict[str, str] | Callable[[str], str]) -> DataFrame: + def rename( + self, mapping: dict[str, str] | Callable[[str], str], *, strict: bool = True + ) -> DataFrame: """ Rename column names. @@ -4488,6 +4490,10 @@ def rename(self, mapping: dict[str, str] | Callable[[str], str]) -> DataFrame: mapping Key value pairs that map from old name to new name, or a function that takes the old name as input and returns the new name. + strict + Validate that all column names exist in the current schema, + and throw an exception if any do not. (Note that this parameter + is a no-op when passing a function to `mapping`). Examples -------- @@ -4517,7 +4523,7 @@ def rename(self, mapping: dict[str, str] | Callable[[str], str]) -> DataFrame: │ 3 ┆ 8 ┆ c │ └─────┴─────┴─────┘ """ - return self.lazy().rename(mapping).collect(_eager=True) + return self.lazy().rename(mapping, strict=strict).collect(_eager=True) def insert_column(self, index: int, column: Series) -> DataFrame: """ @@ -7475,8 +7481,8 @@ def drop( Names of the columns that should be removed from the dataframe. Accepts column selector input. strict - Validate that all column names exist in the schema and throw an - exception if a column name does not exist in the schema. + Validate that all column names exist in the current schema, + and throw an exception if any do not. Examples -------- diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 9e65f918d385..7aa9a870b9c6 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -4973,8 +4973,8 @@ def drop( Names of the columns that should be removed from the dataframe. Accepts column selector input. strict - Validate that all column names exist in the schema and throw an - exception if a column name does not exist in the schema. + Validate that all column names exist in the current schema, + and throw an exception if any do not. Examples -------- @@ -5031,7 +5031,9 @@ def drop( drop_cols = parse_into_list_of_expressions(*columns) return self._from_pyldf(self._ldf.drop(drop_cols, strict=strict)) - def rename(self, mapping: dict[str, str] | Callable[[str], str]) -> LazyFrame: + def rename( + self, mapping: dict[str, str] | Callable[[str], str], *, strict: bool = True + ) -> LazyFrame: """ Rename column names. @@ -5040,6 +5042,10 @@ def rename(self, mapping: dict[str, str] | Callable[[str], str]) -> LazyFrame: mapping Key value pairs that map from old name to new name, or a function that takes the old name as input and returns the new name. + strict + Validate that all column names exist in the current schema, + and throw an exception if any do not. (Note that this parameter + is a no-op when passing a function to `mapping`). Notes ----- @@ -5083,7 +5089,7 @@ def rename(self, mapping: dict[str, str] | Callable[[str], str]) -> LazyFrame: else: existing = list(mapping.keys()) new = list(mapping.values()) - return self._from_pyldf(self._ldf.rename(existing, new)) + return self._from_pyldf(self._ldf.rename(existing, new, strict)) def reverse(self) -> LazyFrame: """ diff --git a/py-polars/tests/unit/lazyframe/test_rename.py b/py-polars/tests/unit/lazyframe/test_rename.py index 15eae7d4cc34..45e2d26ac7ca 100644 --- a/py-polars/tests/unit/lazyframe/test_rename.py +++ b/py-polars/tests/unit/lazyframe/test_rename.py @@ -1,11 +1,22 @@ +import pytest + import polars as pl +from polars.exceptions import ColumnNotFoundError def test_lazy_rename() -> None: - df = pl.DataFrame({"x": [1], "y": [2]}) + lf = pl.LazyFrame({"x": [1], "y": [2]}) + + result = lf.rename({"y": "x", "x": "y"}).select(["x", "y"]).collect() + assert result.to_dict(as_series=False) == {"x": [2], "y": [1]} + + # the `strict` param controls whether we fail on columns not found in the frame + remap_colnames = {"b": "a", "y": "x", "a": "b", "x": "y"} + with pytest.raises(ColumnNotFoundError, match="'b' is invalid"): + lf.rename(remap_colnames).collect() - result = df.lazy().rename({"y": "x", "x": "y"}).select(["x", "y"]) - assert result.collect().to_dict(as_series=False) == {"x": [2], "y": [1]} + result = lf.rename(remap_colnames, strict=False).collect() + assert result.to_dict(as_series=False) == {"x": [2], "y": [1]} def test_remove_redundant_mapping_4668() -> None: diff --git a/py-polars/tests/unit/test_projections.py b/py-polars/tests/unit/test_projections.py index 7c279648fa1c..48b3077423e0 100644 --- a/py-polars/tests/unit/test_projections.py +++ b/py-polars/tests/unit/test_projections.py @@ -360,9 +360,17 @@ def test_projection_join_names_9955() -> None: def test_projection_rename_10595() -> None: lf = pl.LazyFrame(schema={"a": pl.Float32, "b": pl.Float32}) + result = lf.select("a", "b").rename({"b": "a", "a": "b"}).select("a") assert result.collect().schema == {"a": pl.Float32} + result = ( + lf.select("a", "b") + .rename({"c": "d", "b": "a", "d": "c", "a": "b"}, strict=False) + .select("a") + ) + assert result.collect().schema == {"a": pl.Float32} + def test_projection_count_11841() -> None: pl.LazyFrame({"x": 1}).select(records=pl.len()).select(