Skip to content

Commit

Permalink
enhacne: add lock
Browse files Browse the repository at this point in the history
  • Loading branch information
TankyH committed Sep 15, 2024
1 parent ebbbff2 commit 53df7e2
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 100 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.DS_Store
.idea
.Rproj.user
.Rhistory
node_modules/
Expand Down
211 changes: 111 additions & 100 deletions robyn_api/robynapi_enpoints.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
### Import necessary libraries ###

# Function to locate and load required virtual environment used to install nevergrad
load_pythonenv <- function(env="r-reticulate"){
load_pythonenv <- function(env = "r-reticulate") {
tryCatch(
{
library("reticulate")
if(reticulate::condaenv_exists(env)) {use_condaenv(env)}
else if (reticulate::virtualenv_exists(env)) {use_virtualenv(env, required = TRUE)}
else {message('Install nevergrad to proceed')}
},
error=function(e) {
{
library("reticulate")
if (reticulate::condaenv_exists(env)) { use_condaenv(env) }
else if (reticulate::virtualenv_exists(env)) { use_virtualenv(env, required = TRUE) }
else { message('Install nevergrad to proceed') }
},
error = function(e) {
message('Install nevergrad to proceed')
}
)
Expand All @@ -30,16 +30,19 @@ suppressPackageStartupMessages({
library(Robyn)
library(tibble)
library(promises)
library(synchronicity)
})

### GLOBAL LOCK ###
mutex = boost.mutex()

### FUNCTIONS ###

#* Convert hex data to raw bytes
#* This function is called to import the table data such as dt_simulated_weekly, dt_prophet_holidays
hex_to_raw <- function(x) {
chars <- unlist(regmatches(x, gregexpr("..", x)))
as.raw(strtoi(chars, base=16L))
as.raw(strtoi(chars, base = 16L))
}

#* Whether an object is a named list
Expand Down Expand Up @@ -86,57 +89,61 @@ convert_dates_to_Date <- function(json_data) {
recursive_convert <- function(x) {
if (is.list(x)) {
lapply(x, recursive_convert)
} else if (is.character(x) && length(x) == 1 && grepl("^\\d{4}-\\d{2}-\\d{2}$", x)) {
} else if (is.character(x) &&
length(x) == 1 &&
grepl("^\\d{4}-\\d{2}-\\d{2}$", x)) {
as.Date(x)
} else {
x
}
}

# Recursively convert date strings to Date objects
converted_data <- recursive_convert(json_data)

return(converted_data)
}

### Robyn functions expect data/objects to be R unique one, but if bypassing data/obj via REST API, we need to convert these into R unique type like tibble or factor.
#* transform InputCollect from API
transform_InputCollect <- function(InputCollect) {
promise({
InputCollect <- jsonlite::fromJSON(InputCollect) %>% convert_dates_to_Date()

# list > tibble
vars_to_tibble <- c("dt_input", "dt_holidays", "dt_mod", "dt_modRollWind", "dt_inputRollWind", "calibration_input")
for (var in vars_to_tibble) {
InputCollect[[var]] <- as_tibble(InputCollect[[var]])
InputCollect[[var]][] <- lapply(InputCollect[[var]], function(col) {
if (all(grepl("^\\d{4}-\\d{2}-\\d{2}$", col))) {
return(as.Date(col))
}
return(col)
})
}

# Null Treatment
for (var in names(InputCollect)) {
if(length(InputCollect[[var]])==0) {
InputCollect[[var]] <- NULL
named_list <- setNames(alist(x=NULL), var)
InputCollect <- c(InputCollect, named_list)

}
}

# Add class name which is used as a checker in Robyn
class(InputCollect) <- c("robyn_inputs", "list")

return(InputCollect)
lock(mutex)
InputCollect <- jsonlite::fromJSON(InputCollect) %>% convert_dates_to_Date()

# list > tibble
vars_to_tibble <- c("dt_input", "dt_holidays", "dt_mod", "dt_modRollWind", "dt_inputRollWind", "calibration_input")
for (var in vars_to_tibble) {
InputCollect[[var]] <- as_tibble(InputCollect[[var]])
InputCollect[[var]][] <- lapply(InputCollect[[var]], function(col) {
if (all(grepl("^\\d{4}-\\d{2}-\\d{2}$", col))) {
return(as.Date(col))
}
return(col)
})
}

# Null Treatment
for (var in names(InputCollect)) {
if (length(InputCollect[[var]]) == 0) {
InputCollect[[var]] <- NULL
named_list <- setNames(alist(x = NULL), var)
InputCollect <- c(InputCollect, named_list)

}
}

# Add class name which is used as a checker in Robyn
class(InputCollect) <- c("robyn_inputs", "list")
unlock(mutex)
return(InputCollect)
})
}

#* transform OutputCollect from API
transform_OutputCollect <- function(OutputCollect, select_model=FALSE) {
transform_OutputCollect <- function(OutputCollect, select_model = FALSE) {
promise({
lock(mutex)
OutputCollect <- jsonlite::fromJSON(OutputCollect)
# Add class name which is used as a checker in Robyn
class(OutputCollect) <- c("robyn_outputs", "list")
Expand Down Expand Up @@ -165,13 +172,13 @@ transform_OutputCollect <- function(OutputCollect, select_model=FALSE) {
}

# convert only target model data
if (!select_model==FALSE) {
if (!select_model == FALSE) {
OutputCollect[['allPareto']][['plotDataCollect']][[select_model]][['plot2data']][['plotWaterfallLoop']] <-
OutputCollect[['allPareto']][['plotDataCollect']][[select_model]][['plot2data']][['plotWaterfallLoop']] %>%
as_tibble() %>%
mutate(across(where(is.character), as.factor))
as_tibble() %>%
mutate(across(where(is.character), as.factor))
}

unlock(mutex)
return(OutputCollect)
})

Expand Down Expand Up @@ -210,20 +217,20 @@ function() {
#* @param calibration_input A hexadecimal string representing the binary content of a calibration data feather file.
#* @serializer json list(digits = 20, na = 'null')
#* @post /robyn_inputs
function(dt_input=FALSE, dt_holidays=FALSE, jsonInputArgs=FALSE, InputCollect=FALSE, calibration_input=FALSE) {

inputArgs <- if (!jsonInputArgs==FALSE) jsonlite::fromJSON(jsonInputArgs) else NULL
dt_input <- if (!dt_input==FALSE) hex_to_raw(dt_input) %>% arrow::read_feather() else NULL
dt_holidays <- if (!dt_holidays==FALSE) hex_to_raw(dt_holidays) %>% arrow::read_feather() else NULL
InputCollect <- if (!InputCollect==FALSE) transform_InputCollect(InputCollect) else NULL
calibration_input <- if (!calibration_input==FALSE) hex_to_raw(calibration_input) %>% arrow::read_feather() else NULL
function(dt_input = FALSE, dt_holidays = FALSE, jsonInputArgs = FALSE, InputCollect = FALSE, calibration_input = FALSE) {

inputArgs <- if (!jsonInputArgs == FALSE) jsonlite::fromJSON(jsonInputArgs) else NULL
dt_input <- if (!dt_input == FALSE) hex_to_raw(dt_input) %>% arrow::read_feather() else NULL
dt_holidays <- if (!dt_holidays == FALSE) hex_to_raw(dt_holidays) %>% arrow::read_feather() else NULL
InputCollect <- if (!InputCollect == FALSE) transform_InputCollect(InputCollect) else NULL
calibration_input <- if (!calibration_input == FALSE) hex_to_raw(calibration_input) %>% arrow::read_feather() else NULL

InputCollect <- do.call(robyn_inputs, c(list(dt_input = dt_input,
dt_holidays = dt_holidays,
InputCollect = InputCollect,
calibration_input = calibration_input
), inputArgs))
), inputArgs))

return(recursive_ggplot_serialize(InputCollect))
}

Expand All @@ -235,13 +242,13 @@ function(dt_input=FALSE, dt_holidays=FALSE, jsonInputArgs=FALSE, InputCollect=FA
#* @serializer json list(digits = 20, na = 'null')
#* @post /robyn_run
function(InputCollect, jsonRunArgs) {

runArgs <- jsonlite::fromJSON(jsonRunArgs)
InputCollect <- transform_InputCollect(InputCollect)

OutputModels <- do.call(robyn_run, c(list(InputCollect = InputCollect
), runArgs))
), runArgs))

return(recursive_ggplot_serialize(OutputModels))
}

Expand All @@ -255,15 +262,15 @@ function(InputCollect, jsonRunArgs) {
#* @serializer json list(digits = 20, na = 'null')
#* @post /robyn_outputs
function(InputCollect, OutputModels, jsonOutputsArgs) {

outputsArgs <- jsonlite::fromJSON(jsonOutputsArgs)
InputCollect <- transform_InputCollect(InputCollect)
OutputModels <- jsonlite::fromJSON(OutputModels)
OutputCollect <- do.call(robyn_outputs, c(list(InputCollect = InputCollect,

OutputCollect <- do.call(robyn_outputs, c(list(InputCollect = InputCollect,
OutputModels = OutputModels
), outputsArgs))
), outputsArgs))

return(recursive_ggplot_serialize(OutputCollect))
}

Expand All @@ -277,21 +284,21 @@ function(InputCollect, OutputModels, jsonOutputsArgs) {
#* @param width The width of the image to be returned, specified in inches.
#* @param height The height of the image to be returned, specified in inches.
#* @post /robyn_onepagers
function(InputCollect, OutputCollect, jsonOnepagersArgs, dpi=100, width=12, height=8) {
function(InputCollect, OutputCollect, jsonOnepagersArgs, dpi = 100, width = 12, height = 8) {

onepagersArgs <- jsonlite::fromJSON(jsonOnepagersArgs)
InputCollect <- transform_InputCollect(InputCollect)
OutputCollect <- transform_OutputCollect(OutputCollect, onepagersArgs[["select_model"]])
onepager <- do.call(robyn_onepagers, c(list(InputCollect = InputCollect,

onepager <- do.call(robyn_onepagers, c(list(InputCollect = InputCollect,
OutputCollect = OutputCollect
), onepagersArgs))
), onepagersArgs))

dpi <- as.numeric(dpi)
width <- as.numeric(width)
height <- as.numeric(height)
return(ggplot_serialize(onepager[[onepagersArgs[["select_model"]]]], dpi=dpi, width=width, height=height))

return(ggplot_serialize(onepager[[onepagersArgs[["select_model"]]]], dpi = dpi, width = width, height = height))
}

#* Generates and returns a serialized image of the allocation one-pager
Expand All @@ -304,21 +311,21 @@ function(InputCollect, OutputCollect, jsonOnepagersArgs, dpi=100, width=12, heig
#* @param width The width of the image to be returned, specified in inches.
#* @param height The height of the image to be returned, specified in inches.
#* @post /robyn_allocator
function(InputCollect, OutputCollect, jsonAllocatorArgs, dpi=100, width=12, height=8) {
function(InputCollect, OutputCollect, jsonAllocatorArgs, dpi = 100, width = 12, height = 8) {

allocatorArgs <- jsonlite::fromJSON(jsonAllocatorArgs)
InputCollect <- transform_InputCollect(InputCollect)
OutputCollect <- transform_OutputCollect(OutputCollect, allocatorArgs[["select_model"]])

AllocatorCollect <- do.call(robyn_allocator, c(list(InputCollect = InputCollect,
OutputCollect = OutputCollect
), allocatorArgs))
), allocatorArgs))

dpi <- as.numeric(dpi)
width <- as.numeric(width)
height <- as.numeric(height)
return(ggplot_serialize(AllocatorCollect$plots$plots, dpi=dpi, width=width, height=height))

return(ggplot_serialize(AllocatorCollect$plots$plots, dpi = dpi, width = width, height = height))
}

#* Exports model data in JSON format
Expand All @@ -328,13 +335,13 @@ function(InputCollect, OutputCollect, jsonAllocatorArgs, dpi=100, width=12, heig
#* @param OutputModels A JSON string representing the models created by 'robyn_run()'.
#* @param jsonWriteArgs A JSON string containing additional parameters for the 'robyn_write()' function.
#* @post /robyn_write
function(InputCollect=FALSE, OutputCollect=FALSE, OutputModels=FALSE, jsonWriteArgs) {
function(InputCollect = FALSE, OutputCollect = FALSE, OutputModels = FALSE, jsonWriteArgs) {

writeArgs <- jsonlite::fromJSON(jsonWriteArgs)
InputCollect <- if (!InputCollect==FALSE) transform_InputCollect(InputCollect) else NULL
OutputModels <- if (!OutputModels==FALSE) jsonlite::fromJSON(OutputModels) else NULL
OutputCollect <- if (!OutputCollect==FALSE) transform_OutputCollect(OutputCollect) else NULL
InputCollect <- if (!InputCollect == FALSE) transform_InputCollect(InputCollect) else NULL
OutputModels <- if (!OutputModels == FALSE) jsonlite::fromJSON(OutputModels) else NULL
OutputCollect <- if (!OutputCollect == FALSE) transform_OutputCollect(OutputCollect) else NULL

do.call(robyn_write, c(list(InputCollect = InputCollect, OutputCollect = OutputCollect, OutputModels = OutputModels), writeArgs))
}

Expand All @@ -346,16 +353,20 @@ function(InputCollect=FALSE, OutputCollect=FALSE, OutputModels=FALSE, jsonWriteA
#* @serializer json list(digits = 20, na = 'null')
#* @post /robyn_recreate
function(dt_input, dt_holidays, jsonRecreateArgs) {

recreateArgs <- jsonlite::fromJSON(jsonRecreateArgs)
dt_input <- dt_input %>% hex_to_raw() %>% arrow::read_feather()
dt_holidays <- dt_holidays %>% hex_to_raw() %>% arrow::read_feather()

dt_input <- dt_input %>%
hex_to_raw() %>%
arrow::read_feather()
dt_holidays <- dt_holidays %>%
hex_to_raw() %>%
arrow::read_feather()

RobynRecreated <- do.call(robyn_recreate, c(list(dt_input = dt_input,
dt_holidays = dt_holidays
), recreateArgs))
return(recursive_ggplot_serialize(RobynRecreated))
), recreateArgs))

return(recursive_ggplot_serialize(RobynRecreated))
}

#* Retrieves the names of hyperparameters based on adstock and media spend data
Expand All @@ -364,9 +375,9 @@ function(dt_input, dt_holidays, jsonRecreateArgs) {
#* @param all_media A JSON string representing the list of paid media spends.
#* @post /hyper_names
function(adstock, all_media) {

hyper_names_list <- hyper_names(adstock = adstock, all_media = jsonlite::fromJSON(all_media))

return(hyper_names_list)
}

Expand All @@ -379,14 +390,14 @@ function(adstock, all_media) {
#* @serializer json list(digits = 20, na = 'null')
#* @post /robyn_refresh
function(dt_input, dt_holidays, jsonRefreshArgs) {

refreshArgs <- jsonlite::fromJSON(jsonRefreshArgs)
dt_input <- if (!dt_input==FALSE) hex_to_raw(dt_input) %>% arrow::read_feather() else NULL
dt_holidays <- if (!dt_holidays==FALSE) hex_to_raw(dt_holidays) %>% arrow::read_feather() else NULL
dt_input <- if (!dt_input == FALSE) hex_to_raw(dt_input) %>% arrow::read_feather() else NULL
dt_holidays <- if (!dt_holidays == FALSE) hex_to_raw(dt_holidays) %>% arrow::read_feather() else NULL

RobynRefresh <- do.call(robyn_refresh, c(list(dt_input = dt_input,
dt_holidays = dt_holidays
), refreshArgs))
), refreshArgs))

return(recursive_ggplot_serialize(RobynRefresh))
}

0 comments on commit 53df7e2

Please sign in to comment.