Skip to content

Commit

Permalink
fixing random forest
Browse files Browse the repository at this point in the history
fixing random forest
  • Loading branch information
jreps committed Apr 18, 2024
1 parent 778b775 commit dda160f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 25 deletions.
19 changes: 12 additions & 7 deletions R/HyperparameterOptimization.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ autoTuning <- function(

performances <- list()
for(i in 1:self$resamplingFunction$getInterationCount()){

message(paste0('resample ', i, ' of ', self$resamplingFunction$getInterationCount()))
# function to split trainData into train/val - cv/booststrap
dataIndexes <- self$resamplingFunction$getIndexes(
data = trainData,
Expand All @@ -30,21 +30,23 @@ autoTuning <- function(
validationPrediction <- self$fit(
data = trainData,
trainIndex = dataIndexes$trainRowIds,
validationIndex= dataIndexes$validationRowIds,
validationIndex = dataIndexes$validationRowIds,
returnPredictionOnly = T
)
# user specified performance metric that takes prediction and spits out performance (could be multiple inputs)
performanceTemp <- self$performanceFunction$metricFunction(validationPrediction)
performances[[length(performances)+1]] <- performanceTemp

summaryPerformance[[length(summaryPerformance) + 1]] <- list(
hyperparameter = hyperparameter,
fold = i,
performance = performanceTemp
)
#summaryPerformance[[length(summaryPerformance) + 1]] <- list(
# hyperparameter = hyperparameter,
# fold = i,
# performance = performanceTemp
# )
}

message('Aggregating performance')
aggregatePerformanceIteration <- self$performanceFunction$aggregateFunction(performances)
message('aggregate performance: ', aggregatePerformanceIteration)

summaryPerformance[[length(summaryPerformance) + 1]] <- list(
hyperparameter = hyperparameter,
Expand All @@ -54,18 +56,21 @@ autoTuning <- function(

if(start){
start <- F
message('Setting initial currentOptimal')
currentOptimal <- aggregatePerformanceIteration
optimalHyperparameters <- hyperparameter
}

# performance selection function - take performance vector to identify best hyper-params (returns index)
if(self$performanceFunction$maxmize){
if(currentOptimal < aggregatePerformanceIteration){
message('New maximum')
optimalHyperparameters = hyperparameter
currentOptimal <- aggregatePerformanceIteration
}
} else{
if(currentOptimal > aggregatePerformanceIteration){
message('New minimum')
optimalHyperparameters = hyperparameter
currentOptimal <- aggregatePerformanceIteration
}
Expand Down
66 changes: 48 additions & 18 deletions R/newRandomForest.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ fitSklearnBase <- function(
trainY <- reticulate::r_to_py(data$labels$outcomeCount[trainIndexes])
trainX <- reticulate::r_to_py(data$matrix[trainIndexes,])
validationX <- reticulate::r_to_py(data$matrix[validationIndexes,])

validationY <- data$labels[validationIndexes,]

if(requiresDenseMatrix){
ParallelLogger::logInfo('Converting sparse martix to dense matrix (CV)')
trainX <- trainX$toarray()
Expand All @@ -96,7 +97,12 @@ fitSklearnBase <- function(
model <- fitPythonModel(classifier, hyperparameters, seed, trainX, trainY, np, pythonClassifier)

ParallelLogger::logInfo("Calculating predictions on validation...")
prediction <- predictValues(model = model, data = validationX, cohort = data$labels[validationIdexes,], type = 'binary')
prediction <- predictValues(
model = model,
data = validationX,
cohort = validationY,
type = 'binary'
)

if(returnPredictionOnly){
return(prediction)
Expand All @@ -114,21 +120,30 @@ fitSklearnBase <- function(
}

# feature importance
variableImportance <- tryCatch({reticulate::py_to_r(model$feature_importances_)}, error = function(e){ParallelLogger::logInfo(e);return(rep(1,ncol(matrixData)))})
covariateImportance <- data$covariateRef
importanceValues <- tryCatch({reticulate::py_to_r(model$feature_importances_)}, error = function(e){ParallelLogger::logInfo(e);return(rep(1,ncol(matrixData)))})
importanceValues[is.na(importanceValues)] <- 0
covariateImportance$included <- 1
#covariateImportance$included <- 0
#covariateImportance$included[importanceValues > 0 ] <- 1
covariateImportance$covariateValue <- unlist(importanceValues)


return(
list(
prediction = prediction,
variableImportance = variableImportance
covariateImportance = covariateImportance
)
)

}

predictSklearnBase <- function(
plpModelLocation,
covariateMap,
requireDenseMatrix,
plpModelLocation, #self
covariateMap, # self
covariateImportance, #self
requiresDenseMatrix, # self
saveToJson, #self
data,
cohort
){
Expand All @@ -149,7 +164,7 @@ predictSklearnBase <- function(
}

# load model
if(attr(plpModel,'saveToJson')){
if(saveToJson){
modelLocation <- reticulate::r_to_py(file.path(plpModelLocation,"model.json"))
model <- sklearnFromJson(path=modelLocation)
} else{
Expand All @@ -162,7 +177,7 @@ predictSklearnBase <- function(
pythonData <- reticulate::r_to_py(newData[,included, drop = F])

# make dense if needed
if(requireDenseMatrix){
if(requiresDenseMatrix){
pythonData <- pythonData$toarray()
}

Expand All @@ -180,6 +195,9 @@ RandomForest <- R6Class("RandomForest", list(
nJobs = NULL,
modelLocation = NULL,

covariateMap = NULL,
covariateImportance = NULL,

requiresDenseMatrix = F,
name = "Random forest",
fitFunction = 'fitSklearnBase',
Expand Down Expand Up @@ -313,14 +331,14 @@ RandomForest <- R6Class("RandomForest", list(
hyperparameterGenerator = GridHyperparameter
) {

self$hyperparameters$ntrees$grid = ntrees
self$hyperparameters$ntrees$grid = lapply(ntrees, function(x){as.integer(x)})
self$hyperparameters$criterion$grid = criterion
self$hyperparameters$maxDepth$grid = maxDepth
self$hyperparameters$minSamplesSplit$grid = minSamplesSplit
self$hyperparameters$minSamplesLeaf$grid = minSamplesLeaf
self$hyperparameters$maxDepth$grid = lapply(maxDepth, function(x){as.integer(x)})
self$hyperparameters$minSamplesSplit$grid = lapply(minSamplesSplit, function(x){if(x>=1){as.integer(x)}else{x}})
self$hyperparameters$minSamplesLeaf$grid = lapply(minSamplesLeaf, function(x){if(x>=1){as.integer(x)}else{x}})
self$hyperparameters$minWeightFractionLeaf$grid = minWeightFractionLeaf
self$hyperparameters$mtries$grid = mtries
self$hyperparameters$maxLeafNodes$grid = maxLeafNodes
self$hyperparameters$maxLeafNodes$grid = lapply(maxLeafNodes, function(x){if(!is.null(x)){as.integer(x)}else{x}})
self$hyperparameters$minImpurityDecrease$grid = minImpurityDecrease
self$hyperparameters$bootstrap$grid = bootstrap
self$hyperparameters$maxSamples$grid = maxSamples
Expand Down Expand Up @@ -375,16 +393,24 @@ RandomForest <- R6Class("RandomForest", list(
returnPredictionOnly = returnPredictionOnly
)

return(result$prediction)
# store the mapping
self$covariateMap <- data$covariateMap

if(!returnPredictionOnly){
self$covariateImportance <- result$covariateImportance
}

return(invisible(result)) # do we want to return these or set self$trainPrediction?
},

predict = function(data, cohort) {

prediction <- predictSklearnBase(
plpModelLocation = self$modelLocation,
covariateMap = self$covariateMap,
requireDenseMatrix = self$requireDenseMatrix,
covariateMap = self$covariateMap,
covariateImportance = self$covariateImportance,
requiresDenseMatrix = self$requiresDenseMatrix,
saveToJson = self$saveToJson,
data = data,
cohort = cohort
)
Expand Down Expand Up @@ -501,6 +527,10 @@ CrossValidationSampler <- R6Class("CrossValidationSampler", list(
meanList <- function(x){
mean(unlist(x))
}
computeAucNew <- function(prediction){
return(PatientLevelPrediction:::aucWithoutCi(prediction = prediction$value, truth = prediction$outcomeCount))
}


PerformanceFunction <- R6Class("PerformanceFunction", list(
maxmize = NULL,
Expand All @@ -509,7 +539,7 @@ PerformanceFunction <- R6Class("PerformanceFunction", list(

initialize = function(
maxmize = T,
metricFunctionName = 'computeAuc',
metricFunctionName = 'computeAucNew',
aggregateFunctionName = 'meanList'
) {
self$maxmize <- maxmize
Expand Down

0 comments on commit dda160f

Please sign in to comment.