From 987ceaa840bc8fc5e8b182efc6597069bd0da700 Mon Sep 17 00:00:00 2001 From: LaraFuhrmann <55209716+LaraFuhrmann@users.noreply.github.com> Date: Thu, 28 Nov 2024 15:21:37 +0100 Subject: [PATCH] enable unique_modus --- .../learn_error_params/cavi.py | 81 ++++++++++--------- .../learn_error_params/run_dpm_mfa.py | 4 +- .../use_quality_scores/cavi.py | 3 +- .../use_quality_scores/run_dpm_mfa.py | 4 +- 4 files changed, 50 insertions(+), 42 deletions(-) diff --git a/viloca/local_haplotype_inference/learn_error_params/cavi.py b/viloca/local_haplotype_inference/learn_error_params/cavi.py index a88036a..d0aec73 100644 --- a/viloca/local_haplotype_inference/learn_error_params/cavi.py +++ b/viloca/local_haplotype_inference/learn_error_params/cavi.py @@ -34,6 +34,7 @@ def multistart_cavi( reads_weights, n_starts, output_dir, + record_history ): pool = mp.Pool(mp.cpu_count()) @@ -51,6 +52,7 @@ def multistart_cavi( reads_weights, start, output_dir, + record_history ), callback=collect_result, ) @@ -72,6 +74,7 @@ def run_cavi( reads_weights, start_id, output_dir, + record_history ): """ @@ -85,16 +88,7 @@ def run_cavi( "alphabet": alphabet, } # N= #reads, K= #components - history_alpha = [] - history_mean_log_pi = [] - history_theta_c = [] - history_theta_d = [] - history_mean_log_theta = [] - history_gamma_a = [] - history_gamma_b = [] - history_mean_log_gamma = [] - history_mean_haplo = [] - history_mean_cluster = [] + history_elbo = [] state_init_dict = initialization.draw_init_state( @@ -130,8 +124,8 @@ def run_cavi( converged = False elbo = 0 state_curr_dict = state_init_dict - k = 0 - while converged is False: + min_number_iterations = 10 + while (converged is False) or (iter < min_number_iterations): if iter <= 1: digamma_alpha_sum = digamma(state_curr_dict["alpha"].sum(axis=0)) @@ -160,7 +154,7 @@ def run_cavi( state_init_dict, state_curr_dict, ) - + if iter % 2 == 0: history_elbo.append(elbo) history_mean_log_pi.append(state_curr_dict["mean_log_pi"]) @@ -184,11 +178,11 @@ def run_cavi( break elif np.abs(elbo - history_elbo[-2]) < 1e-03: converged = True - k += 1 + iter += 1 message = "ELBO converged." exitflag = 0 else: - k = 0 + iter = 0 # if k%10==0: # every 10th parameter set is saved to history state_curr_dict.update({"elbo": elbo}) @@ -198,28 +192,41 @@ def run_cavi( state_curr_dict.update({"elbo": elbo}) - dict_result.update( - { - "exit_message": message, - "exitflag": exitflag, - "n_iterations": iter, - "converged": converged, - "elbo": elbo, - "history_mean_log_theta": history_mean_log_theta, - "history_elbo": history_elbo, - "history_alpha": history_alpha, - "history_mean_log_pi": history_mean_log_pi, - "history_theta_c": history_theta_c, - "history_alpha": history_alpha, - "history_theta_d": history_theta_d, - "history_mean_log_theta": history_mean_log_theta, - "history_gamma_a": history_gamma_a, - "history_gamma_b": history_gamma_b, - "history_mean_log_gamma": history_mean_log_gamma, - "history_mean_haplo": history_mean_haplo, - "history_mean_cluster": history_mean_cluster, - } - ) + if record_history: + dict_result.update( + { + "exit_message": message, + "exitflag": exitflag, + "n_iterations": iter, + "converged": converged, + "elbo": elbo, + "history_mean_log_theta": history_mean_log_theta, + "history_elbo": history_elbo, + "history_alpha": history_alpha, + "history_mean_log_pi": history_mean_log_pi, + "history_theta_c": history_theta_c, + "history_alpha": history_alpha, + "history_theta_d": history_theta_d, + "history_mean_log_theta": history_mean_log_theta, + "history_gamma_a": history_gamma_a, + "history_gamma_b": history_gamma_b, + "history_mean_log_gamma": history_mean_log_gamma, + "history_mean_haplo": history_mean_haplo, + "history_mean_cluster": history_mean_cluster, + } + ) + else: + dict_result.update( + { + "exit_message": message, + "exitflag": exitflag, + "n_iterations": iter, + "converged": converged, + "elbo": elbo, + "history_elbo": history_elbo, + } + ) + # dict_result.update(state_curr_dict) summary = analyze_results.summarize_results( diff --git a/viloca/local_haplotype_inference/learn_error_params/run_dpm_mfa.py b/viloca/local_haplotype_inference/learn_error_params/run_dpm_mfa.py index 45c9141..aff6f3e 100644 --- a/viloca/local_haplotype_inference/learn_error_params/run_dpm_mfa.py +++ b/viloca/local_haplotype_inference/learn_error_params/run_dpm_mfa.py @@ -29,7 +29,7 @@ def gzip_file(f_name): return f_out.name -def main(freads_in, fref_in, output_dir, n_starts, K, alpha0, alphabet="ACGT-", unique_modus=False, record_history=False): +def main(freads_in, fref_in, output_dir, n_starts, K, alpha0, alphabet="ACGT-", unique_modus=True, record_history=False): window_id = freads_in.split("/")[-1][:-4] # freads_in is absolute path @@ -39,7 +39,7 @@ def main(freads_in, fref_in, output_dir, n_starts, K, alpha0, alphabet="ACGT-", # Read in reads reference_seq, ref_id = preparation.load_reference_seq(fref_in) reference_binary = preparation.reference2binary(reference_seq, alphabet) - reads_list = preparation.load_fasta2reads_list(freads_in, alphabet, False) + reads_list = preparation.load_fasta2reads_list(freads_in, alphabet, unique_modus) reads_seq_binary, reads_weights = preparation.reads_list_to_array(reads_list) if n_starts >1: diff --git a/viloca/local_haplotype_inference/use_quality_scores/cavi.py b/viloca/local_haplotype_inference/use_quality_scores/cavi.py index 7ff06a3..e78330d 100644 --- a/viloca/local_haplotype_inference/use_quality_scores/cavi.py +++ b/viloca/local_haplotype_inference/use_quality_scores/cavi.py @@ -113,7 +113,8 @@ def run_cavi( converged = False elbo = 0 state_curr_dict = state_init_dict - while converged is False: + min_number_iterations = 10 + while (converged is False) or (iter < min_number_iterations): if iter <= 1: digamma_alpha_sum = digamma(state_curr_dict["alpha"].sum(axis=0)) diff --git a/viloca/local_haplotype_inference/use_quality_scores/run_dpm_mfa.py b/viloca/local_haplotype_inference/use_quality_scores/run_dpm_mfa.py index 8f2b3e1..e2678bf 100644 --- a/viloca/local_haplotype_inference/use_quality_scores/run_dpm_mfa.py +++ b/viloca/local_haplotype_inference/use_quality_scores/run_dpm_mfa.py @@ -33,7 +33,7 @@ def main( K, alpha0, alphabet="ACGT-", - unique_modus=False, + unique_modus=True, convergence_threshold=1e-03, record_history=False, ): @@ -47,7 +47,7 @@ def main( reference_binary, ref_id = preparation.load_reference_seq(fref_in, alphabet) reads_list, qualities = preparation.load_fasta_and_qualities( - freads_in, fname_qualities, alphabet, False + freads_in, fname_qualities, alphabet, unique_modus ) reads_seq_binary, reads_weights = preparation.reads_list_to_array(reads_list) reads_log_error_proba = preparation.compute_reads_log_error_proba(