Skip to content

Commit

Permalink
fix: Adding with_row_index() to previously collected lazy scan does…
Browse files Browse the repository at this point in the history
… not take effect (#18913)
  • Loading branch information
nameexhaustion authored Sep 25, 2024
1 parent f235240 commit f246a4c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 17 deletions.
45 changes: 28 additions & 17 deletions crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1762,31 +1762,42 @@ impl LazyFrame {
/// # Warning
/// This can have a negative effect on query performance. This may for instance block
/// predicate pushdown optimization.
pub fn with_row_index<S>(mut self, name: S, offset: Option<IdxSize>) -> LazyFrame
pub fn with_row_index<S>(self, name: S, offset: Option<IdxSize>) -> LazyFrame
where
S: Into<PlSmallStr>,
{
let name = name.into();
let add_row_index_in_map = match &mut self.logical_plan {
DslPlan::Scan {
file_options: options,
scan_type,
..
} if !matches!(scan_type, FileScan::Anonymous { .. }) => {
let name = name.clone();
options.row_index = Some(RowIndex {

match &self.logical_plan {
v @ DslPlan::Scan { scan_type, .. }
if !matches!(scan_type, FileScan::Anonymous { .. }) =>
{
let DslPlan::Scan {
sources,
mut file_options,
scan_type,
file_info,
cached_ir: _,
} = v.clone()
else {
unreachable!()
};

file_options.row_index = Some(RowIndex {
name,
offset: offset.unwrap_or(0),
});
false
},
_ => true,
};

if add_row_index_in_map {
self.map_private(DslFunction::RowIndex { name, offset })
} else {
self
DslPlan::Scan {
sources,
file_options,
scan_type,
file_info,
cached_ir: Default::default(),
}
.into()
},
_ => self.map_private(DslFunction::RowIndex { name, offset }),
}
}

Expand Down
16 changes: 16 additions & 0 deletions py-polars/tests/unit/io/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,3 +785,19 @@ def test_scan_stringio(method: str) -> None:
def test_empty_list(method: Callable[[list[str]], pl.LazyFrame]) -> None:
with pytest.raises(pl.exceptions.ComputeError, match="expected at least 1 source"):
_ = (method)([]).collect()


def test_scan_double_collect_row_index_invalidates_cached_ir_18892() -> None:
lf = pl.scan_csv(io.BytesIO(b"a\n1\n2\n3"))

lf.collect()

out = lf.with_row_index().collect()

assert_frame_equal(
out,
pl.DataFrame(
{"index": [0, 1, 2], "a": [1, 2, 3]},
schema={"index": pl.UInt32, "a": pl.Int64},
),
)

0 comments on commit f246a4c

Please sign in to comment.