From 35761664eb3d4814b6114c1a8ca7adc189221dd2 Mon Sep 17 00:00:00 2001 From: Edgar Ruiz Date: Mon, 30 Dec 2024 11:54:55 -0600 Subject: [PATCH] Adds py_require() function --- NAMESPACE | 2 + R/py_require.R | 296 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 298 insertions(+) create mode 100644 R/py_require.R diff --git a/NAMESPACE b/NAMESPACE index c0f2257c7..1abeee449 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -74,6 +74,7 @@ S3method(print,python.builtin.list) S3method(print,python.builtin.module) S3method(print,python.builtin.object) S3method(print,python.builtin.tuple) +S3method(print,python_requirements) S3method(py_str,default) S3method(py_str,python.builtin.bytearray) S3method(py_str,python.builtin.dict) @@ -190,6 +191,7 @@ export(py_module_available) export(py_none) export(py_numpy_available) export(py_repr) +export(py_require) export(py_run_file) export(py_run_string) export(py_save_object) diff --git a/R/py_require.R b/R/py_require.R new file mode 100644 index 000000000..6af801771 --- /dev/null +++ b/R/py_require.R @@ -0,0 +1,296 @@ +#' @export +py_require <- function(packages = NULL, + python_version = NULL, + action = c("add", "omit", "replace"), + silent = FALSE) { + if (missing(packages) && missing(python_version)) { + return(get_python_reqs()) + } + + action <- match.arg(action) + + err_packages <- NULL + err_python <- NULL + + msg_packages <- NULL + msg_python <- NULL + + final_packages <- NULL + final_python <- NULL + + has_error <- FALSE + + if (!is.null(packages)) { + req_packages <- get_python_reqs("packages") + if (action %in% c("replace", "omit")) { + for (pkg in packages) { + pkg_name <- extract_name(pkg) + if (action == "omit" && pkg_name != pkg) { + matches <- pkg == req_packages + } else { + matches <- pkg_name == extract_name(req_packages) + } + if (any(matches)) { + match_pkgs <- req_packages[matches] + match_pkgs <- ennumerate_packages(match_pkgs) + req_packages <- req_packages[!matches] + if (action == "replace") { + req_packages <- c(req_packages, pkg) + msg_packages <- c(msg_packages, paste( + "Replaced", match_pkgs, "with", sprintf("\"%s\"", pkg) + )) + } else { + msg_packages <- c(msg_packages, paste("Ommiting", match_pkgs)) + } + + } else { + has_error <- TRUE + err_msg <- sprintf("\"%s\"", pkg) + if (action == "replace" && pkg != pkg_name) { + err_msg <- sprintf("%s(searched for: \"%s\")", err_msg, pkg_name) + } + err_packages <- c(err_packages, err_msg) + } + } + final_packages <- req_packages + } else { + msg_packages <- paste("Added", ennumerate_packages(packages)) + final_packages <- unique(c(req_packages, packages)) + } + if (length(err_packages) > 0) { + err_packages <- c( + "Could not match", + ennumerate_packages(err_packages, FALSE) + ) + if (action == "replace") { + err_packages <- c( + err_packages, + "\nTip: Check spelling, or remove from your command, and try again" + ) + } + if (action == "omit") { + err_packages <- c( + err_packages, + "\nTip: Remove from your command, and try again" + ) + } + err_packages <- paste0(err_packages, collapse = " ") + } + } + + env_name <- environmentName(topenv(parent.frame())) + if (env_name == "R_GlobalEnv") { + env_name <- "R session" + } + + entry <- list(list( + requested_from = env_name, + packages = packages, + action = action, + python_version = python_version + )) + + new_history <- c(get_python_reqs("history"), entry) + if (!is.null(python_version)) { + current_versions <- NULL + for (item in new_history) { + item_python <- item$python_version + if (length(item_python) > 0) { + item_python <- item_python[[1]] + item_action <- item$action + if (item_action == "add") { + current_versions <- c(current_versions, item_python) + } + if (item_action == "replace") { + current_versions <- item_python + } + if (item_action == "omit") { + matched <- item_python == current_versions + if (any(matched)) { + current_versions <- current_versions[!matched] + } else { + err_python <- sprintf( + fmt = "An entry for Python %s was not found in the history", + item_python + ) + } + } + } + } + if (is.null(err_python)) { + final_python <- resolve_python(current_versions) + if (is.na(final_python)) { + has_error <- TRUE + err_python <- paste( + "Python versions are in conflict:", + ennumerate_packages(current_versions, FALSE) + ) + } else { + msg_python <- paste("Setting Python version to:", final_python) + } + } + } + + if (has_error) { + stop("\n", add_dash(c(err_packages, err_python)), "\n", call. = FALSE) + } else { + set_python_reqs( + packages = final_packages, + python_version = final_python, + history = new_history + ) + if (!silent) { + cat(add_dash(c(msg_packages, msg_python)), "\n") + } + } + + invisible() +} + +add_dash <- function(x) { + if (length(x) == 1) { + dashed <- "" + spaces <- "" + } else { + dashed <- "- " + spaces <- " " + } + x <- gsub("\n", paste0("\n", spaces), x) + paste0(dashed, x, collapse = "\n") +} + +resolve_python <- function(constraints) { + constraints <- paste0(constraints, collapse = ",") + candidates <- paste0("3.", 9:13) + for (check in as_version_constraint_checkers(constraints)) { + satisfies_constraint <- check(candidates) + candidates <- candidates[satisfies_constraint] + } + candidates[1] +} + + +extract_name <- function(x) { + as.character(lapply(x, function(x) { + # If it's a URL or path to a binary or source distribution + # (e.g., .whl, .sdist), try to extract the name + is_dist <- grepl("/", x) || + grepl("\\.(whl|tar\\.gz|.zip|.tgz)$", x) + if (is_dist) { + # Remove path or URL leading up to the file name + x <- sub(".*/", "", x) + # Remove everything after the first "-", which + # by the spec should be the *distribution* name. + x <- sub("-.*$", "", x) + + # a whl/tar.gz or other package format + # should have name standardized already + # with `-` substituted with `_` already. + return(x) + } + # If it's a package name with a version + # constraint, remove the version constraint + x <- sub("[=<>].*$", "", x) # Remove ver constraints like `=`, `<`, `>` + + # If it's a package name with a modifier like + # `tensorflow[and-cuda]`, remove the modifier + x <- sub("\\[.*$", "", x) # Remove modifiers like `[and-cuda]` + # standardize, replace "-" with "_" + gsub("-", "_", x, fixed = TRUE) + })) +} + +.globals$python_requirements <- structure( + .Data = list( + python_version = "", + packages = c(), + history = list() + ), + class = "python_requirements" +) + +#' @export +print.python_requirements <- function(x, ...) { + packages <- x$packages + if (is.null(packages)) { + packages <- "[No packages added yet]" + } else { + packages <- paste0(packages, collapse = ", ") + } + python_version <- x$python_version + if(python_version == "") { + python_version <- "[No version of Python selected yet]" + } + cat("Setup ------------------------------\n") + cat(" Packages:", packages, "\n") + cat(" Python: ", python_version, "\n") + cat("History ----------------------------\n") + for (item in x$history) { + args <- list() + if (!is.null(item$packages)) { + if (length(item$packages) > 1) { + item$packages <- paste0("\"", item$packages, "\"", collapse = ", ") + item$packages <- sprintf("c(%s)", item$packages) + args$packages <- paste("packages =", item$packages) + } else { + args$packages <- sprintf("packages = \"%s\"", item$packages) + } + } + if (!is.null(item$python_version)) { + args$python_version <- sprintf( + fmt = "python_verison = \"%s\"", item$python_version + ) + } + if (item$action != "add") { + args$action <- sprintf("action = \"%s\"", item$action) + } + args <- paste0(args, collapse = ", ") + cat(sprintf(" py_require(%s) # %s\n", args, item$requested_from)) + } +} + +get_python_reqs <- function( + x = c("all", "python_version", "packages", "history")) { + pr <- .globals$python_requirements + x <- match.arg(x) + switch(x, + all = pr, + python_version = pr$python_version, + packages = pr$packages, + history = pr$history + ) +} + +set_python_reqs <- function( + python_version = NULL, + packages = NULL, + history = NULL) { + pr <- get_python_reqs("all") + pr$python_version <- python_version %||% pr$python_version + pr$packages <- packages %||% pr$packages + pr$history <- history %||% pr$history + .globals$python_requirements <- pr + get_python_reqs("all") +} + +ennumerate_packages <- function(x, add_quotes = TRUE) { + out <- NULL + len_x <- length(x) + for (i in seq_along(x)) { + i_x <- len_x - i + if (i_x > 1) { + join <- ", " + } else if (i_x == 1) { + join <- ", and " + } else { + join <- NULL + } + if (add_quotes) { + xi <- sprintf("\"%s\"", x[i]) + } else { + xi <- x[i] + } + out <- c(out, xi, join) + } + paste0(out, collapse = "") +}