diff --git a/R/parse.R b/R/parse.R index fe5ae50a..b1cd0765 100644 --- a/R/parse.R +++ b/R/parse.R @@ -131,13 +131,20 @@ parse_model_lhs <- function(model){ # Traverse call removing all resp() usage # This is used to evaluate the response from the input data - response_exprs <- lapply(model_lhs, traverse, - .f = function(x, y) { - if(is_resp(y)) x[[1]] else call2(call_name(y), !!!x) - }, - .g = function(x) x[-1], - # .h = function(x) if(is_resp(x)) x[[length(x)]] else x, - base = function(x) is_syntactic_literal(x) || is_symbol(x) + response_exprs <- lapply( + model_lhs, traverse, + .f = function(x, y) { + if(is_resp(y)) x[[1]] else call2(y[[1]], !!!x) + }, + .g = function(x) x[-1], + # .h = function(x) if(is_resp(x)) x[[length(x)]] else x, + base = function(x) is_syntactic_literal(x) || is_symbol(x) + ) + + has_resp <- function(x) traverse( + x, + .f = function(x,y) x[[1]]||y, .g = function(x) x[-1], .h = is_resp, + base = function(x) is_syntactic_literal(x) || is_symbol(x) ) # Traverse call AST to parse out order of transformations @@ -146,113 +153,98 @@ parse_model_lhs <- function(model){ # If the response is not set via resp(), identify the response by the maximum length object until encountering ties # # Returns a list of increasing depth of transformations - traversed_lhs <- lapply(model_lhs, traverse, - .f = function(x, y) { - # if(any(resp_pos <- map_lgl(x, function(x) any(names(x) %in% "response")))){ - # if(sum(resp_pos) != 1) abort("The `resp()` function can only be used once per response variable. For multivariate modelling, use `vars()`.") - # names(x)[resp_pos] <- "response" - # } - - vals <- list() - resp <- NULL - args <- lapply(x, function(y) { - if(inherits(y, "resp_path")) { - vals[names(attr(y, "vals"))] <<- attr(y, "vals") - resp <<- y - y[[length(y)]] - } else { - if(y$len == 1) { - vals[as_label(y$cl)] <<- y$res - } else { - resp <<- y$cl - } - y$cl - } - }) - - if(is.list(y)) y <- y[["cl"]] - path <- c(resp, as.call(c(y[[1]], args))) - - # if(inherits(x[[1]], "resp_path")) { - # x <- x[[1]] - # vals <- attr(x, "vals") - # path <- c(x, as.call(c(y[[1]], x[[length(x)]]))) - # - # } else { - # x <- transpose(x) - # x$len <- as.double(x$len) - # - # if(any(len1 <- x$len == 1)) { - # vals <- x$res[len1] - # names(vals) <- map_chr(x$cl[len1], as_label) - # } - # path <- c(x$cl[[which.max(x$len)]], as.call(c(y[[1]], x$cl))) - # } - - structure( - path, - vals = vals, - class = "resp_path" - ) - }, - .g = function(x) { - if(is.list(x)) x <- x[["cl"]] - - # traverse only call arguments - args <- x[-1] - - res <- map(args, function(y) eval(y, envir = model$data, enclos = model$specials)) - len <- map_dbl(res, length) - - # handle unspecified response - if(sum(len == max(len)) > 1) { - return(call("resp", x)) - } - - if(length(len[len!=1]) > 1){ - abort( - sprintf( - "Response variable transformation has incompatible lengths, all arguments must be the length of the data %i or 1.", - max(len) - ) - ) - } - - # handle length 1 arguments - if(any(is_singular <- len == 1)) { - nm <- map_chr(args[is_singular], as_label) - args[is_singular] <- syms(nm) - } - - # add length of each args - .mapply( - list, - dots = list(cl = args, len = len, res = res), - MoreArgs = list() - ) - }, - # .h = function(x) { - # if(is_resp(x)){ - # if(length(x) > 2) abort("The response variable accepts only one input. For multivariate modelling, use `vars()`.") - # list(response = x) - # } else x - # }, - base = function(x) { - if(is.list(x)) x <- x[["cl"]] - is_syntactic_literal(x) || is_symbol(x) || is_resp(x) - } + traversed_lhs <- lapply(model_lhs, + function(x) { + len1vals <- list() + path <- traverse( + x, + .f = function(x, y) { + # Special handling for if the response was found by a length tie + if((length(x[[1]]$response == 1L) && (as_label(x[[1]]$response) == as_label(y[[1]])))) { + return(x[[1]]) + } + + # Rebuild the expression + args <- lapply(x, function(y) { + y[[length(y)]] + }) + y <- as.call(c(y[[1]][[1]], args)) + + # Search for response + path <- compact(lapply(x, `[[`, "response")) + if(length(path) > 0) return(list(response = c(path[[1]], y))) + + # Otherwise keep the path that isn't length 1 + path <- x[[which(map_lgl(x, function(x) is.name(x[[1]]) && !(as_label(x[[1]]) %in% names(len1vals))))]] + c(path, y) + }, + .g = function(x) { + # traverse only call arguments + args <- x[-1] + + # search for resp() to avoid unneeded evaluation + resp_loc <- which(map_lgl(args, has_resp)) + if(length(resp_loc) > 1) { + abort("The `resp()` function can only be used once per response variable. For multivariate modelling, use `vars()`.") + } + non_resp <- if(length(resp_loc) == 0) seq_along(args) else -resp_loc + + res <- map(args[non_resp], function(y) eval(y, envir = model$data, enclos = model$specials)) + len <- map_dbl(res, length) + + if(length(unique(len[len!=1])) > 1){ + abort( + sprintf( + "Response variable transformation has incompatible lengths, all arguments must be the length of the data %i or 1.", + max(len) + ) + ) + } + + # store length 1 arguments for transformation environment + len1check <- function(len, arg) { + (len == 1) && (is.name(arg) || (is.call(arg) && !(as_label(arg[[1]]) %in% "length"))) + } + + if(any(is_singular <- map2_lgl(len, args[non_resp], len1check))) { + nm <- map_chr(args[non_resp][is_singular], as_label) + args[non_resp][is_singular] <- syms(nm) + len1vals[nm] <<- res[is_singular] + } + + # handle unspecified response with equal length args + if((length(resp_loc) == 0) && (sum(len == max(len)) > 1)) { + return(list(call("resp", x))) + } + + args + }, + .h = function(x) { + if(is_resp(x)){ + if(length(x[-1]) > 2) abort("The response variable accepts only one input. For multivariate modelling, use `vars()`.") + list(response = sym(as_label(x[[2]]))) + } else list(x) + }, + base = function(x) { + is_syntactic_literal(x) || is_symbol(x) || is_resp(x) + } + ) + if("response" %in% names(path)) path <- path$response + if(!is.list(path)) path <- list(path) + list(path = path, len1vals = len1vals) + } ) # Obtain parsed out response variable - responses <- map(traversed_lhs, function(x) x[[1]]) + responses <- map(traversed_lhs, function(x) x$path[[1]]) responses <- map_chr(responses, as_label) # Obtain transformation expression applied to response variable - transform_exprs <- lapply(traversed_lhs, function(x) x[[length(x)]]) + transform_exprs <- lapply(traversed_lhs, function(x) x$path[[length(x$path)]]) # Invert transformation applied to response variable inverse_exprs <- lapply(traversed_lhs, function(x){ - x <- rev(x) + x <- rev(x$path) result <- x[[length(x)]] for (i in seq_len(length(x) - 1)){ result <- undo_transformation(x[[i]], x[[i + 1]], result) @@ -262,7 +254,7 @@ parse_model_lhs <- function(model){ # Create evaluation environment for transformation functions # Includes cached values of single length arguments - transform_args <- lapply(traversed_lhs, attr, "vals") + transform_args <- lapply(traversed_lhs, `[[`, "len1vals") envs <- lapply(transform_args, new_environment, parent = model$env) # Produce transformation class functions for bt() usage