diff --git a/viloca/local_haplotype_inference/use_quality_scores/analyze_results.py b/viloca/local_haplotype_inference/use_quality_scores/analyze_results.py index ed0a6b2..4093ff2 100644 --- a/viloca/local_haplotype_inference/use_quality_scores/analyze_results.py +++ b/viloca/local_haplotype_inference/use_quality_scores/analyze_results.py @@ -31,7 +31,7 @@ def haplotypes_to_fasta(state_curr_dict, output_dir): ave_reads = state_curr_dict["weight" + str(k)] if ave_reads==0: # this haplotype will not be reported as there are no reads - # supporting it. + # supporting it. continue head = ( diff --git a/viloca/local_haplotype_inference/use_quality_scores/cavi.py b/viloca/local_haplotype_inference/use_quality_scores/cavi.py index 3ff3c58..bfce3d9 100644 --- a/viloca/local_haplotype_inference/use_quality_scores/cavi.py +++ b/viloca/local_haplotype_inference/use_quality_scores/cavi.py @@ -114,8 +114,9 @@ def run_cavi( iter = 0 converged = False elbo = 0 + min_number_iterations = 10 state_curr_dict = state_init_dict - while converged is False: + while (converged is False) or (iter < min_number_iterations): if iter <= 1: digamma_alpha_sum = digamma(state_curr_dict["alpha"].sum(axis=0)) @@ -155,7 +156,7 @@ def run_cavi( break elif (history_elbo[-2] > elbo) and np.abs(elbo - history_elbo[-2]) > 1e-08: exit_message = "Error: ELBO is decreasing." - break + #break elif np.abs(elbo - history_elbo[-2]) < convergence_threshold: converged = True exit_message = "ELBO converged." diff --git a/viloca/local_haplotype_inference/use_quality_scores/run_dpm_mfa.py b/viloca/local_haplotype_inference/use_quality_scores/run_dpm_mfa.py index b62a84f..eb0c692 100644 --- a/viloca/local_haplotype_inference/use_quality_scores/run_dpm_mfa.py +++ b/viloca/local_haplotype_inference/use_quality_scores/run_dpm_mfa.py @@ -33,7 +33,7 @@ def main( K, alpha0, alphabet="ACGT-", - unique_modus=False, + unique_modus=True, convergence_threshold=1e-03, ): @@ -46,7 +46,7 @@ def main( reference_binary, ref_id = preparation.load_reference_seq(fref_in, alphabet) reads_list, qualities = preparation.load_fasta_and_qualities( - freads_in, fname_qualities, alphabet, False + freads_in, fname_qualities, alphabet, unique_modus ) reads_seq_binary, reads_weights = preparation.reads_list_to_array(reads_list) reads_log_error_proba = preparation.compute_reads_log_error_proba( @@ -147,9 +147,9 @@ def main( sys.argv[1], sys.argv[2], sys.argv[3], - int(sys.argv[4]), + sys.argv[4], int(sys.argv[5]), - float(sys.argv[6]), - sys.argv[7], + int(sys.argv[6]), + float(sys.argv[7]), ) # freads_in, fref_in, output_dir, n_starts, K, alpha0, alphabet = 'ACGT-'