From 53df7e204e8c38c69d93a81ea93eb73f41e31f8d Mon Sep 17 00:00:00 2001 From: "tengqiu.huang" Date: Sun, 15 Sep 2024 21:36:57 +0800 Subject: [PATCH] enhacne: add lock --- .gitignore | 1 + robyn_api/robynapi_enpoints.R | 211 ++++++++++++++++++---------------- 2 files changed, 112 insertions(+), 100 deletions(-) diff --git a/.gitignore b/.gitignore index 01fb872a3..56dc493be 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .DS_Store +.idea .Rproj.user .Rhistory node_modules/ diff --git a/robyn_api/robynapi_enpoints.R b/robyn_api/robynapi_enpoints.R index 65cee92c0..8e11f64ca 100644 --- a/robyn_api/robynapi_enpoints.R +++ b/robyn_api/robynapi_enpoints.R @@ -6,15 +6,15 @@ ### Import necessary libraries ### # Function to locate and load required virtual environment used to install nevergrad -load_pythonenv <- function(env="r-reticulate"){ +load_pythonenv <- function(env = "r-reticulate") { tryCatch( - { - library("reticulate") - if(reticulate::condaenv_exists(env)) {use_condaenv(env)} - else if (reticulate::virtualenv_exists(env)) {use_virtualenv(env, required = TRUE)} - else {message('Install nevergrad to proceed')} - }, - error=function(e) { + { + library("reticulate") + if (reticulate::condaenv_exists(env)) { use_condaenv(env) } + else if (reticulate::virtualenv_exists(env)) { use_virtualenv(env, required = TRUE) } + else { message('Install nevergrad to proceed') } + }, + error = function(e) { message('Install nevergrad to proceed') } ) @@ -30,8 +30,11 @@ suppressPackageStartupMessages({ library(Robyn) library(tibble) library(promises) + library(synchronicity) }) +### GLOBAL LOCK ### +mutex = boost.mutex() ### FUNCTIONS ### @@ -39,7 +42,7 @@ suppressPackageStartupMessages({ #* This function is called to import the table data such as dt_simulated_weekly, dt_prophet_holidays hex_to_raw <- function(x) { chars <- unlist(regmatches(x, gregexpr("..", x))) - as.raw(strtoi(chars, base=16L)) + as.raw(strtoi(chars, base = 16L)) } #* Whether an object is a named list @@ -86,16 +89,18 @@ convert_dates_to_Date <- function(json_data) { recursive_convert <- function(x) { if (is.list(x)) { lapply(x, recursive_convert) - } else if (is.character(x) && length(x) == 1 && grepl("^\\d{4}-\\d{2}-\\d{2}$", x)) { + } else if (is.character(x) && + length(x) == 1 && + grepl("^\\d{4}-\\d{2}-\\d{2}$", x)) { as.Date(x) } else { x } } - + # Recursively convert date strings to Date objects converted_data <- recursive_convert(json_data) - + return(converted_data) } @@ -103,40 +108,42 @@ convert_dates_to_Date <- function(json_data) { #* transform InputCollect from API transform_InputCollect <- function(InputCollect) { promise({ - InputCollect <- jsonlite::fromJSON(InputCollect) %>% convert_dates_to_Date() - - # list > tibble - vars_to_tibble <- c("dt_input", "dt_holidays", "dt_mod", "dt_modRollWind", "dt_inputRollWind", "calibration_input") - for (var in vars_to_tibble) { - InputCollect[[var]] <- as_tibble(InputCollect[[var]]) - InputCollect[[var]][] <- lapply(InputCollect[[var]], function(col) { - if (all(grepl("^\\d{4}-\\d{2}-\\d{2}$", col))) { - return(as.Date(col)) - } - return(col) - }) - } - - # Null Treatment - for (var in names(InputCollect)) { - if(length(InputCollect[[var]])==0) { - InputCollect[[var]] <- NULL - named_list <- setNames(alist(x=NULL), var) - InputCollect <- c(InputCollect, named_list) - - } - } - - # Add class name which is used as a checker in Robyn - class(InputCollect) <- c("robyn_inputs", "list") - - return(InputCollect) + lock(mutex) + InputCollect <- jsonlite::fromJSON(InputCollect) %>% convert_dates_to_Date() + + # list > tibble + vars_to_tibble <- c("dt_input", "dt_holidays", "dt_mod", "dt_modRollWind", "dt_inputRollWind", "calibration_input") + for (var in vars_to_tibble) { + InputCollect[[var]] <- as_tibble(InputCollect[[var]]) + InputCollect[[var]][] <- lapply(InputCollect[[var]], function(col) { + if (all(grepl("^\\d{4}-\\d{2}-\\d{2}$", col))) { + return(as.Date(col)) + } + return(col) + }) + } + + # Null Treatment + for (var in names(InputCollect)) { + if (length(InputCollect[[var]]) == 0) { + InputCollect[[var]] <- NULL + named_list <- setNames(alist(x = NULL), var) + InputCollect <- c(InputCollect, named_list) + + } + } + + # Add class name which is used as a checker in Robyn + class(InputCollect) <- c("robyn_inputs", "list") + unlock(mutex) + return(InputCollect) }) } #* transform OutputCollect from API -transform_OutputCollect <- function(OutputCollect, select_model=FALSE) { +transform_OutputCollect <- function(OutputCollect, select_model = FALSE) { promise({ + lock(mutex) OutputCollect <- jsonlite::fromJSON(OutputCollect) # Add class name which is used as a checker in Robyn class(OutputCollect) <- c("robyn_outputs", "list") @@ -165,13 +172,13 @@ transform_OutputCollect <- function(OutputCollect, select_model=FALSE) { } # convert only target model data - if (!select_model==FALSE) { + if (!select_model == FALSE) { OutputCollect[['allPareto']][['plotDataCollect']][[select_model]][['plot2data']][['plotWaterfallLoop']] <- OutputCollect[['allPareto']][['plotDataCollect']][[select_model]][['plot2data']][['plotWaterfallLoop']] %>% - as_tibble() %>% - mutate(across(where(is.character), as.factor)) + as_tibble() %>% + mutate(across(where(is.character), as.factor)) } - + unlock(mutex) return(OutputCollect) }) @@ -210,20 +217,20 @@ function() { #* @param calibration_input A hexadecimal string representing the binary content of a calibration data feather file. #* @serializer json list(digits = 20, na = 'null') #* @post /robyn_inputs -function(dt_input=FALSE, dt_holidays=FALSE, jsonInputArgs=FALSE, InputCollect=FALSE, calibration_input=FALSE) { - - inputArgs <- if (!jsonInputArgs==FALSE) jsonlite::fromJSON(jsonInputArgs) else NULL - dt_input <- if (!dt_input==FALSE) hex_to_raw(dt_input) %>% arrow::read_feather() else NULL - dt_holidays <- if (!dt_holidays==FALSE) hex_to_raw(dt_holidays) %>% arrow::read_feather() else NULL - InputCollect <- if (!InputCollect==FALSE) transform_InputCollect(InputCollect) else NULL - calibration_input <- if (!calibration_input==FALSE) hex_to_raw(calibration_input) %>% arrow::read_feather() else NULL - +function(dt_input = FALSE, dt_holidays = FALSE, jsonInputArgs = FALSE, InputCollect = FALSE, calibration_input = FALSE) { + + inputArgs <- if (!jsonInputArgs == FALSE) jsonlite::fromJSON(jsonInputArgs) else NULL + dt_input <- if (!dt_input == FALSE) hex_to_raw(dt_input) %>% arrow::read_feather() else NULL + dt_holidays <- if (!dt_holidays == FALSE) hex_to_raw(dt_holidays) %>% arrow::read_feather() else NULL + InputCollect <- if (!InputCollect == FALSE) transform_InputCollect(InputCollect) else NULL + calibration_input <- if (!calibration_input == FALSE) hex_to_raw(calibration_input) %>% arrow::read_feather() else NULL + InputCollect <- do.call(robyn_inputs, c(list(dt_input = dt_input, dt_holidays = dt_holidays, InputCollect = InputCollect, calibration_input = calibration_input - ), inputArgs)) - + ), inputArgs)) + return(recursive_ggplot_serialize(InputCollect)) } @@ -235,13 +242,13 @@ function(dt_input=FALSE, dt_holidays=FALSE, jsonInputArgs=FALSE, InputCollect=FA #* @serializer json list(digits = 20, na = 'null') #* @post /robyn_run function(InputCollect, jsonRunArgs) { - + runArgs <- jsonlite::fromJSON(jsonRunArgs) InputCollect <- transform_InputCollect(InputCollect) - + OutputModels <- do.call(robyn_run, c(list(InputCollect = InputCollect - ), runArgs)) - + ), runArgs)) + return(recursive_ggplot_serialize(OutputModels)) } @@ -255,15 +262,15 @@ function(InputCollect, jsonRunArgs) { #* @serializer json list(digits = 20, na = 'null') #* @post /robyn_outputs function(InputCollect, OutputModels, jsonOutputsArgs) { - + outputsArgs <- jsonlite::fromJSON(jsonOutputsArgs) InputCollect <- transform_InputCollect(InputCollect) OutputModels <- jsonlite::fromJSON(OutputModels) - - OutputCollect <- do.call(robyn_outputs, c(list(InputCollect = InputCollect, + + OutputCollect <- do.call(robyn_outputs, c(list(InputCollect = InputCollect, OutputModels = OutputModels - ), outputsArgs)) - + ), outputsArgs)) + return(recursive_ggplot_serialize(OutputCollect)) } @@ -277,21 +284,21 @@ function(InputCollect, OutputModels, jsonOutputsArgs) { #* @param width The width of the image to be returned, specified in inches. #* @param height The height of the image to be returned, specified in inches. #* @post /robyn_onepagers -function(InputCollect, OutputCollect, jsonOnepagersArgs, dpi=100, width=12, height=8) { - +function(InputCollect, OutputCollect, jsonOnepagersArgs, dpi = 100, width = 12, height = 8) { + onepagersArgs <- jsonlite::fromJSON(jsonOnepagersArgs) InputCollect <- transform_InputCollect(InputCollect) OutputCollect <- transform_OutputCollect(OutputCollect, onepagersArgs[["select_model"]]) - - onepager <- do.call(robyn_onepagers, c(list(InputCollect = InputCollect, + + onepager <- do.call(robyn_onepagers, c(list(InputCollect = InputCollect, OutputCollect = OutputCollect - ), onepagersArgs)) - + ), onepagersArgs)) + dpi <- as.numeric(dpi) width <- as.numeric(width) height <- as.numeric(height) - - return(ggplot_serialize(onepager[[onepagersArgs[["select_model"]]]], dpi=dpi, width=width, height=height)) + + return(ggplot_serialize(onepager[[onepagersArgs[["select_model"]]]], dpi = dpi, width = width, height = height)) } #* Generates and returns a serialized image of the allocation one-pager @@ -304,21 +311,21 @@ function(InputCollect, OutputCollect, jsonOnepagersArgs, dpi=100, width=12, heig #* @param width The width of the image to be returned, specified in inches. #* @param height The height of the image to be returned, specified in inches. #* @post /robyn_allocator -function(InputCollect, OutputCollect, jsonAllocatorArgs, dpi=100, width=12, height=8) { - +function(InputCollect, OutputCollect, jsonAllocatorArgs, dpi = 100, width = 12, height = 8) { + allocatorArgs <- jsonlite::fromJSON(jsonAllocatorArgs) InputCollect <- transform_InputCollect(InputCollect) OutputCollect <- transform_OutputCollect(OutputCollect, allocatorArgs[["select_model"]]) - + AllocatorCollect <- do.call(robyn_allocator, c(list(InputCollect = InputCollect, OutputCollect = OutputCollect - ), allocatorArgs)) - + ), allocatorArgs)) + dpi <- as.numeric(dpi) width <- as.numeric(width) height <- as.numeric(height) - - return(ggplot_serialize(AllocatorCollect$plots$plots, dpi=dpi, width=width, height=height)) + + return(ggplot_serialize(AllocatorCollect$plots$plots, dpi = dpi, width = width, height = height)) } #* Exports model data in JSON format @@ -328,13 +335,13 @@ function(InputCollect, OutputCollect, jsonAllocatorArgs, dpi=100, width=12, heig #* @param OutputModels A JSON string representing the models created by 'robyn_run()'. #* @param jsonWriteArgs A JSON string containing additional parameters for the 'robyn_write()' function. #* @post /robyn_write -function(InputCollect=FALSE, OutputCollect=FALSE, OutputModels=FALSE, jsonWriteArgs) { - +function(InputCollect = FALSE, OutputCollect = FALSE, OutputModels = FALSE, jsonWriteArgs) { + writeArgs <- jsonlite::fromJSON(jsonWriteArgs) - InputCollect <- if (!InputCollect==FALSE) transform_InputCollect(InputCollect) else NULL - OutputModels <- if (!OutputModels==FALSE) jsonlite::fromJSON(OutputModels) else NULL - OutputCollect <- if (!OutputCollect==FALSE) transform_OutputCollect(OutputCollect) else NULL - + InputCollect <- if (!InputCollect == FALSE) transform_InputCollect(InputCollect) else NULL + OutputModels <- if (!OutputModels == FALSE) jsonlite::fromJSON(OutputModels) else NULL + OutputCollect <- if (!OutputCollect == FALSE) transform_OutputCollect(OutputCollect) else NULL + do.call(robyn_write, c(list(InputCollect = InputCollect, OutputCollect = OutputCollect, OutputModels = OutputModels), writeArgs)) } @@ -346,16 +353,20 @@ function(InputCollect=FALSE, OutputCollect=FALSE, OutputModels=FALSE, jsonWriteA #* @serializer json list(digits = 20, na = 'null') #* @post /robyn_recreate function(dt_input, dt_holidays, jsonRecreateArgs) { - + recreateArgs <- jsonlite::fromJSON(jsonRecreateArgs) - dt_input <- dt_input %>% hex_to_raw() %>% arrow::read_feather() - dt_holidays <- dt_holidays %>% hex_to_raw() %>% arrow::read_feather() - + dt_input <- dt_input %>% + hex_to_raw() %>% + arrow::read_feather() + dt_holidays <- dt_holidays %>% + hex_to_raw() %>% + arrow::read_feather() + RobynRecreated <- do.call(robyn_recreate, c(list(dt_input = dt_input, dt_holidays = dt_holidays - ), recreateArgs)) - - return(recursive_ggplot_serialize(RobynRecreated)) + ), recreateArgs)) + + return(recursive_ggplot_serialize(RobynRecreated)) } #* Retrieves the names of hyperparameters based on adstock and media spend data @@ -364,9 +375,9 @@ function(dt_input, dt_holidays, jsonRecreateArgs) { #* @param all_media A JSON string representing the list of paid media spends. #* @post /hyper_names function(adstock, all_media) { - + hyper_names_list <- hyper_names(adstock = adstock, all_media = jsonlite::fromJSON(all_media)) - + return(hyper_names_list) } @@ -379,14 +390,14 @@ function(adstock, all_media) { #* @serializer json list(digits = 20, na = 'null') #* @post /robyn_refresh function(dt_input, dt_holidays, jsonRefreshArgs) { - + refreshArgs <- jsonlite::fromJSON(jsonRefreshArgs) - dt_input <- if (!dt_input==FALSE) hex_to_raw(dt_input) %>% arrow::read_feather() else NULL - dt_holidays <- if (!dt_holidays==FALSE) hex_to_raw(dt_holidays) %>% arrow::read_feather() else NULL - + dt_input <- if (!dt_input == FALSE) hex_to_raw(dt_input) %>% arrow::read_feather() else NULL + dt_holidays <- if (!dt_holidays == FALSE) hex_to_raw(dt_holidays) %>% arrow::read_feather() else NULL + RobynRefresh <- do.call(robyn_refresh, c(list(dt_input = dt_input, dt_holidays = dt_holidays - ), refreshArgs)) - + ), refreshArgs)) + return(recursive_ggplot_serialize(RobynRefresh)) } \ No newline at end of file