Skip to content

Commit

Permalink
make sure recurse can handle dots
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinushey committed Oct 6, 2024
1 parent dab64b2 commit 855593e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 29 deletions.
9 changes: 3 additions & 6 deletions R/ffi.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,13 @@

`__ffi__recurse` <- function(object, callback, ...) {

symbol <- as.symbol(names(formals(args(callback)))[[1L]])
expr <- body(callback)
envir <- new.env(parent = environment(callback))
callback <- match.fun(callback)

.Call(
"renv_ffi__recurse",
object,
symbol,
expr,
envir,
callback,
environment(),
PACKAGE = "renv"
)

Expand Down
59 changes: 37 additions & 22 deletions inst/ext/renv.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
#include <R_ext/Rdynload.h>
#include <Rinternals.h>

// needed for macro sanity below
#define DBLSXP REALSXP

#define DBL_PTR REAL
#define INT_PTR INTEGER
#define LGL_PTR LOGICAL

static const int _NILSXP = NILSXP;
static const int _INTSXP = INTSXP;
static const int _DBLSXP = DBLSXP;
Expand Down Expand Up @@ -122,18 +125,18 @@ static SEXP renv_dependencies_recurse(SEXP object,
#define GET_STRSXP(__X__, __I__) Rf_ScalarString(STRING_ELT(__X__, __I__))
#define GET_VECSXP(__X__, __I__) VECTOR_ELT(__X__, __I__)

#define EXTRACT_INTSXP(__X__) INTEGER(__X__)[0]
#define EXTRACT_DBLSXP(__X__) REAL(__X__)[0]
#define EXTRACT_LGLSXP(__X__) LOGICAL(__X__)[0]
#define EXTRACT_STRSXP(__X__) STRING_ELT(__X__, 0)
#define EXTRACT_VECSXP(__X__) __X__

#define SET_INTSXP(__X__, __I__, __V__) INTEGER(__X__)[__I__] = __V__
#define SET_DBLSXP(__X__, __I__, __V__) REAL(__X__)[__I__] = __V__
#define SET_LGLSXP(__X__, __I__, __V__) LOGICAL(__X__)[__I__] = __V__
#define SET_INTSXP(__X__, __I__, __V__) INT_PTR(__X__)[__I__] = __V__
#define SET_DBLSXP(__X__, __I__, __V__) DBL_PTR(__X__)[__I__] = __V__
#define SET_LGLSXP(__X__, __I__, __V__) LGL_PTR(__X__)[__I__] = __V__
#define SET_STRSXP(__X__, __I__, __V__) SET_STRING_ELT(__X__, __I__, __V__)
#define SET_VECSXP(__X__, __I__, __V__) SET_VECTOR_ELT(__X__, __I__, __V__)

#define EXTRACT_INTSXP(__X__) INT_PTR(__X__)[0]
#define EXTRACT_DBLSXP(__X__) DBL_PTR(__X__)[0]
#define EXTRACT_LGLSXP(__X__) LGL_PTR(__X__)[0]
#define EXTRACT_STRSXP(__X__) STRING_ELT(__X__, 0)
#define EXTRACT_VECSXP(__X__) __X__

#define COERCE_INTSXP(__X__) Rf_coerceVector(__X__, INTSXP)
#define COERCE_DBLSXP(__X__) Rf_coerceVector(__X__, DBLSXP)
#define COERCE_LGLSXP(__X__) Rf_coerceVector(__X__, LGLSXP)
Expand All @@ -149,11 +152,11 @@ static SEXP renv_dependencies_recurse(SEXP object,

#define ENUMERATE_CASE(__TYPE__) ENUMERATE_CASE_IMPL(__TYPE__, GET_NAMES##__TYPE__, ENUMERATE_CASE_IMPL##__TYPE__)

#define ENUMERATE_CASE_IMPL(__TYPE__, __NAMES__, __DISPATCH__) \
#define ENUMERATE_CASE_IMPL(__TYPE__, __GET_NAMES__, __DISPATCH__) \
do \
{ \
SEXP result = R_NilValue; \
SEXP names = PROTECT(__NAMES__(x)); \
SEXP names = PROTECT(__GET_NAMES__(x)); \
\
switch (TYPEOF(type)) \
{ \
Expand All @@ -162,7 +165,7 @@ static SEXP renv_dependencies_recurse(SEXP object,
case _LGLSXP: __DISPATCH__(result, __TYPE__, _LGLSXP); break; \
case _STRSXP: __DISPATCH__(result, __TYPE__, _STRSXP); break; \
case _VECSXP: __DISPATCH__(result, __TYPE__, _VECSXP); break; \
case _NILSXP: __DISPATCH__(result, __TYPE__, _VECSXP); break; \
case _NILSXP: __DISPATCH__(result, __TYPE__, _VECSXP); break; \
} \
\
UNPROTECT(1); \
Expand Down Expand Up @@ -267,6 +270,7 @@ static SEXP enumerate(SEXP x,

switch (TYPEOF(x))
{
case _NILSXP: return R_NilValue;
case _INTSXP: ENUMERATE_CASE(_INTSXP);
case _DBLSXP: ENUMERATE_CASE(_DBLSXP);
case _LGLSXP: ENUMERATE_CASE(_LGLSXP);
Expand All @@ -279,15 +283,17 @@ static SEXP enumerate(SEXP x,
return R_NilValue;
}

static SEXP recurse(SEXP object,
SEXP symbol,
SEXP expr,
SEXP envir)
static SEXP recurse_impl(SEXP object,
SEXP objectsym,
SEXP callback,
SEXP callbacksym,
SEXP envir)
{
if (object != R_MissingArg)
{
Rf_defineVar(symbol, object, envir);
Rf_eval(expr, envir);
Rf_defineVar(objectsym, object, envir);
SEXP call = Rf_lang3(callbacksym, objectsym, R_DotsSymbol);
R_forceAndCall(call, 1, envir);
}

switch (TYPEOF(object))
Expand All @@ -296,7 +302,7 @@ static SEXP recurse(SEXP object,
case EXPRSXP:
{
for (R_xlen_t i = 0, n = Rf_xlength(object); i < n; i++)
recurse(VECTOR_ELT(object, i), symbol, expr, envir);
recurse_impl(VECTOR_ELT(object, i), objectsym, callback, callbacksym, envir);
break;
}

Expand All @@ -305,7 +311,7 @@ static SEXP recurse(SEXP object,
{
while (object != R_NilValue)
{
recurse(CAR(object), symbol, expr, envir);
recurse_impl(CAR(object), objectsym, callback, callbacksym, envir);
object = CDR(object);
}
break;
Expand All @@ -315,13 +321,22 @@ static SEXP recurse(SEXP object,
return R_NilValue;
}

static SEXP recurse(SEXP object,
SEXP callback,
SEXP envir)
{
SEXP callbacksym = Rf_install("callback");
SEXP objectsym = Rf_install("object");
return recurse_impl(object, objectsym, callback, callbacksym, envir);
}

// Init ----

static const R_CallMethodDef callEntries[] = {
{"renv_ffi__renv_call_expect", (DL_FUNC) &renv_call_expect, 3},
{"renv_ffi__renv_dependencies_recurse", (DL_FUNC) &renv_dependencies_recurse, 4},
{"renv_ffi__enumerate", (DL_FUNC) &enumerate, 3},
{"renv_ffi__recurse", (DL_FUNC) &recurse, 4},
{"renv_ffi__recurse", (DL_FUNC) &recurse, 3},
{NULL, NULL, 0}
};

Expand Down
14 changes: 13 additions & 1 deletion tests/testthat/test-recurse.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

test_that("recurse() can handle missing objects", {
data <- substitute(list(a = A), list(A = quote(expr = )))
expect_no_error(recurse(data, function(node) print(node)))
expect_no_error(recurse(data, function(node) force(node)))
})

test_that("recurse() can handle lists", {
Expand All @@ -22,3 +22,15 @@ test_that("recurse() can handle lists", {
expect_equal(items, list(1, 2, 3, 4))

})

test_that("recurse() can handle dots", {

counter <- 0L
recurse(list(1, list(2, list(3, list(4, list(5))))), function(node, extra) {
expect_equal(extra, 42)
if (is.list(node))
counter <<- counter + 1L
}, extra = 42)
expect_equal(counter, 5L)

})

0 comments on commit 855593e

Please sign in to comment.