Skip to content

Commit

Permalink
Fixes if individual no features + when age variable not included
Browse files Browse the repository at this point in the history
  • Loading branch information
AniekMarkus committed Aug 9, 2024
1 parent 8a83468 commit 92d2cec
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
20 changes: 15 additions & 5 deletions R/EXPLORE.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ fitExplore <- function(trainData,
ParallelLogger::logInfo(paste0("Used training set size: ", nrow(exploreData)))

# convert age to decades
exploreData['1002'] <- round(exploreData['1002']/10)*10
if ("1002" %in% colnames(exploreData)) {
exploreData['1002'] <- round(exploreData['1002']/10)*10
}

# train model
fit <- tryCatch({
Expand Down Expand Up @@ -302,18 +304,26 @@ convertToExploreData <- function(trainData, modelSettings, search, analysisId, s
}

predictExplore <- function(plpModel, data, cohort) {
ParallelLogger::logInfo("Predict for Explore.")

if (is.na(plpModel$model$fit)) {
ParallelLogger::logError("Explore model is NA.")
return(NULL)
}

# Convert to dense covariates
covariates <- as.data.frame(data$covariateData$covariates)
covariates <- covariates[covariates$covariateId %in% plpModel$model$coefficients,] # Select only covariates included in model
exploreData <- reshape2::dcast(covariates, rowId ~ covariateId, value.var = 'covariateValue', fill = 0)

# exploreData <- merge(cohort[c("rowId", "outcomeCount")], exploreData, by = 'rowId', all.x = TRUE)
# exploreData[is.na(exploreData)] <- 0
# exploreData[c("rowId", "outcomeCount")] <- NULL
exploreData <- merge(cohort[c("rowId", "outcomeCount")], exploreData, by = 'rowId', all.x = TRUE)
exploreData[is.na(exploreData)] <- 0
exploreData[c("rowId", "outcomeCount")] <- NULL

# convert age to decades
exploreData['1002'] <- round(exploreData['1002']/10)*10
if ("1002" %in% colnames(exploreData)) {
exploreData['1002'] <- round(exploreData['1002']/10)*10
}

prediction <- data.frame(rowId=cohort$rowId, value=as.numeric(Explore::predictExplore(model = plpModel$model$fit, test_data = exploreData)))

Expand Down
18 changes: 11 additions & 7 deletions R/RIPPER.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,15 @@ fitRIPPER <- function(trainData,
denseData <- convertToDenseData(trainData, param$variableSelection, search, analysisId, param$saveDirectory)

# convert age to groups
denseData['1002'] <- cut(denseData[['1002']],breaks=c(0,25,50,75,100),labels=c('0-25','25-50','50-75','75-100'))

if ("1002" %in% colnames(denseData)) {
denseData['1002'] <- cut(denseData[['1002']],breaks=c(0,25,50,75,100),labels=c('0-25','25-50','50-75','75-100'))
}

# convert to factors (JRip cannot handle numeric features)
# binary_cols <- sapply(1:ncol(denseData), function(c) all(denseData[[c]] %in% 0:1))
# denseData[binary_cols] <- as.data.frame(sapply(denseData[binary_cols], function(col) factor(col, levels = c(0,1))), stringsAsFactors = TRUE)
denseData <- as.data.frame(sapply(denseData, function(col) factor(col, levels = unique(col))), stringsAsFactors = TRUE)

# train model
fit <- tryCatch({
ParallelLogger::logInfo('Running RIPPER')
Expand Down Expand Up @@ -235,7 +237,9 @@ predictRIPPER <- function(plpModel, data, cohort) {
denseData <- reshape2::dcast(covariates, rowId ~ covariateId, value.var = 'covariateValue', fill = 0)

# convert age to groups
denseData['1002'] <- cut(denseData[['1002']],breaks=c(0,25,50,75,100),labels=c('0-25','25-50','50-75','75-100'))
if ("1002" %in% colnames(denseData)) {
denseData['1002'] <- cut(denseData[['1002']],breaks=c(0,25,50,75,100),labels=c('0-25','25-50','50-75','75-100'))
}

# convert to factors (JRip cannot handle numeric features)
# binary_cols <- sapply(1:ncol(denseData), function(c) all(denseData[[c]] %in% 0:1))
Expand All @@ -246,9 +250,9 @@ predictRIPPER <- function(plpModel, data, cohort) {
addCols <- varSelection[!(varSelection %in% c(colnames(denseData), "outcomeCount"))]
denseData[addCols] <- 0

# denseData <- merge(cohort[c("rowId", "outcomeCount")], denseData, by = 'rowId', all.x = TRUE)
# denseData[is.na(denseData)] <- 0
# denseData[c("rowId", "outcomeCount")] <- NULL
denseData <- merge(cohort[c("rowId", "outcomeCount")], denseData, by = 'rowId', all.x = TRUE)
denseData[is.na(denseData)] <- 0
denseData[c("rowId", "outcomeCount")] <- NULL

prediction <- data.frame(rowId=cohort$rowId, value=as.numeric(stats::predict(plpModel$model$fit, denseData)==1))

Expand Down

0 comments on commit 92d2cec

Please sign in to comment.