From 79237c545c18b32540d6d49ea8107fd1cdd861ab Mon Sep 17 00:00:00 2001 From: Etienne Bacher <52219252+etiennebacher@users.noreply.github.com> Date: Thu, 16 Jan 2025 14:12:37 +0100 Subject: [PATCH] feat: `pl$arg_where()` and `$arg_true()` --- R/000-wrappers.R | 6 ++++++ R/expr-expr.R | 20 ++++++++++---------- R/functions-lazy.R | 15 +++++++++++++++ man/expr__arg_true.Rd | 18 ++++++++++++++++++ man/pl.Rd | 2 +- man/pl__arg_where.Rd | 23 +++++++++++++++++++++++ src/init.c | 6 ++++++ src/rust/Cargo.toml | 1 + src/rust/api.h | 1 + src/rust/src/functions/lazy.rs | 5 +++++ tests/testthat/test-expr-expr.R | 8 ++++++++ tests/testthat/test-functions-lazy.R | 8 ++++++++ 12 files changed, 102 insertions(+), 11 deletions(-) create mode 100644 man/expr__arg_true.Rd create mode 100644 man/pl__arg_where.Rd diff --git a/R/000-wrappers.R b/R/000-wrappers.R index f4ccd0d5..46554ef8 100644 --- a/R/000-wrappers.R +++ b/R/000-wrappers.R @@ -48,6 +48,12 @@ NULL } +`arg_where` <- function(`condition`) { + `condition` <- .savvy_extract_ptr(`condition`, "PlRExpr") + .savvy_wrap_PlRExpr(.Call(savvy_arg_where__impl, `condition`)) +} + + `as_struct` <- function(`exprs`) { .savvy_wrap_PlRExpr(.Call(savvy_as_struct__impl, `exprs`)) } diff --git a/R/expr-expr.R b/R/expr-expr.R index eb07afbd..7a662652 100644 --- a/R/expr-expr.R +++ b/R/expr-expr.R @@ -1510,16 +1510,16 @@ expr__arg_unique <- function() { wrap() } -# TODO-REWRITE: requires pl$arg_where() -# #' Return indices where expression is true -# #' -# #' @inherit as_polars_expr return -# #' @examples -# #' df <- pl$DataFrame(a = c(1, 1, 2, 1)) -# #' df$select((pl$col("a") == 1)$arg_true()) -# expr__arg_true <- function() { -# pl$arg_where(self$`_rexpr`) -# } +#' Return indices where expression is true +#' +#' @inherit as_polars_expr return +#' @examples +#' df <- pl$DataFrame(a = c(1, 1, 2, 1)) +#' df$select((pl$col("a") == 1)$arg_true()) +expr__arg_true <- function() { + arg_where(self$`_rexpr`) |> + wrap() +} #' Get the number of non-null elements in the column #' diff --git a/R/functions-lazy.R b/R/functions-lazy.R index 228e7d28..0738a798 100644 --- a/R/functions-lazy.R +++ b/R/functions-lazy.R @@ -60,3 +60,18 @@ pl__coalesce <- function(...) { coalesce() |> wrap() } + +#' Return indices where `condition` evaluates to `TRUE` +#' +#' @param condition Boolean expression to evaluate. +#' @inherit as_polars_expr return +#' +#' @examples +#' df <- pl$DataFrame(a = 1:5) +#' df$select( +#' pl$arg_where(pl$col("a") %% 2 == 0) +#' ) +pl__arg_where <- function(condition) { + arg_where(as_polars_expr(condition)$`_rexpr`) |> + wrap() +} diff --git a/man/expr__arg_true.Rd b/man/expr__arg_true.Rd new file mode 100644 index 00000000..78f20a28 --- /dev/null +++ b/man/expr__arg_true.Rd @@ -0,0 +1,18 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/expr-expr.R +\name{expr__arg_true} +\alias{expr__arg_true} +\title{Return indices where expression is true} +\usage{ +expr__arg_true() +} +\value{ +A polars \link{expression} +} +\description{ +Return indices where expression is true +} +\examples{ +df <- pl$DataFrame(a = c(1, 1, 2, 1)) +df$select((pl$col("a") == 1)$arg_true()) +} diff --git a/man/pl.Rd b/man/pl.Rd index 33e6b48d..0c796695 100644 --- a/man/pl.Rd +++ b/man/pl.Rd @@ -5,7 +5,7 @@ \alias{pl} \title{Polars top-level function namespace} \format{ -An object of class \code{polars_object} of length 75. +An object of class \code{polars_object} of length 76. } \usage{ pl diff --git a/man/pl__arg_where.Rd b/man/pl__arg_where.Rd new file mode 100644 index 00000000..baa4945a --- /dev/null +++ b/man/pl__arg_where.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/functions-lazy.R +\name{pl__arg_where} +\alias{pl__arg_where} +\title{Return indices where \code{condition} evaluates to \code{TRUE}} +\usage{ +pl__arg_where(condition) +} +\arguments{ +\item{condition}{Boolean expression to evaluate.} +} +\value{ +A polars \link{expression} +} +\description{ +Return indices where \code{condition} evaluates to \code{TRUE} +} +\examples{ +df <- pl$DataFrame(a = 1:5) +df$select( + pl$arg_where(pl$col("a") \%\% 2 == 0) +) +} diff --git a/src/init.c b/src/init.c index edb61641..62e8e1b4 100644 --- a/src/init.c +++ b/src/init.c @@ -44,6 +44,11 @@ SEXP savvy_any_horizontal__impl(SEXP c_arg__exprs) { return handle_result(res); } +SEXP savvy_arg_where__impl(SEXP c_arg__condition) { + SEXP res = savvy_arg_where__ffi(c_arg__condition); + return handle_result(res); +} + SEXP savvy_as_struct__impl(SEXP c_arg__exprs) { SEXP res = savvy_as_struct__ffi(c_arg__exprs); return handle_result(res); @@ -2558,6 +2563,7 @@ SEXP savvy_PlRWhen_then__impl(SEXP self__, SEXP c_arg__statement) { static const R_CallMethodDef CallEntries[] = { {"savvy_all_horizontal__impl", (DL_FUNC) &savvy_all_horizontal__impl, 1}, {"savvy_any_horizontal__impl", (DL_FUNC) &savvy_any_horizontal__impl, 1}, + {"savvy_arg_where__impl", (DL_FUNC) &savvy_arg_where__impl, 1}, {"savvy_as_struct__impl", (DL_FUNC) &savvy_as_struct__impl, 1}, {"savvy_coalesce__impl", (DL_FUNC) &savvy_coalesce__impl, 1}, {"savvy_col__impl", (DL_FUNC) &savvy_col__impl, 1}, diff --git a/src/rust/Cargo.toml b/src/rust/Cargo.toml index 8644f9da..9df3711e 100644 --- a/src/rust/Cargo.toml +++ b/src/rust/Cargo.toml @@ -28,6 +28,7 @@ rev = "841c387d99d7024037556c4ef79d96bf2caac397" features = [ "abs", "approx_unique", + "arg_where", "array_any_all", "array_count", "array_to_struct", diff --git a/src/rust/api.h b/src/rust/api.h index bfe5178f..715e3fda 100644 --- a/src/rust/api.h +++ b/src/rust/api.h @@ -1,5 +1,6 @@ SEXP savvy_all_horizontal__ffi(SEXP c_arg__exprs); SEXP savvy_any_horizontal__ffi(SEXP c_arg__exprs); +SEXP savvy_arg_where__ffi(SEXP c_arg__condition); SEXP savvy_as_struct__ffi(SEXP c_arg__exprs); SEXP savvy_coalesce__ffi(SEXP c_arg__exprs); SEXP savvy_col__ffi(SEXP c_arg__name); diff --git a/src/rust/src/functions/lazy.rs b/src/rust/src/functions/lazy.rs index e5b1aa49..08f6a138 100644 --- a/src/rust/src/functions/lazy.rs +++ b/src/rust/src/functions/lazy.rs @@ -264,3 +264,8 @@ pub fn concat_lf_diagonal( .map_err(RPolarsErr::from)?; Ok(lf.into()) } + +#[savvy] +pub fn arg_where(condition: PlRExpr) -> Result { + Ok(dsl::arg_where(condition.inner.clone()).into()) +} diff --git a/tests/testthat/test-expr-expr.R b/tests/testthat/test-expr-expr.R index af6129ff..b7092b81 100644 --- a/tests/testthat/test-expr-expr.R +++ b/tests/testthat/test-expr-expr.R @@ -1338,6 +1338,14 @@ test_that("arg_unique", { ) }) +test_that("arg_true", { + df <- pl$DataFrame(a = c(1, 1, 2, 1)) + expect_equal( + df$select((pl$col("a") == 1)$arg_true()), + pl$DataFrame(a = c(0, 1, 3))$cast(pl$UInt32) + ) +}) + # test_that("Expr_quantile", { # v <- sample(0:100) # expect_equal( diff --git a/tests/testthat/test-functions-lazy.R b/tests/testthat/test-functions-lazy.R index 37793a35..f839b0d6 100644 --- a/tests/testthat/test-functions-lazy.R +++ b/tests/testthat/test-functions-lazy.R @@ -18,3 +18,11 @@ test_that("pl$coalesce()", { "must be passed by position, not name" ) }) + +test_that("arg_where", { + df <- pl$DataFrame(a = 1:5) + expect_equal( + df$select(pl$arg_where(pl$col("a") %% 2 == 0)), + pl$DataFrame(a = c(1, 3))$cast(pl$UInt32) + ) +})