-
Notifications
You must be signed in to change notification settings - Fork 0
/
ROC_v2.R
132 lines (115 loc) · 3.97 KB
/
ROC_v2.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# ROC
library(Matrix)
library(tibble)
library(dplyr)
library(ggplot2)
library(nett)
run_sims = T
lvm_alt = T #whether to use DCLVM as the alternative model
nruns = 1000 # num of simulated adjacency matrices
DC <- T # whether the model has the degree-correction
balanced <- F # whether community sizes are balanced
K <- 4 # number of communities
n <- 10000
# n <- 2000
oir <- 0.1 # out-in ratio
if (lvm_alt) {
lambda <- 8 # average degrees for DCLVM alternative
}else{
lambda <- 15 # average degrees for DCSBM alternative
}
K_H0 <- K # number of communities in H0
# K_H1 <- K+1 # number of communities in Ha (main)
# K_H1 <- K-1 # number of communities in Ha (appendix)
K_H1 <- K # number of communities in Ha (appendix)
if (DC) {
theta <- EnvStats::rpareto(n, 3/4, 4)
} else {
theta <- 1
}
if (balanced) {
pri_func = function(K) rep(1, K)
# pri0 = rep(1, K_H0)
# pri1 = rep(1, K_H1)
} else {
pri_func = function(K) 1:K
# pri0 = 1:K_H0
# pri1 = 1:K_H1
}
file_tag = gsub("\\.","p",sprintf("roc6_%dvs%d_n%d_lam%d_oir%2.2f_nruns%d_%s%s_%s_updated",
K_H0, K_H1, n, lambda, oir, nruns, ifelse(DC, "dc", ""), ifelse(lvm_alt,"lvm","sbm"), ifelse(balanced,"bal","unbal")))
apply_methods = function(A) {
z0 = spec_clust(A, K_H0)
z0p = spec_clust(A, K_H0+1)
z1 = spec_clust(A, K_H1)
tibble::tribble(
~method, ~tstat, ~twosided,
"FNAC+", nac_test(A, K_H0, z=z0, y=z0p)$stat, F,
"SNAC+", snac_test(A, K_H0, z=z0)$stat, F,
"AS", adj_spec_test(A, K_H0, z=z0), F,
"LR", eval_dcsbm_loglr(A, cbind(z0, z1), poi=T), F
)
}
gen_lvm = function(n, Ktru, lambda, theta) {
d = Ktru
labels = sample(Ktru, n, replace=T, prob= pri_func(Ktru))
mu = diag(Ktru)
z = 2*mu[labels, ] + matrix(rnorm(n*d), n)
sample_dclvm(z, lambda, theta)
}
gen_null_data = function() {
sample_dcpp(n, lambda, K_H0, oir = oir, theta, pri = pri_func(K_H0))$adj
# gen_lvm(n, K_H0, lambda, theta)
}
gen_alt_data = function() {
if (lvm_alt) {
return( gen_lvm(n, K_H1, lambda, theta) )
} else {
return( sample_dcpp(n, lambda, K_H1, oir = oir, theta, pri = pri_func(K_H1))$adj )
}
}
if (run_sims) {
roc_res = simulate_roc(apply_methods,
gen_null_data = gen_null_data,
gen_alt_data = gen_alt_data,
nruns = nruns, core_count = 24)
printf('time = %3.3f', roc_res$elapsed_time)
saveRDS(roc_res, paste0(file_tag, ".rds"))
} else {
roc_res = readRDS(paste0(file_tag, ".rds"))
}
plot_roc <- function(roc_results, method_names=NULL) {
if (!is.null(method_names)){
roc_results = roc_results %>%
dplyr::mutate(method = factor(method, levels = method_names))
} else {
roc_results = roc_results %>%
dplyr::mutate(method = factor(method))
}
cbbPalette <- c("#000000", "#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7")
p = roc_results %>%
filter(method != "LR") %>% # leave out LR when K_H0 = K_H1
ggplot2::ggplot(ggplot2::aes(x = FPR, y = TPR, color = method, linetype = method)) +
ggplot2::scale_colour_manual(values=cbbPalette)+
ggplot2::geom_line(size=2) +
ggplot2::theme_bw() +
ggplot2::theme(text = ggplot2::element_text(size=18))+
ggplot2::coord_fixed(ratio = 1) +
ggplot2::geom_abline(intercept =0 , slope = 1, linetype="dashed") +
ggplot2::scale_x_continuous(limits = c(0,1.01), expand = c(0,0)) +
ggplot2::scale_y_continuous(limits = c(0,1.01),expand = c(0,0)) +
ggplot2::theme(
legend.background = ggplot2::element_blank(),
legend.title = ggplot2::element_blank(),
legend.position = c(0.8, 0.2),
legend.text = ggplot2::element_text(size=25),
text = ggplot2::element_text(size=26)
) +
ggplot2::guides(colour = ggplot2::guide_legend(keywidth = 4, keyheight = 1)
)
p
}
plot_res <- roc_res$roc
plot_res$method <- factor(plot_res$method, levels = c("FNAC+","SNAC+","AS", "LR"))
plot_roc(plot_res)
# ggsave(paste0(file_tag, ".pdf"), width = 8.22, height = 6.94)