Skip to content

Commit

Permalink
R6 class for module
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Aug 14, 2024
1 parent 93ae511 commit 67041d9
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 176 deletions.
1 change: 0 additions & 1 deletion .Rprofile

This file was deleted.

294 changes: 119 additions & 175 deletions Main.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,185 +14,129 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Module methods -------------------------
getModuleInfo <- function() {
checkmate::assert_file_exists("MetaData.json")
return(ParallelLogger::loadSettingsFromJson("MetaData.json"))
}
PatientLevelPredictionValidationModule <- R6::R6Class(
classname = "PatientLevelPredictionValidationModule",
inherit = Strategus::StrategusModule,
public = list(
tablePrefix = "plp",
initialize = function() {
super$initialize()
},
execute = function(connectionDetails, analysisSpecifications, executionSettings) {
super$execute(connectionDetails, analysisSpecifications, executionSettings)
checkmate::assertClass(executionSettings, "CdmExecutionSettings")

getModelInfo <- function(strategusOutputPath) {
modelDesigns <- list.files(strategusOutputPath, pattern = "modelDesign.json",
recursive = TRUE, full.names = TRUE)
model <- NULL
for (modelFilePath in modelDesigns) {
directory <- dirname(modelFilePath)
modelDesign <- ParallelLogger::loadSettingsFromJson(modelFilePath)

if (is.null(model)) {
model <- data.frame(
target_id = modelDesign$targetId,
outcome_id = modelDesign$outcomeId,
modelPath = directory)
} else {
model <- rbind(model,
data.frame(
target_id = modelDesign$targetId,
outcome_id = modelDesign$outcomeId,
modelPath = directory))
}
}

models <- model %>%
dplyr::group_by(.data$target_id, .data$outcome_id) %>%
dplyr::summarise(modelPath = list(.data$modelPath), .groups = "drop")
return(models)
}
private$.message("Executing PatientLevelPrediction Validation")
jobContext <- private$jobContext
# check the model locations are valid and apply model

getSharedResourceByClassName <- function(sharedResources, className) {
returnVal <- NULL
for (i in 1:length(sharedResources)) {
if (className %in% class(sharedResources[[i]])) {
returnVal <- sharedResources[[i]]
break
}
}
invisible(returnVal)
}
workFolder <- jobContext$moduleExecutionSettings$workSubFolder
resultsFolder <- jobContext$moduleExecutionSettings$resultsSubFolder
upperResultDir <- dirname(workFolder)
modelTransferFolder <- sort(dir(upperResultDir,
pattern = "ModelTransferModule"
), decreasing = TRUE)[1]

# this updates the cohort table details in covariates
updateCovariates <- function(plpModel, cohortTable, cohortDatabaseSchema){

covSettings <- plpModel$modelDesign$covariateSettings
# if a single setting make it into a list to force consistency
if (inherits(covSettings, 'covariateSettings')) {
covSettings <- list(covSettings)
}

for (i in 1:length(covSettings)) {
if ('cohortTable' %in% names(covSettings[[i]])) {
covSettings[[i]]$cohortTable <- cohortTable
}
if ('cohortDatabaseSchema' %in% names(covSettings[[i]])) {
covSettings[[i]]$cohortDatabaseSchema <- cohortDatabaseSchema
# hack to use output folder for model transfer
modelSaveLocation <- file.path(upperResultDir, modelTransferFolder, "models")
modelInfo <- getModelInfo(modelSaveLocation)

designs <- list()
for (i in seq_len(nrow(modelInfo))) {
df <- modelInfo[i, ]

design <- PatientLevelPrediction::createValidationDesign(
targetId = df$target_id[1],
outcomeId = df$outcome_id[1],
plpModelList = as.list(df$modelPath),
restrictPlpDataSettings = jobContext$settings[[1]]$restrictPlpDataSettings,
populationSettings = jobContext$settings[[1]]$populationSettings
)
designs <- c(designs, design)
}
databaseNames <- c()
databaseNames <- c(
databaseNames,
paste0(jobContext$moduleExecutionSettings$connectionDetailsReference)
)

databaseDetails <- PatientLevelPrediction::createDatabaseDetails(
connectionDetails = jobContext$moduleExecutionSettings$connectionDetails,
cdmDatabaseSchema = jobContext$moduleExecutionSettings$cdmDatabaseSchema,
cohortDatabaseSchema = jobContext$moduleExecutionSettings$workDatabaseSchema,
cdmDatabaseName = jobContext$moduleExecutionSettings$connectionDetailsReference,
cdmDatabaseId = jobContext$moduleExecutionSettings$databaseId,
tempEmulationSchema = jobContext$moduleExecutionSettings$tempEmulationSchema,
cohortTable = jobContext$moduleExecutionSettings$cohortTableNames$cohortTable,
outcomeDatabaseSchema = jobContext$moduleExecutionSettings$workDatabaseSchema,
outcomeTable = jobContext$moduleExecutionSettings$cohortTableNames$cohortTable
)

PatientLevelPrediction::validateExternal(
validationDesignList = designs,
databaseDetails = databaseDetails,
logSettings = PatientLevelPrediction::createLogSettings(
verbosity = "INFO",
logName = "validatePLP"
),
outputFolder = workFolder
)

sqliteConnectionDetails <- DatabaseConnector::createConnectionDetails(
dbms = "sqlite",
server = file.path(workFolder, "sqlite", "databaseFile.sqlite")
)

PatientLevelPrediction::extractDatabaseToCsv(
connectionDetails = sqliteConnectionDetails,
databaseSchemaSettings = PatientLevelPrediction::createDatabaseSchemaSettings(
resultSchema = "main",
tablePrefix = "",
targetDialect = "sqlite",
tempEmulationSchema = NULL
),
csvFolder = resultsFolder,
fileAppend = NULL
)
},
createModuleSpecifications = function(settings) {
specifications <- super$createModuleSpecifications(settings)
return(specifications)
}
}

plpModel$modelDesign$covariateSettings <- covSettings

return(plpModel)
}
),
private = list(
getModelInfo = function(modelLocations) {
modelDesigns <- list.files(modelLocations,
pattern = "modelDesign.json",
recursive = TRUE, full.names = TRUE
)
model <- NULL
for (modelFilePath in modelDesigns) {
directory <- dirname(modelFilePath)
modelDesign <- ParallelLogger::loadSettingsFromJson(modelFilePath)

createCohortDefinitionSetFromJobContext <- function(sharedResources, settings) {
cohortDefinitions <- list()
if (length(sharedResources) <= 0) {
stop("No shared resources found")
}
cohortDefinitionSharedResource <- getSharedResourceByClassName(sharedResources = sharedResources,
className = "CohortDefinitionSharedResources")
if (is.null(cohortDefinitionSharedResource)) {
stop("Cohort definition shared resource not found!")
}
cohortDefinitions <- cohortDefinitionSharedResource$cohortDefinitions
if (length(cohortDefinitions) <= 0) {
stop("No cohort definitions found")
}
cohortDefinitionSet <- CohortGenerator::createEmptyCohortDefinitionSet()
for (i in 1:length(cohortDefinitions)) {
cohortJson <- cohortDefinitions[[i]]$cohortDefinition
cohortDefinitionSet <- rbind(cohortDefinitionSet, data.frame(
cohortId = as.integer(cohortDefinitions[[i]]$cohortId),
cohortName = cohortDefinitions[[i]]$cohortName,
json = cohortJson,
stringsAsFactors = FALSE
))
}
return(cohortDefinitionSet)
}
if (is.null(model)) {
model <- data.frame(
targetId = modelDesign$targetId,
outcomeId = modelDesign$outcomeId,
modelPath = directory
)
} else {
model <- rbind(
model,
data.frame(
targetId = modelDesign$targetId,
outcomeId = modelDesign$outcomeId,
modelPath = directory
)
)
}
}

# Module methods -------------------------
execute <- function(jobContext) {
library(PatientLevelPrediction)
rlang::inform("Validating inputs")
inherits(jobContext, 'list')

if (is.null(jobContext$settings)) {
stop("Analysis settings not found in job context")
}
if (is.null(jobContext$sharedResources)) {
stop("Shared resources not found in job context")
}
if (is.null(jobContext$moduleExecutionSettings)) {
stop("Execution settings not found in job context")
}

workFolder <- jobContext$moduleExecutionSettings$workSubFolder
resultsFolder <- jobContext$moduleExecutionSettings$resultsSubFolder

rlang::inform("Executing PLP Validation")
moduleInfo <- getModuleInfo()

# find where cohortDefinitions are as sharedResources is a list
cohortDefinitionSet <- createCohortDefinitionSetFromJobContext(
sharedResources = jobContext$sharedResources,
settings = jobContext$settings
)

# check the model locations are valid and apply model
upperWorkDir <- dirname(workFolder)
modelTransferFolder <- sort(dir(upperWorkDir, pattern = 'ModelTransferModule'), decreasing = T)[1]

modelSaveLocation <- file.path( upperWorkDir, modelTransferFolder, 'models') # hack to use work folder for model transfer
modelInfo <- getModelInfo(modelSaveLocation)

designs <- list()
for (i in seq_len(nrow(modelInfo))) {
df <- modelInfo[i, ]

design <- PatientLevelPrediction::createValidationDesign(
targetId = df$target_id[1],
outcomeId = df$outcome_id[1],
plpModelList = as.list(df$modelPath),
restrictPlpDataSettings = jobContext$settings[[1]]$restrictPlpDataSettings,
populationSettings = jobContext$settings[[1]]$populationSettings
)
designs <- c(designs, design)
}
databaseNames <- c()
databaseNames <- c(databaseNames, paste0(jobContext$moduleExecutionSettings$connectionDetailsReference))

databaseDetails <- PatientLevelPrediction::createDatabaseDetails(
connectionDetails = jobContext$moduleExecutionSettings$connectionDetails,
cdmDatabaseSchema = jobContext$moduleExecutionSettings$cdmDatabaseSchema,
cohortDatabaseSchema = jobContext$moduleExecutionSettings$workDatabaseSchema,
cdmDatabaseName = jobContext$moduleExecutionSettings$connectionDetailsReference,
cdmDatabaseId = jobContext$moduleExecutionSettings$databaseId,
tempEmulationSchema = jobContext$moduleExecutionSettings$tempEmulationSchema,
cohortTable = jobContext$moduleExecutionSettings$cohortTableNames$cohortTable,
outcomeDatabaseSchema = jobContext$moduleExecutionSettings$workDatabaseSchema,
outcomeTable = jobContext$moduleExecutionSettings$cohortTableNames$cohortTable
)

PatientLevelPrediction::validateExternal(
validationDesignList = designs,
databaseDetails = databaseDetails,
logSettings = PatientLevelPrediction::createLogSettings(verbosity = 'INFO', logName = 'validatePLP'),
outputFolder = workFolder
)

sqliteConnectionDetails <- DatabaseConnector::createConnectionDetails(
dbms = 'sqlite',
server = file.path(workFolder, "sqlite", "databaseFile.sqlite")
)

PatientLevelPrediction::extractDatabaseToCsv(
connectionDetails = sqliteConnectionDetails,
databaseSchemaSettings = PatientLevelPrediction::createDatabaseSchemaSettings(
resultSchema = 'main',
tablePrefix = '',
targetDialect = 'sqlite',
tempEmulationSchema = NULL
),
csvFolder = resultsFolder,
fileAppend = NULL
models <- model %>%
dplyr::group_by(.data$targetId, .data$outcomeId) %>%
dplyr::summarise(modelPath = list(.data$modelPath), .groups = "drop")
return(models)
}
)
}
)

0 comments on commit 67041d9

Please sign in to comment.