Skip to content

Commit

Permalink
fix(rust, python): improve concat_list with empty list error message (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
josh authored Mar 2, 2023
1 parent bac9b2d commit 5b5e3d7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
14 changes: 10 additions & 4 deletions polars/polars-lazy/polars-plan/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,16 @@ pub fn format_str<E: AsRef<[Expr]>>(format: &str, args: E) -> PolarsResult<Expr>
}

/// Concat lists entries.
pub fn concat_lst<E: AsRef<[IE]>, IE: Into<Expr> + Clone>(s: E) -> Expr {
let s = s.as_ref().iter().map(|e| e.clone().into()).collect();
pub fn concat_lst<E: AsRef<[IE]>, IE: Into<Expr> + Clone>(s: E) -> PolarsResult<Expr> {
let s: Vec<_> = s.as_ref().iter().map(|e| e.clone().into()).collect();

Expr::Function {
if s.is_empty() {
return Err(PolarsError::ComputeError(
"concat_list needs one or more expressions".into(),
));
}

Ok(Expr::Function {
input: s,
function: FunctionExpr::ListExpr(ListFunction::Concat),
options: FunctionOptions {
Expand All @@ -313,7 +319,7 @@ pub fn concat_lst<E: AsRef<[IE]>, IE: Into<Expr> + Clone>(s: E) -> Expr {
fmt_str: "concat_list",
..Default::default()
},
}
})
}

/// Create list entries that are range arrays
Expand Down
5 changes: 3 additions & 2 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,10 @@ fn concat_str(s: Vec<dsl::PyExpr>, separator: &str) -> dsl::PyExpr {
}

#[pyfunction]
fn concat_lst(s: Vec<dsl::PyExpr>) -> dsl::PyExpr {
fn concat_lst(s: Vec<dsl::PyExpr>) -> PyResult<dsl::PyExpr> {
let s = s.into_iter().map(|e| e.inner).collect::<Vec<_>>();
polars_rs::lazy::dsl::concat_lst(s).into()
let expr = polars_rs::lazy::dsl::concat_lst(s).map_err(PyPolarsErr::from)?;
Ok(expr.into())
}

#[pyfunction]
Expand Down
6 changes: 6 additions & 0 deletions py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import date, datetime, time

import pandas as pd
import pytest

import polars as pl

Expand Down Expand Up @@ -360,6 +361,11 @@ def test_concat_list_in_agg_6397() -> None:
}


def test_concat_list_empty_raises() -> None:
with pytest.raises(pl.ComputeError):
pl.DataFrame({"a": [1, 2, 3]}).with_columns(pl.concat_list([]))


def test_flat_aggregation_to_list_conversion_6918() -> None:
df = pl.DataFrame({"a": [1, 2, 2], "b": [[0, 1], [2, 3], [4, 5]]})

Expand Down

0 comments on commit 5b5e3d7

Please sign in to comment.