Skip to content

Commit

Permalink
Refactor lhs parser to support resp() and not cache length() outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
mitchelloharawild committed Sep 4, 2023
1 parent 7a6f6c1 commit 45a4c46
Showing 1 changed file with 98 additions and 106 deletions.
204 changes: 98 additions & 106 deletions R/parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,20 @@ parse_model_lhs <- function(model){

# Traverse call removing all resp() usage
# This is used to evaluate the response from the input data
response_exprs <- lapply(model_lhs, traverse,
.f = function(x, y) {
if(is_resp(y)) x[[1]] else call2(call_name(y), !!!x)
},
.g = function(x) x[-1],
# .h = function(x) if(is_resp(x)) x[[length(x)]] else x,
base = function(x) is_syntactic_literal(x) || is_symbol(x)
response_exprs <- lapply(
model_lhs, traverse,
.f = function(x, y) {
if(is_resp(y)) x[[1]] else call2(y[[1]], !!!x)
},
.g = function(x) x[-1],
# .h = function(x) if(is_resp(x)) x[[length(x)]] else x,
base = function(x) is_syntactic_literal(x) || is_symbol(x)
)

has_resp <- function(x) traverse(
x,
.f = function(x,y) x[[1]]||y, .g = function(x) x[-1], .h = is_resp,
base = function(x) is_syntactic_literal(x) || is_symbol(x)
)

# Traverse call AST to parse out order of transformations
Expand All @@ -146,113 +153,98 @@ parse_model_lhs <- function(model){
# If the response is not set via resp(), identify the response by the maximum length object until encountering ties
#
# Returns a list of increasing depth of transformations
traversed_lhs <- lapply(model_lhs, traverse,
.f = function(x, y) {
# if(any(resp_pos <- map_lgl(x, function(x) any(names(x) %in% "response")))){
# if(sum(resp_pos) != 1) abort("The `resp()` function can only be used once per response variable. For multivariate modelling, use `vars()`.")
# names(x)[resp_pos] <- "response"
# }

vals <- list()
resp <- NULL
args <- lapply(x, function(y) {
if(inherits(y, "resp_path")) {
vals[names(attr(y, "vals"))] <<- attr(y, "vals")
resp <<- y
y[[length(y)]]
} else {
if(y$len == 1) {
vals[as_label(y$cl)] <<- y$res
} else {
resp <<- y$cl
}
y$cl
}
})

if(is.list(y)) y <- y[["cl"]]
path <- c(resp, as.call(c(y[[1]], args)))

# if(inherits(x[[1]], "resp_path")) {
# x <- x[[1]]
# vals <- attr(x, "vals")
# path <- c(x, as.call(c(y[[1]], x[[length(x)]])))
#
# } else {
# x <- transpose(x)
# x$len <- as.double(x$len)
#
# if(any(len1 <- x$len == 1)) {
# vals <- x$res[len1]
# names(vals) <- map_chr(x$cl[len1], as_label)
# }
# path <- c(x$cl[[which.max(x$len)]], as.call(c(y[[1]], x$cl)))
# }

structure(
path,
vals = vals,
class = "resp_path"
)
},
.g = function(x) {
if(is.list(x)) x <- x[["cl"]]

# traverse only call arguments
args <- x[-1]

res <- map(args, function(y) eval(y, envir = model$data, enclos = model$specials))
len <- map_dbl(res, length)

# handle unspecified response
if(sum(len == max(len)) > 1) {
return(call("resp", x))
}

if(length(len[len!=1]) > 1){
abort(
sprintf(
"Response variable transformation has incompatible lengths, all arguments must be the length of the data %i or 1.",
max(len)
)
)
}

# handle length 1 arguments
if(any(is_singular <- len == 1)) {
nm <- map_chr(args[is_singular], as_label)
args[is_singular] <- syms(nm)
}

# add length of each args
.mapply(
list,
dots = list(cl = args, len = len, res = res),
MoreArgs = list()
)
},
# .h = function(x) {
# if(is_resp(x)){
# if(length(x) > 2) abort("The response variable accepts only one input. For multivariate modelling, use `vars()`.")
# list(response = x)
# } else x
# },
base = function(x) {
if(is.list(x)) x <- x[["cl"]]
is_syntactic_literal(x) || is_symbol(x) || is_resp(x)
}
traversed_lhs <- lapply(model_lhs,
function(x) {
len1vals <- list()
path <- traverse(
x,
.f = function(x, y) {
# Special handling for if the response was found by a length tie
if((length(x[[1]]$response == 1L) && (as_label(x[[1]]$response) == as_label(y[[1]])))) {
return(x[[1]])
}

# Rebuild the expression
args <- lapply(x, function(y) {
y[[length(y)]]
})
y <- as.call(c(y[[1]][[1]], args))

# Search for response
path <- compact(lapply(x, `[[`, "response"))
if(length(path) > 0) return(list(response = c(path[[1]], y)))

# Otherwise keep the path that isn't length 1
path <- x[[which(map_lgl(x, function(x) is.name(x[[1]]) && !(as_label(x[[1]]) %in% names(len1vals))))]]
c(path, y)
},
.g = function(x) {
# traverse only call arguments
args <- x[-1]

# search for resp() to avoid unneeded evaluation
resp_loc <- which(map_lgl(args, has_resp))
if(length(resp_loc) > 1) {
abort("The `resp()` function can only be used once per response variable. For multivariate modelling, use `vars()`.")
}
non_resp <- if(length(resp_loc) == 0) seq_along(args) else -resp_loc

res <- map(args[non_resp], function(y) eval(y, envir = model$data, enclos = model$specials))
len <- map_dbl(res, length)

if(length(unique(len[len!=1])) > 1){
abort(
sprintf(
"Response variable transformation has incompatible lengths, all arguments must be the length of the data %i or 1.",
max(len)
)
)
}

# store length 1 arguments for transformation environment
len1check <- function(len, arg) {
(len == 1) && (is.name(arg) || (is.call(arg) && !(as_label(arg[[1]]) %in% "length")))
}

if(any(is_singular <- map2_lgl(len, args[non_resp], len1check))) {
nm <- map_chr(args[non_resp][is_singular], as_label)
args[non_resp][is_singular] <- syms(nm)
len1vals[nm] <<- res[is_singular]
}

# handle unspecified response with equal length args
if((length(resp_loc) == 0) && (sum(len == max(len)) > 1)) {
return(list(call("resp", x)))
}

args
},
.h = function(x) {
if(is_resp(x)){
if(length(x[-1]) > 2) abort("The response variable accepts only one input. For multivariate modelling, use `vars()`.")
list(response = sym(as_label(x[[2]])))
} else list(x)
},
base = function(x) {
is_syntactic_literal(x) || is_symbol(x) || is_resp(x)
}
)
if("response" %in% names(path)) path <- path$response
if(!is.list(path)) path <- list(path)
list(path = path, len1vals = len1vals)
}
)

# Obtain parsed out response variable
responses <- map(traversed_lhs, function(x) x[[1]])
responses <- map(traversed_lhs, function(x) x$path[[1]])
responses <- map_chr(responses, as_label)

# Obtain transformation expression applied to response variable
transform_exprs <- lapply(traversed_lhs, function(x) x[[length(x)]])
transform_exprs <- lapply(traversed_lhs, function(x) x$path[[length(x$path)]])

# Invert transformation applied to response variable
inverse_exprs <- lapply(traversed_lhs, function(x){
x <- rev(x)
x <- rev(x$path)
result <- x[[length(x)]]
for (i in seq_len(length(x) - 1)){
result <- undo_transformation(x[[i]], x[[i + 1]], result)
Expand All @@ -262,7 +254,7 @@ parse_model_lhs <- function(model){

# Create evaluation environment for transformation functions
# Includes cached values of single length arguments
transform_args <- lapply(traversed_lhs, attr, "vals")
transform_args <- lapply(traversed_lhs, `[[`, "len1vals")
envs <- lapply(transform_args, new_environment, parent = model$env)

# Produce transformation class functions for bt() usage
Expand Down

0 comments on commit 45a4c46

Please sign in to comment.