From f246a4cd16719bade4e717a17c5563f62238ff5f Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Wed, 25 Sep 2024 16:26:00 +1000 Subject: [PATCH] fix: Adding `with_row_index()` to previously collected lazy scan does not take effect (#18913) --- crates/polars-lazy/src/frame/mod.rs | 45 +++++++++++++++++----------- py-polars/tests/unit/io/test_scan.py | 16 ++++++++++ 2 files changed, 44 insertions(+), 17 deletions(-) diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index 3536c3071332..54163fe33544 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -1762,31 +1762,42 @@ impl LazyFrame { /// # Warning /// This can have a negative effect on query performance. This may for instance block /// predicate pushdown optimization. - pub fn with_row_index(mut self, name: S, offset: Option) -> LazyFrame + pub fn with_row_index(self, name: S, offset: Option) -> LazyFrame where S: Into, { let name = name.into(); - let add_row_index_in_map = match &mut self.logical_plan { - DslPlan::Scan { - file_options: options, - scan_type, - .. - } if !matches!(scan_type, FileScan::Anonymous { .. }) => { - let name = name.clone(); - options.row_index = Some(RowIndex { + + match &self.logical_plan { + v @ DslPlan::Scan { scan_type, .. } + if !matches!(scan_type, FileScan::Anonymous { .. }) => + { + let DslPlan::Scan { + sources, + mut file_options, + scan_type, + file_info, + cached_ir: _, + } = v.clone() + else { + unreachable!() + }; + + file_options.row_index = Some(RowIndex { name, offset: offset.unwrap_or(0), }); - false - }, - _ => true, - }; - if add_row_index_in_map { - self.map_private(DslFunction::RowIndex { name, offset }) - } else { - self + DslPlan::Scan { + sources, + file_options, + scan_type, + file_info, + cached_ir: Default::default(), + } + .into() + }, + _ => self.map_private(DslFunction::RowIndex { name, offset }), } } diff --git a/py-polars/tests/unit/io/test_scan.py b/py-polars/tests/unit/io/test_scan.py index 8fdfb83f44ec..3da1ade6b1e3 100644 --- a/py-polars/tests/unit/io/test_scan.py +++ b/py-polars/tests/unit/io/test_scan.py @@ -785,3 +785,19 @@ def test_scan_stringio(method: str) -> None: def test_empty_list(method: Callable[[list[str]], pl.LazyFrame]) -> None: with pytest.raises(pl.exceptions.ComputeError, match="expected at least 1 source"): _ = (method)([]).collect() + + +def test_scan_double_collect_row_index_invalidates_cached_ir_18892() -> None: + lf = pl.scan_csv(io.BytesIO(b"a\n1\n2\n3")) + + lf.collect() + + out = lf.with_row_index().collect() + + assert_frame_equal( + out, + pl.DataFrame( + {"index": [0, 1, 2], "a": [1, 2, 3]}, + schema={"index": pl.UInt32, "a": pl.Int64}, + ), + )