From 28083a1e39d581cad2ff2057828477efa771364a Mon Sep 17 00:00:00 2001 From: jarathomas Date: Wed, 8 May 2024 13:29:53 -0400 Subject: [PATCH] bug squash: fixed indiv probabilities --- src/insilicova/_sampler/sampler.cpp | 10 +++++----- src/insilicova/api.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/insilicova/_sampler/sampler.cpp b/src/insilicova/_sampler/sampler.cpp index b90269e..c9b6487 100644 --- a/src/insilicova/_sampler/sampler.cpp +++ b/src/insilicova/_sampler/sampler.cpp @@ -396,11 +396,11 @@ void Sampler::count_current(double *ptr_indic, } } - // loop over all inidviduals + // loop over all individuals for (int n = 0; n < N; ++n) { // cause of this death int c_current = (int) y_new(n); - // add to toal counts + // add to total counts ptr_count_c[c_current] += 1; // loop over all symptoms for (int s = 0; s < S; ++s) { @@ -515,7 +515,7 @@ void Sampler::trunc_beta2(py::array_t prior_a, double prior_b, } upper = std::min(upper, trunc_max); } - // if range is invalide, use higher case + // if range is invalid, use higher case if (lower >= upper) { proxy_new_pbase(s, c.cast()) = upper; } else { @@ -617,7 +617,7 @@ void Sampler::trunc_beta(py::array_t prior_a, double prior_b, } upper = std::min(upper, trunc_max); } - // if range is invalide, use higher case + // if range is invalid, use higher case if (lower >= upper) { proxy_new_pbase(s.cast(), c) = upper; } else { @@ -893,7 +893,7 @@ void Sampler::fit(py::array_t prior_a, double prior_b, } } - // reinitiate after checking impossible + // re-initiate after checking impossible if (!is_added) { for (int sub = 0; sub < N_sub; ++sub) { int fix = 0; diff --git a/src/insilicova/api.py b/src/insilicova/api.py index 013667d..92f0d83 100644 --- a/src/insilicova/api.py +++ b/src/insilicova/api.py @@ -1327,7 +1327,6 @@ def _sample_posterior(self): N_thin = int((N_gibbs - burn) / (thin)) probbase_gibbs = np.zeros((N_thin, S, C)) levels_gibbs = np.zeros((N_thin, N_level)) - pnb_mean = np.zeros((N, C)) p_gibbs = np.zeros((N_thin, N_sub, C)) pnb_mean = np.zeros((N, C)) naccept = [0] * N_sub @@ -1396,10 +1395,11 @@ def _sample_posterior(self): # self.burnin = self.n_sim / 2 N_gibbs = int(np.trunc(N_gibbs * (2 ** (add - 1)))) burn = 0 - N_thin = int((N_gibbs - burn) / (thin)) + N_thin = int((N_gibbs - burn) / thin) probbase_gibbs = np.zeros((N_thin, S, C)) levels_gibbs = np.zeros((N_thin, N_level)) p_gibbs = np.zeros((N_thin, N_sub, C)) + pnb_mean = np.zeros((N, C)) warnings.warn( f"Not all causes with CSMF > {self.conv_csmf} are " f"convergent.\n Increase chain length with another "