Skip to content

Commit

Permalink
enable unique_modus
Browse files Browse the repository at this point in the history
  • Loading branch information
LaraFuhrmann committed Nov 28, 2024
1 parent 018c036 commit 987ceaa
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 42 deletions.
81 changes: 44 additions & 37 deletions viloca/local_haplotype_inference/learn_error_params/cavi.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def multistart_cavi(
reads_weights,
n_starts,
output_dir,
record_history
):

pool = mp.Pool(mp.cpu_count())
Expand All @@ -51,6 +52,7 @@ def multistart_cavi(
reads_weights,
start,
output_dir,
record_history
),
callback=collect_result,
)
Expand All @@ -72,6 +74,7 @@ def run_cavi(
reads_weights,
start_id,
output_dir,
record_history
):

"""
Expand All @@ -85,16 +88,7 @@ def run_cavi(
"alphabet": alphabet,
} # N= #reads, K= #components

history_alpha = []
history_mean_log_pi = []
history_theta_c = []
history_theta_d = []
history_mean_log_theta = []
history_gamma_a = []
history_gamma_b = []
history_mean_log_gamma = []
history_mean_haplo = []
history_mean_cluster = []

history_elbo = []

state_init_dict = initialization.draw_init_state(
Expand Down Expand Up @@ -130,8 +124,8 @@ def run_cavi(
converged = False
elbo = 0
state_curr_dict = state_init_dict
k = 0
while converged is False:
min_number_iterations = 10
while (converged is False) or (iter < min_number_iterations):

if iter <= 1:
digamma_alpha_sum = digamma(state_curr_dict["alpha"].sum(axis=0))
Expand Down Expand Up @@ -160,7 +154,7 @@ def run_cavi(
state_init_dict,
state_curr_dict,
)

if iter % 2 == 0:
history_elbo.append(elbo)
history_mean_log_pi.append(state_curr_dict["mean_log_pi"])
Expand All @@ -184,11 +178,11 @@ def run_cavi(
break
elif np.abs(elbo - history_elbo[-2]) < 1e-03:
converged = True
k += 1
iter += 1
message = "ELBO converged."
exitflag = 0
else:
k = 0
iter = 0

# if k%10==0: # every 10th parameter set is saved to history
state_curr_dict.update({"elbo": elbo})
Expand All @@ -198,28 +192,41 @@ def run_cavi(

state_curr_dict.update({"elbo": elbo})

dict_result.update(
{
"exit_message": message,
"exitflag": exitflag,
"n_iterations": iter,
"converged": converged,
"elbo": elbo,
"history_mean_log_theta": history_mean_log_theta,
"history_elbo": history_elbo,
"history_alpha": history_alpha,
"history_mean_log_pi": history_mean_log_pi,
"history_theta_c": history_theta_c,
"history_alpha": history_alpha,
"history_theta_d": history_theta_d,
"history_mean_log_theta": history_mean_log_theta,
"history_gamma_a": history_gamma_a,
"history_gamma_b": history_gamma_b,
"history_mean_log_gamma": history_mean_log_gamma,
"history_mean_haplo": history_mean_haplo,
"history_mean_cluster": history_mean_cluster,
}
)
if record_history:
dict_result.update(
{
"exit_message": message,
"exitflag": exitflag,
"n_iterations": iter,
"converged": converged,
"elbo": elbo,
"history_mean_log_theta": history_mean_log_theta,
"history_elbo": history_elbo,
"history_alpha": history_alpha,
"history_mean_log_pi": history_mean_log_pi,
"history_theta_c": history_theta_c,
"history_alpha": history_alpha,
"history_theta_d": history_theta_d,
"history_mean_log_theta": history_mean_log_theta,
"history_gamma_a": history_gamma_a,
"history_gamma_b": history_gamma_b,
"history_mean_log_gamma": history_mean_log_gamma,
"history_mean_haplo": history_mean_haplo,
"history_mean_cluster": history_mean_cluster,
}
)
else:
dict_result.update(
{
"exit_message": message,
"exitflag": exitflag,
"n_iterations": iter,
"converged": converged,
"elbo": elbo,
"history_elbo": history_elbo,
}
)


# dict_result.update(state_curr_dict)
summary = analyze_results.summarize_results(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def gzip_file(f_name):
return f_out.name


def main(freads_in, fref_in, output_dir, n_starts, K, alpha0, alphabet="ACGT-", unique_modus=False, record_history=False):
def main(freads_in, fref_in, output_dir, n_starts, K, alpha0, alphabet="ACGT-", unique_modus=True, record_history=False):

window_id = freads_in.split("/")[-1][:-4] # freads_in is absolute path

Expand All @@ -39,7 +39,7 @@ def main(freads_in, fref_in, output_dir, n_starts, K, alpha0, alphabet="ACGT-",
# Read in reads
reference_seq, ref_id = preparation.load_reference_seq(fref_in)
reference_binary = preparation.reference2binary(reference_seq, alphabet)
reads_list = preparation.load_fasta2reads_list(freads_in, alphabet, False)
reads_list = preparation.load_fasta2reads_list(freads_in, alphabet, unique_modus)
reads_seq_binary, reads_weights = preparation.reads_list_to_array(reads_list)

if n_starts >1:
Expand Down
3 changes: 2 additions & 1 deletion viloca/local_haplotype_inference/use_quality_scores/cavi.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def run_cavi(
converged = False
elbo = 0
state_curr_dict = state_init_dict
while converged is False:
min_number_iterations = 10
while (converged is False) or (iter < min_number_iterations):

if iter <= 1:
digamma_alpha_sum = digamma(state_curr_dict["alpha"].sum(axis=0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def main(
K,
alpha0,
alphabet="ACGT-",
unique_modus=False,
unique_modus=True,
convergence_threshold=1e-03,
record_history=False,
):
Expand All @@ -47,7 +47,7 @@ def main(
reference_binary, ref_id = preparation.load_reference_seq(fref_in, alphabet)

reads_list, qualities = preparation.load_fasta_and_qualities(
freads_in, fname_qualities, alphabet, False
freads_in, fname_qualities, alphabet, unique_modus
)
reads_seq_binary, reads_weights = preparation.reads_list_to_array(reads_list)
reads_log_error_proba = preparation.compute_reads_log_error_proba(
Expand Down

0 comments on commit 987ceaa

Please sign in to comment.