Skip to content

Commit

Permalink
Adds py_require() function
Browse files Browse the repository at this point in the history
  • Loading branch information
edgararuiz committed Dec 30, 2024
1 parent 61f0fa4 commit 3576166
Show file tree
Hide file tree
Showing 2 changed files with 298 additions and 0 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ S3method(print,python.builtin.list)
S3method(print,python.builtin.module)
S3method(print,python.builtin.object)
S3method(print,python.builtin.tuple)
S3method(print,python_requirements)
S3method(py_str,default)
S3method(py_str,python.builtin.bytearray)
S3method(py_str,python.builtin.dict)
Expand Down Expand Up @@ -190,6 +191,7 @@ export(py_module_available)
export(py_none)
export(py_numpy_available)
export(py_repr)
export(py_require)
export(py_run_file)
export(py_run_string)
export(py_save_object)
Expand Down
296 changes: 296 additions & 0 deletions R/py_require.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
#' @export
py_require <- function(packages = NULL,
python_version = NULL,
action = c("add", "omit", "replace"),
silent = FALSE) {
if (missing(packages) && missing(python_version)) {
return(get_python_reqs())
}

action <- match.arg(action)

err_packages <- NULL
err_python <- NULL

msg_packages <- NULL
msg_python <- NULL

final_packages <- NULL
final_python <- NULL

has_error <- FALSE

if (!is.null(packages)) {
req_packages <- get_python_reqs("packages")
if (action %in% c("replace", "omit")) {
for (pkg in packages) {
pkg_name <- extract_name(pkg)
if (action == "omit" && pkg_name != pkg) {
matches <- pkg == req_packages
} else {
matches <- pkg_name == extract_name(req_packages)
}
if (any(matches)) {
match_pkgs <- req_packages[matches]
match_pkgs <- ennumerate_packages(match_pkgs)
req_packages <- req_packages[!matches]
if (action == "replace") {
req_packages <- c(req_packages, pkg)
msg_packages <- c(msg_packages, paste(
"Replaced", match_pkgs, "with", sprintf("\"%s\"", pkg)
))
} else {
msg_packages <- c(msg_packages, paste("Ommiting", match_pkgs))
}

} else {
has_error <- TRUE
err_msg <- sprintf("\"%s\"", pkg)
if (action == "replace" && pkg != pkg_name) {
err_msg <- sprintf("%s(searched for: \"%s\")", err_msg, pkg_name)
}
err_packages <- c(err_packages, err_msg)
}
}
final_packages <- req_packages
} else {
msg_packages <- paste("Added", ennumerate_packages(packages))
final_packages <- unique(c(req_packages, packages))
}
if (length(err_packages) > 0) {
err_packages <- c(
"Could not match",
ennumerate_packages(err_packages, FALSE)
)
if (action == "replace") {
err_packages <- c(
err_packages,
"\nTip: Check spelling, or remove from your command, and try again"
)
}
if (action == "omit") {
err_packages <- c(
err_packages,
"\nTip: Remove from your command, and try again"
)
}
err_packages <- paste0(err_packages, collapse = " ")
}
}

env_name <- environmentName(topenv(parent.frame()))
if (env_name == "R_GlobalEnv") {
env_name <- "R session"
}

entry <- list(list(
requested_from = env_name,
packages = packages,
action = action,
python_version = python_version
))

new_history <- c(get_python_reqs("history"), entry)
if (!is.null(python_version)) {
current_versions <- NULL
for (item in new_history) {
item_python <- item$python_version
if (length(item_python) > 0) {
item_python <- item_python[[1]]
item_action <- item$action
if (item_action == "add") {
current_versions <- c(current_versions, item_python)
}
if (item_action == "replace") {
current_versions <- item_python
}
if (item_action == "omit") {
matched <- item_python == current_versions
if (any(matched)) {
current_versions <- current_versions[!matched]
} else {
err_python <- sprintf(
fmt = "An entry for Python %s was not found in the history",
item_python
)
}
}
}
}
if (is.null(err_python)) {
final_python <- resolve_python(current_versions)
if (is.na(final_python)) {
has_error <- TRUE
err_python <- paste(
"Python versions are in conflict:",
ennumerate_packages(current_versions, FALSE)
)
} else {
msg_python <- paste("Setting Python version to:", final_python)
}
}
}

if (has_error) {
stop("\n", add_dash(c(err_packages, err_python)), "\n", call. = FALSE)
} else {
set_python_reqs(
packages = final_packages,
python_version = final_python,
history = new_history
)
if (!silent) {
cat(add_dash(c(msg_packages, msg_python)), "\n")
}
}

invisible()
}

add_dash <- function(x) {
if (length(x) == 1) {
dashed <- ""
spaces <- ""
} else {
dashed <- "- "
spaces <- " "
}
x <- gsub("\n", paste0("\n", spaces), x)
paste0(dashed, x, collapse = "\n")
}

resolve_python <- function(constraints) {
constraints <- paste0(constraints, collapse = ",")
candidates <- paste0("3.", 9:13)
for (check in as_version_constraint_checkers(constraints)) {
satisfies_constraint <- check(candidates)
candidates <- candidates[satisfies_constraint]
}
candidates[1]
}


extract_name <- function(x) {
as.character(lapply(x, function(x) {
# If it's a URL or path to a binary or source distribution
# (e.g., .whl, .sdist), try to extract the name
is_dist <- grepl("/", x) ||
grepl("\\.(whl|tar\\.gz|.zip|.tgz)$", x)
if (is_dist) {
# Remove path or URL leading up to the file name
x <- sub(".*/", "", x)
# Remove everything after the first "-", which
# by the spec should be the *distribution* name.
x <- sub("-.*$", "", x)

# a whl/tar.gz or other package format
# should have name standardized already
# with `-` substituted with `_` already.
return(x)
}
# If it's a package name with a version
# constraint, remove the version constraint
x <- sub("[=<>].*$", "", x) # Remove ver constraints like `=`, `<`, `>`

# If it's a package name with a modifier like
# `tensorflow[and-cuda]`, remove the modifier
x <- sub("\\[.*$", "", x) # Remove modifiers like `[and-cuda]`
# standardize, replace "-" with "_"
gsub("-", "_", x, fixed = TRUE)
}))
}

.globals$python_requirements <- structure(
.Data = list(
python_version = "",
packages = c(),
history = list()
),
class = "python_requirements"
)

#' @export
print.python_requirements <- function(x, ...) {
packages <- x$packages
if (is.null(packages)) {
packages <- "[No packages added yet]"
} else {
packages <- paste0(packages, collapse = ", ")
}
python_version <- x$python_version
if(python_version == "") {
python_version <- "[No version of Python selected yet]"
}
cat("Setup ------------------------------\n")
cat(" Packages:", packages, "\n")
cat(" Python: ", python_version, "\n")
cat("History ----------------------------\n")
for (item in x$history) {
args <- list()
if (!is.null(item$packages)) {
if (length(item$packages) > 1) {
item$packages <- paste0("\"", item$packages, "\"", collapse = ", ")
item$packages <- sprintf("c(%s)", item$packages)
args$packages <- paste("packages =", item$packages)
} else {
args$packages <- sprintf("packages = \"%s\"", item$packages)
}
}
if (!is.null(item$python_version)) {
args$python_version <- sprintf(
fmt = "python_verison = \"%s\"", item$python_version
)
}
if (item$action != "add") {
args$action <- sprintf("action = \"%s\"", item$action)
}
args <- paste0(args, collapse = ", ")
cat(sprintf(" py_require(%s) # %s\n", args, item$requested_from))
}
}

get_python_reqs <- function(
x = c("all", "python_version", "packages", "history")) {
pr <- .globals$python_requirements
x <- match.arg(x)
switch(x,
all = pr,
python_version = pr$python_version,
packages = pr$packages,
history = pr$history
)
}

set_python_reqs <- function(
python_version = NULL,
packages = NULL,
history = NULL) {
pr <- get_python_reqs("all")
pr$python_version <- python_version %||% pr$python_version
pr$packages <- packages %||% pr$packages
pr$history <- history %||% pr$history
.globals$python_requirements <- pr
get_python_reqs("all")
}

ennumerate_packages <- function(x, add_quotes = TRUE) {
out <- NULL
len_x <- length(x)
for (i in seq_along(x)) {
i_x <- len_x - i
if (i_x > 1) {
join <- ", "
} else if (i_x == 1) {
join <- ", and "
} else {
join <- NULL
}
if (add_quotes) {
xi <- sprintf("\"%s\"", x[i])
} else {
xi <- x[i]
}
out <- c(out, xi, join)
}
paste0(out, collapse = "")
}

0 comments on commit 3576166

Please sign in to comment.