Skip to content

Latest commit

 

History

History
296 lines (261 loc) · 9.22 KB

paper_pub.org

File metadata and controls

296 lines (261 loc) · 9.22 KB

Paper Publications


Experiments

setwd("/Users/fred/Documents/projects/latvar/")
source("LatentConfounderBNlearnv2.R")
load("final_model_nolvp_novp.RData", verbose = T)
library(doParallel)

cl <- makeCluster(5) ## for multi-threading
registerDoParallel(cl)

train = datalist$data_noisy
trainlv = train %>% dplyr::select(-U1.out, -U2.out)


blacklist = rbind(data.frame(from = "Z", to = colnames(train)),
		      data.frame(from = colnames(train), to = "U1.out"),
		      data.frame(from = colnames(train), to = "U2.out")
		      )
blacklistlv = rbind(data.frame(from = "Z", to = colnames(trainlv)))

##bootn = c(5, 10, 15, 20, 30, 40, 50)
bootn = c(30, 40, 50)
niter = 20

for(nn in bootn){
    message("Bootstrap # = ", nn)
    bootdir = file.path("boottestv2",
                        paste0("boot_", nn))
    dir.create(bootdir, recursive = T)
    message("running full data")
    set.seed(213)
    ens_alldata_boot = getEnsemble2(train,
                                    blacklist = blacklist,
                                    restart = 100,
                                    Nboot = nn,
                                    prior = "vsp",
                                    score = "bge",
                                    algorithm = 'hc',
                                    parallel = TRUE
                                    )
    message("running with missing data")
    ens_missing_boot = getEnsemble2(
        trainlv,
        blacklist = blacklistlv,
        restart = 100,
        Nboot = nn,
        prior = "vsp",
        score = "bge",
        algorithm = 'hc',
        parallel = TRUE
    )
    save(ens_alldata_boot,
         ens_missing_boot,
         file = file.path(
             bootdir, 
             paste0('ensembles_boot_', nn, ".RData")   
         )
         )
    message("run latent discovery")
    ##    try({
    latdiscv <- latentDiscovery(
      ens_missing_boot,
      nItera=niter,
      data = trainlv,
      "Z",
      workpath=bootdir,
      freqCutoff = 0.01,
      maxpath = 1,
      alpha = 0.05,
      scale. = TRUE,
      method = "linear",
      latent_iterations = 100,
      include_downstream = TRUE,
      truecoef = datalist$coef %>% filter(output=="Z"),
      truelatent=datalist$test %>% dplyr::select("U1.out","U2.out"),
      testdata=datalist$test_noisy %>% dplyr::select(-U1.out, -U2.out),
      include_output = TRUE,
      multiple_comparison_correction = T,
          debug = FALSE,
          parallel = TRUE,
          Nsamples = 10
        )
##    })
}

Plots

R2 plot

## load resulste
library(tidyverse)
##latVar = readRDS("/Users/fred/Documents/projects/latvar/pca_ens50/latVars.RDS")
latVar = readRDS("/Users/fred/Documents/projects/latvar/boottestv2/boot_50/latVars.RDS")
library("latex2exp")

## mylabels = c(
##     "R2a_U1.out" = ("$R^2$ for $U_1$"), 
##     "R2a_U2.out" = ("$R^2$ for $U_2$")
## )

mylabels = c(
     "R2a_U1.out" = TeX("$U_1$"), 
    "R2a_U2.out" = TeX("$U_2$")
)

tstrep = latVar$details$Diagnostics %>%
    dplyr::select(Iteration,
                  R2a_U1.out,
                  R2a_U2.out) %>%
    gather(var, val, -Iteration) %>%
    mutate(var = factor(var))

(ggp = tstrep %>%
    ggplot(aes(x = Iteration, y = val, colour = var, fill = var)) +
    ggbeeswarm::geom_beeswarm()  + geom_smooth(span = 0.8, alpha = 0.1) +
    ggpubr::theme_pubclean() +    ylab(TeX("R^2"))  +
    ##ggtitle(TeX("Latent Variables Prediction $R^2$")) +
    scale_colour_manual(values = c("skyblue", "orange"),
                        labels = expression(U[1], U[2])) +
    scale_fill_manual(values = c("skyblue", "orange"),
                      labels = expression(U[1], U[2])) + 
    theme(strip.text=element_text(size = rel(1.5)),
          legend.text=element_text(size = rel(1.1)),
          legend.position = c(0.5, 0.2),
          legend.direction ="horizontal", 
          axis.title=element_text(size = rel(1)),
          plot.title = element_text(size = rel(1.5), hjust=0.5),
          legend.title=element_blank( )
          )
)


pdf("./images/fig_paper_r2lat_rep.pdf", width = 5, height = 5)
print(ggp)
dev.off()
ggp

./images/fig_paper_r2lat_rep.png

Coef plot

plotCoefError = function(latvar, ens_alldata, out ='Z',
                         drivers = c("V1", "V2")){
    errallresp = latvar$details$Diagnostics %>%
        dplyr::select(Repeat, Iteration, Coef) %>%
        unnest(Coef)
    errresp = errallresp %>%
        filter(True_coef >= 0.02)
    errresp = errresp %>%
        mutate(Error = abs(error)) %>% 
        dplyr::select(
                   Iteration, Variable = input, Error
               )
    rmsedfresp = latvar$details$Diagnostics %>%
        dplyr::select(Iteration, Error_rmse) %>%
        mutate(Variable = "RMSE") %>%
        dplyr::select(Iteration, Variable, Error = Error_rmse)
    allerrorresp = rbind(errresp, rmsedfresp)
    ## get performance of model with all data
    score_alldata = getScores(ens=ens_alldata,output=out,
                              truecoef=datalist$coef %>%
                                  filter(output=="Z"),
                              lat_estimates=NULL)
    error_alldata = score_alldata$Coef[[1]] %>% filter(input %in% drivers)
    rmse_alldata = score_alldata$Error_rmse
    allerror_alldata = rbind(
        error_alldata %>%
        mutate(error = abs(error)) %>%
        dplyr::select(Variable = input, Error = error),
        tribble(
            ~ Variable, ~ Error,
            'RMSE', rmse_alldata
        )
    )
    ## plot
    (ggp = allerrorresp %>%
         ggplot(aes(x = Iteration, y = Error, colour = Variable, fill = Variable)) + 
         ggbeeswarm::geom_beeswarm() +
         geom_smooth(se = TRUE) +
         geom_hline(data = allerror_alldata,
                    aes(yintercept = Error, colour = Variable), linetype = 'dashed') +
         ## annotate("text", x = 0, y = max(allerror_alldata$Error), label = "Obs. All Variables",
         ##          vjust = -1) + 
       ylab("Error in Coefficients") +
       ylim(0, 1) + 
         ggpubr::theme_pubclean() +    
         theme(strip.text=element_text(size = rel(1)),
               axis.title=element_text(size = rel(1)),
               legend.position = c(0.5, 0.7),
               legend.text = element_text(size = rel(1)), 
               legend.direction='horizontal', 
               plot.title = element_text(size = rel(1.5), hjust=0.5))  
    )
    return(ggp)
}

load("/Users/fred/Documents/projects/latvar/boottestv2/boot_50/ensembles_boot_50.RData", verbose = TRUE)


##graphics.off()

pdf("images/fig_paper_errors.pdf", width = 5, height = 5)
plotCoefError(latVar, ens_alldata_boot)
dev.off()

plotCoefError(latVar, ens_alldata_boot)


./images/fig_paper_errors.png

Network plots

full data

plot.bnlearn_ens(ens_alldata_boot,"Z",0, cutoff = 0.4,
                 maxpath=2,nodesep=0.001,
                 freqth = 0, 
                 sep = 0.1,
                 coef_filter = 0.5,
                 label_pad = 3, 
	layout = 'dot', edge_labels = "coef",
	edgeweights=T,
	edge_color=tribble(~inp,~out,~color,"V1","Z","red","V2","Z","red"),
        filename = "images/estimated_network_fulldata.pdf",
        saveToFile = TRUE, 
	fill = list("U.*" = 'darksalmon',
		    "Z" = 'yellow',
		    "^V1$|V2$" = 'skyblue'))

images/fig_graph_fulldata.png

with missing data

plot.bnlearn_ens(ens_missing_boot,"Z",0, cutoff = 0.4,
                 maxpath=2,nodesep=0.001,
                 freqth = 0, 
                 sep = 0.1,
                 coef_filter = 0.5,
                 label_pad = 3, 
	layout = 'dot', edge_labels = "coef",
	edgeweights=T,
	edge_color=tribble(~inp,~out,~color,"V1","Z","red","V2","Z","red"
			   ),
        filename = "images/estimated_network_missingdata.pdf",
        saveToFile = F, 
	fill = list("U.*" = 'darksalmon',
		    "Z" = 'yellow',
		    "^V1$|V2$" = 'skyblue'))

images/fig_graph_missing.png

With infered latent variables

plot.bnlearn_ens(latVar$details$Diagnostics$Ensemble[[21]],"Z",0, cutoff = 0.4,
                 maxpath=2,nodesep=0.01,
                 freqth = 0, 
                 sep = 0.1, 
                 layout = 'dot', edge_labels = "coef",
                 coef_filter = 0.5,
                 label_pad = 3, 
	edgeweights=T,
	edgelabelsFilter = 0.5, 
	edge_color=tribble(~inp,~out,~color,"V1","Z","red","V2","Z","red"
			   ),
        filename = "images/estimated_network_infered.pdf",
        saveToFile = F, 
	fill = list("U.*" = 'darksalmon',
		    "Z" = 'yellow',
		    "^V1$|V2$" = 'skyblue'))

images/fig_graph_infered.png