Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ROC curve #14

Merged
merged 14 commits into from
Oct 31, 2023
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ Imports:
farff,
Rcpp (>= 1.0.5),
RcppParallel,
stringr
stringr,
caret,
pracma
Encoding: UTF-8
LinkingTo: Rcpp, BH (>= 1.51.0), RcppParallel
RoxygenNote: 7.2.3
Expand Down
9 changes: 4 additions & 5 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
# Generated by roxygen2: do not edit by hand

export(aurocEXPLORE)
export(changeSetting)
export(getSetting)
export(modelsCurveExplore)
export(predictExplore)
export(saveData)
export(settingsExplore)
export(rocCurveExplore)
export(trainExplore)
import(Rcpp)
import(checkmate)
importFrom(RcppParallel,RcppParallelLibs)
importFrom(caret,confusionMatrix)
importFrom(farff,writeARFF)
importFrom(pracma,trapz)
importFrom(stringr,str_extract)
importFrom(stringr,str_replace_all)
importFrom(stringr,str_split_fixed)
Expand Down
15 changes: 1 addition & 14 deletions R/HelperFunctions.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#'
#' @return A parameter value, character.
#' @importFrom stringr str_extract str_replace_all
#' @export
getSetting <- function(settings, parameter, type = "value") {
extraction <- stringr::str_extract(settings, paste0(parameter , "=.*?\u000A"))[[1]]
extraction <- stringr::str_replace_all(extraction, "\\n", "")
Expand All @@ -33,8 +32,6 @@ getSetting <- function(settings, parameter, type = "value") {
#'
#' @return A setting parameter value
#' @importFrom utils write.table
#'
#' @export
changeSetting <- function(settings, parameter, input, default_setting) {

current_setting <- getSetting(settings, parameter, type = "complete")
Expand Down Expand Up @@ -77,8 +74,6 @@ changeSetting <- function(settings, parameter, input, default_setting) {
#'
#' @importFrom farff writeARFF
#' @importFrom utils write.table
#'
#' @export
saveData <- function(output_path, train_data, file_name) {

# Save data as arff file
Expand All @@ -96,12 +91,4 @@ saveData <- function(output_path, train_data, file_name) {
row.names = FALSE)

# TODO: Support other file formats?
}

simple_auc <- function(TPR, FPR){
# inputs already sorted, best scores first
# TODO: different computation? is it same as standard packages (how LASSO computed)?
dFPR <- c(diff(FPR), 0)
dTPR <- c(diff(TPR), 0)
sum(TPR * dFPR) + sum(dTPR * dFPR)/2
}
}
124 changes: 94 additions & 30 deletions R/MainFunctions.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
#' @param ClassFeature String, should be name of one of columns in data train. Always provided by the user. The string should be enclused in single quotation marks, e.g. 'class'
#' @param PositiveClass 1 or string (?) (should be one of elements of column 'ClassFeature' in data train). Always provided by the user. The string should be enclused in single quotation marks, e.g. 'class'
#' @param FeatureInclude Empty or string (should be name of one of columns in data train)
#' @param Maximize One of list with strings, list = "ACCURACY", ...
#' @param Maximize One of list with strings, list = "ACCURACY", "SENSITIVITY", "SPECIFICITY", ...
#' @param Accuracy Float 0-0.999 -> default = 0 (if 0, make empty = computationally more beneficial)
#' @param BalancedAccuracy Float 0-0.999 -> default = 0 (if 0, make empty = computationally more beneficial)
#' @param Specificity float 0-0.999, default = 0
#' @param PrintSettings True or False
#' @param PrintPerformance True or False
Expand All @@ -43,17 +44,18 @@ trainExplore <- function(train_data = NULL,
FeatureInclude = "",
Maximize = "ACCURACY",
Accuracy = 0,
BalancedAccuracy = 0,
Specificity = 0,
PrintSettings = TRUE,
PrintPerformance = TRUE,
Subsumption = TRUE,
BranchBound = TRUE,
Parallel = FALSE) {

if (!dir.exists(output_path)) {
dir.create(output_path, recursive = TRUE)
}

}
# Create output folder
if(!endsWith(output_path, "/")) {
warning("Output path should end with /, add this")
Expand All @@ -68,7 +70,7 @@ trainExplore <- function(train_data = NULL,
OutputFile <- paste0(output_path, file_name, ".result")
} else {
checkmate::checkFileExists(OutputFile,
add = errorMessage)
add = errorMessage)
}

# check settings_path
Expand All @@ -91,6 +93,7 @@ trainExplore <- function(train_data = NULL,
checkString(FeatureInclude),
checkString(Maximize),
checkDouble(Accuracy),
checkDouble(BalancedAccuracy),
checkDouble(Specificity),
checkLogical(PrintSettings),
checkLogical(PrintPerformance),
Expand All @@ -101,13 +104,14 @@ trainExplore <- function(train_data = NULL,
combine = "and"
)
checkmate::reportAssertions(collection = errorMessage)

PrintSettings <- ifelse(PrintSettings == TRUE, "yes", "no")
PrintPerformance <- ifelse(PrintPerformance == TRUE, "yes", "no")
Subsumption <- ifelse(Subsumption == TRUE, "yes", "no")
BranchBound <- ifelse(BranchBound == TRUE, "yes", "no")
Parallel <- ifelse(Parallel == TRUE, "yes", "no")
Accuracy <- ifelse(Accuracy == 0, "", Specificity)
Accuracy <- ifelse(Accuracy == 0, "", Accuracy)
BalancedAccuracy <- ifelse(BalancedAccuracy == 0, "", BalancedAccuracy)
Specificity <- ifelse(Specificity == 0, "", Specificity)

# Create project setting
Expand Down Expand Up @@ -146,6 +150,7 @@ trainExplore <- function(train_data = NULL,
FeatureInclude = FeatureInclude,
Maximize = Maximize,
Accuracy = Accuracy,
BalancedAccuracy = BalancedAccuracy,
Specificity = Specificity,
PrintSettings = PrintSettings,
PrintPerformance = PrintPerformance,
Expand All @@ -163,7 +168,7 @@ trainExplore <- function(train_data = NULL,

# Load model
rule_string <- stringr::str_extract(results, "Best candidate \\(overall\\):.*?\u000A")

# Clean string
rule_string <- stringr::str_replace(rule_string, "Best candidate \\(overall\\):", "")
rule_string <- stringr::str_replace_all(rule_string, " ", "")
Expand Down Expand Up @@ -200,7 +205,6 @@ trainExplore <- function(train_data = NULL,
#'
#' @return Settings path
#' @import checkmate
#' @export
settingsExplore <- function(settings,
output_path, # C++ cannot handle spaces in file path well, avoid those
file_name,
Expand All @@ -215,14 +219,15 @@ settingsExplore <- function(settings,
FeatureInclude = "",
Maximize = "ACCURACY",
Accuracy = 0,
BalancedAccuracy = 0,
Specificity = 0,
PrintSettings = "yes",
PrintPerformance = "yes",
Subsumption = "yes",
BranchBound = "yes",
Parallel = "no") {


# Insert location training data and cutoff file if train_data is entered
if (!is.null(train_data)) {
settings <- changeSetting(settings, parameter = "DataFile", input = paste0(output_path, file_name, ".arff"))
Expand All @@ -240,6 +245,7 @@ settingsExplore <- function(settings,
settings <- changeSetting(settings, parameter = "FeatureInclude", input = FeatureInclude)
settings <- changeSetting(settings, parameter = "Maximize", input = Maximize)
settings <- changeSetting(settings, parameter = "Accuracy", input = Accuracy)
settings <- changeSetting(settings, parameter = "BalancedAccuracy", input = BalancedAccuracy)
settings <- changeSetting(settings, parameter = "Specificity", input = Specificity)
settings <- changeSetting(settings, parameter = "PrintSettings", input = PrintSettings)
settings <- changeSetting(settings, parameter = "PrintPerformance", input = PrintPerformance)
Expand Down Expand Up @@ -303,7 +309,7 @@ predictExplore <- function(model, test_data) {
}


#' aucrocExplore
#' modelsCurveExplore # TODO: update documentation?
#'
#' @param output_path A string declaring the path to the settings
#' @param train_data Train data
Expand All @@ -312,33 +318,91 @@ predictExplore <- function(model, test_data) {
#' @param ... List of arguments
#'
#' @import checkmate
#' @return auroc
#' @return models for different sensitivities/specificities
#' @export
aurocEXPLORE <- function(output_path, train_data, settings_path, file_name, ...) {
# TODO: check with latest implementation in PLP
modelsCurveExplore <- function(train_data = NULL,
settings_path = NULL,
output_path,
file_name = "train_data",
OutputFile = NULL,
StartRulelength = 1,
EndRulelength = 3,
OperatorMethod = "EXHAUSTIVE",
CutoffMethod = "RVAC",
ClassFeature = "'class'",
PositiveClass = "'Iris-versicolor'",
FeatureInclude = "",
Maximize = "ACCURACY",
Accuracy = 0,
BalancedAccuracy = 0,
Specificity = 0,
PrintSettings = TRUE,
PrintPerformance = TRUE,
Subsumption = TRUE,
BranchBound = TRUE,
Parallel = FALSE) {
# TODO: only input required variables?

# Range of specificities to check
specificities <- seq(from = 0.01, to = 0.99, by = 0.02)
constraints <- c(seq(0.05,0.65,0.1), seq(0.75,0.97,0.02))

# Set specificity constraint and maximize sensitivity
sensitivities <- rep(NA, length(specificities))
for (s in 1:length(specificities)) { # s <- 0.1

model <- trainExplore(output_path = output_path, train_data = train_data, settings_path = settings_path, Maximize = "SENSITIVITY", Specificity = specificities[s], ...)
modelsCurve <- tryCatch({
models <- sapply(constraints, function(constraint) {
print(paste0("Model for specificity: ", as.character(constraint)))

# Fit EXPLORE
model <- Explore::trainExplore(output_path = file.path(output_path, "modelsCurve"), train_data = train_data,
settings_path = settings_path,
file_name = paste0("explore_specificity", as.character(constraint)),
OutputFile = OutputFile,
StartRulelength = StartRulelength, EndRulelength = EndRulelength,
OperatorMethod = OperatorMethod, CutoffMethod = CutoffMethod,
ClassFeature = ClassFeature, PositiveClass = PositiveClass,
FeatureInclude = FeatureInclude, Maximize = "SENSITIVITY",
Accuracy = Accuracy, BalancedAccuracy = BalancedAccuracy, Specificity = constraint,
PrintSettings = PrintSettings, PrintPerformance = PrintPerformance,
Subsumption = Subsumption, BranchBound = BranchBound,
Parallel = Parallel)

return(model)
})
},
finally = warning("No model for specificity.")
)

return(modelsCurve)
}


#' rocCurveExplore
#'
#' @return auc value for EXPLORE
#' @export
#' @importFrom caret confusionMatrix
#' @importFrom pracma trapz
rocCurveExplore <- function(modelsCurve, data, labels) { # labels <- cohort$outcomeCount

# TODO: input checks?

# Combine all these results
curve_TPR <- c(1,0)
curve_FPR <- c(1,0)

for (c in length(modelsCurve):1) {
model <- modelsCurve[c]

# Extract sensitivity from results file
results <- paste(readLines(paste0(output_path, "train_data.result")), collapse="\n")
# Predict using train and test
predict <- tryCatch(as.numeric(Explore::predictExplore(model = model, test_data = data)))

sensitivity <- stringr::str_extract_all(results, "Train-set: .*?\u000A")[[1]]
sensitivity <- stringr::str_extract(results, "SE:.*? ")[[1]]
sensitivity <- stringr::str_remove_all(sensitivity, "SE:")
sensitivity <- stringr::str_replace_all(sensitivity, " ", "")
# Compute metrics
conf_matrix <- table(factor(predict, levels = c(0,1)), factor(labels, levels = c(0,1))) # binary prediction
performance <- caret::confusionMatrix(conf_matrix, positive = '1')

sensitivities[s] <- as.numeric(sensitivity)
curve_TPR[c+2] <- performance$byClass['Sensitivity']
curve_FPR[c+2] <- 1 - performance$byClass['Specificity']
}

auroc <- simple_auc(TPR = rev(sensitivities), FPR = rev(1 - specificities))
# plot(1-specificities, sensitivities)
roc <- pracma::trapz(curve_FPR[length(curve_FPR):1],curve_TPR[length(curve_TPR):1])

return(auroc)
return (roc)
}
1 change: 1 addition & 0 deletions inst/examples/iris.project
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ FeatureRule=
[Constraints]
Maximize=ACCURACY
Accuracy=
BalancedAccuracy=
Specificity=
[Output]
OutputMethod=BEST
Expand Down
1 change: 1 addition & 0 deletions inst/settings/template.project
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ FeatureRule=
[Constraints]
Maximize=ACCURACY
Accuracy=
BalancedAccuracy=
Specificity=
[Output]
OutputMethod=BEST
Expand Down
8 changes: 0 additions & 8 deletions man/Explore-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 0 additions & 25 deletions man/aurocEXPLORE.Rd

This file was deleted.

Loading
Loading