Skip to content

Commit

Permalink
WIP: model transformation parser to cache length 1 arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
mitchelloharawild committed Aug 20, 2023
1 parent 35eeca5 commit 7a6f6c1
Showing 1 changed file with 113 additions and 60 deletions.
173 changes: 113 additions & 60 deletions R/parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,64 +140,107 @@ parse_model_lhs <- function(model){
base = function(x) is_syntactic_literal(x) || is_symbol(x)
)

# Traverse call to parse out AST for transformations
# Traverse call AST to parse out order of transformations
#
# If the response is set via resp(), remove all usage of resp() from the traversal
# If the response is not set via resp(), identify the response by the maximum length object until encountering ties
#
# Returns a list of increasing depth of transformations
traversed_lhs <- lapply(model_lhs, traverse,
.f = function(x, y) {
if(any(resp_pos <- map_lgl(x, function(x) any(names(x) %in% "response")))){
if(sum(resp_pos) != 1) abort("The `resp()` function can only be used once per response variable. For multivariate modelling, use `vars()`.")
names(x)[resp_pos] <- "response"
}
`attr<-`(x, "call", y[[1]])
# if(any(resp_pos <- map_lgl(x, function(x) any(names(x) %in% "response")))){
# if(sum(resp_pos) != 1) abort("The `resp()` function can only be used once per response variable. For multivariate modelling, use `vars()`.")
# names(x)[resp_pos] <- "response"
# }

vals <- list()
resp <- NULL
args <- lapply(x, function(y) {
if(inherits(y, "resp_path")) {
vals[names(attr(y, "vals"))] <<- attr(y, "vals")
resp <<- y
y[[length(y)]]
} else {
if(y$len == 1) {
vals[as_label(y$cl)] <<- y$res
} else {
resp <<- y$cl
}
y$cl
}
})

if(is.list(y)) y <- y[["cl"]]
path <- c(resp, as.call(c(y[[1]], args)))

# if(inherits(x[[1]], "resp_path")) {
# x <- x[[1]]
# vals <- attr(x, "vals")
# path <- c(x, as.call(c(y[[1]], x[[length(x)]])))
#
# } else {
# x <- transpose(x)
# x$len <- as.double(x$len)
#
# if(any(len1 <- x$len == 1)) {
# vals <- x$res[len1]
# names(vals) <- map_chr(x$cl[len1], as_label)
# }
# path <- c(x$cl[[which.max(x$len)]], as.call(c(y[[1]], x$cl)))
# }

structure(
path,
vals = vals,
class = "resp_path"
)
},
.g = function(x) x[-1],
.h = function(x) {
if(is_resp(x)){
if(length(x) > 2) abort("The response variable accepts only one input. For multivariate modelling, use `vars()`.")
list(response = x)
} else list(x)
.g = function(x) {
if(is.list(x)) x <- x[["cl"]]

# traverse only call arguments
args <- x[-1]

res <- map(args, function(y) eval(y, envir = model$data, enclos = model$specials))
len <- map_dbl(res, length)

# handle unspecified response
if(sum(len == max(len)) > 1) {
return(call("resp", x))
}

if(length(len[len!=1]) > 1){
abort(
sprintf(
"Response variable transformation has incompatible lengths, all arguments must be the length of the data %i or 1.",
max(len)
)
)
}

# handle length 1 arguments
if(any(is_singular <- len == 1)) {
nm <- map_chr(args[is_singular], as_label)
args[is_singular] <- syms(nm)
}

# add length of each args
.mapply(
list,
dots = list(cl = args, len = len, res = res),
MoreArgs = list()
)
},
base = function(x) is_syntactic_literal(x) || is_symbol(x) || is_resp(x)
)

# Reduce traversal down to the response
# If the response is set via resp(), remove all usage of resp() from the traversal
# If the response is not set via resp(), identify the response by the maximum length object until encountering ties
traversed_lhs <- lapply(traversed_lhs, traverse,
.f = function(x, y){
# Capture parent expression of base case
cl <- NULL
if(length(x) == 0){
# Multiple length `n` variables found and cannot disambiguate response
# Start with most disaggregated result of computation as response.
x <- if(is.null(attr(y, "call"))) list(y[[1]]) else syms(as_label(attr(y, "call")))
}
else{
if(is.null(attr(y, "call"))){
if(is_resp(x[[1]])){
x[[1]] <- x[[1]][[2]]
}
}
else{
# Remove resp() from call
cl <- attr(y,"call")
if(any(names(y) == "response")){
cl[[which(names(y) == "response")+1]] <- x[[1]][[length(x[[1]])]]
}
}
}
c(x[[1]], cl)
},
.g = function(x){
if(all(names(x) != "response") && !is.null(attr(x, "call"))){
# parent_len <- length(eval(attr(x, "call") %||% x[[1]], envir = model$data))
len <- map_dbl(x, function(y) length(eval(attr(y, "call") %||% y[[1]], envir = model$data, enclos = model$specials)))
if(sum(len == max(len)) == 1){
names(x)[which.max(len)] <- "response"
}
}
if("response" %in% names(x)) x["response"] else list()
},
base = function(x) !is.list(x)
# .h = function(x) {
# if(is_resp(x)){
# if(length(x) > 2) abort("The response variable accepts only one input. For multivariate modelling, use `vars()`.")
# list(response = x)
# } else x
# },
base = function(x) {
if(is.list(x)) x <- x[["cl"]]
is_syntactic_literal(x) || is_symbol(x) || is_resp(x)
}
)

# Obtain parsed out response variable
Expand All @@ -217,15 +260,25 @@ parse_model_lhs <- function(model){
result
})

# Create evaluation environment for transformation functions
# Includes cached values of single length arguments
transform_args <- lapply(traversed_lhs, attr, "vals")
envs <- lapply(transform_args, new_environment, parent = model$env)

# Produce transformation class functions for bt() usage
make_transforms <- function(exprs, responses){
map2(exprs, responses, function(x, response){
new_function(args = set_names(list(missing_arg()), response), x, env = model$env)
})
make_transforms <- function(exprs, responses, envs){
.mapply(
function(x, response, env){
new_function(args = set_names(list(missing_arg()), response), x, env = env)
},
dots = list(x = exprs, response = responses, env = envs),
MoreArgs = list()
)
}

transformations <- map2(
make_transforms(transform_exprs, responses),
make_transforms(inverse_exprs, responses),
make_transforms(transform_exprs, responses, envs),
make_transforms(inverse_exprs, responses, envs),
new_transformation
)

Expand Down

0 comments on commit 7a6f6c1

Please sign in to comment.