Skip to content

Commit

Permalink
Integrate dadi LowPass functions and some streamlining code (#116)
Browse files Browse the repository at this point in the history
* Make loading work_queue and cyvcf2 optional

* Add exception for error cyvcf2 '_bcf_float_missing'

* Enable --interactive flag

* Updates to GenerateFs for LowPass

* Update Infer subcommands for LowPass

* Update Plot for LowPass

* Update unit tests for LowPass

* Remove --nomisid requirement from Plot and Stat subcommands, and instead use param names from inference output file

* Update tests for change to subcommands (besides GenerateCache) that take popt file

* Add skip when comparing model names to dadi <=2.3.6
  • Loading branch information
tjstruck authored Aug 14, 2024
1 parent 549d190 commit 540491d
Show file tree
Hide file tree
Showing 29 changed files with 1,523 additions and 136 deletions.
56 changes: 39 additions & 17 deletions dadi_cli/GenerateFs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import dadi, random
from cyvcf2 import VCF


def generate_fs(
Expand All @@ -8,12 +7,13 @@ def generate_fs(
pop_ids: list[str],
pop_info: str,
projections: list[int],
subsample: bool,
subsample: list[int],
polarized: bool,
marginalize_pops: list[str],
bootstrap: int,
chunk_size: int,
masking: str,
calc_coverage: bool,
seed: int,
) -> None:
"""
Expand All @@ -31,8 +31,9 @@ def generate_fs(
Name of the file containing population information.
projections : list[int]
List of sample sizes after projection.
subsample : bool
If True, spectrum is generated with sub-samples; otherwise, spectrum is generated with all samples.
subsample : list
If filled, spectrum is generated based on number of requested sub-sampled individuals for each population;
otherwise, spectrum is generated with all samples.
polarized : bool
If True, unfolded spectrum is generated; otherwise, folded spectrum is generated.
marginalize_pops : list[str]
Expand All @@ -46,6 +47,8 @@ def generate_fs(
'singleton' - Masks singletons in each population,
'shared' - Masks singletons in each population and those shared across populations,
'' - No masking is applied.
calc_coverage : bool
If True, a data dictionary with coverage information is generated as <output>.coverage.pickle.
seed : int
Seed for generating random numbers. If None, a random seed is used.
Expand All @@ -56,25 +59,44 @@ def generate_fs(
If the VCF file does not contain the AA INFO field and `polarized` is True.
"""
if len(pop_ids) != len(projections):
raise ValueError("The lengths of `pop_ids` and `projections` must match.")

if polarized:
if not VCF(vcf).contains('AA'):
raise ValueError(
f'The AA (Ancestral allele) INFO field cannot be found in the header of {vcf}, ' +
'but an unfolded frequency spectrum is requested.'
)

if subsample:
try:
from cyvcf2 import VCF
if not VCF(vcf).contains('AA'):
raise ValueError(
f'The AA (Ancestral allele) INFO field cannot be found in the header of {vcf}, ' +
'but an unfolded frequency spectrum is requested.'
)
except ModuleNotFoundError:
print("Unable to load cyvcf2 and check if ancestral alleles are in provided VCF.\n"+
"Generated FS may be empty if ancestral allele not found.")
except ImportError:
print("Error importing cyvcf2")

if subsample != []:
subsample_dict = {}
for i in range(len(pop_ids)):
subsample_dict[pop_ids[i]] = projections[i]
dd = dadi.Misc.make_data_dict_vcf(
vcf_filename=vcf, popinfo_filename=pop_info, subsample=subsample_dict
subsample_dict[pop_ids[i]] = subsample[i]
# dadi will store the number of chromosomes in ploidy
dd, ploidy = dadi.Misc.make_data_dict_vcf(
vcf_filename=vcf, popinfo_filename=pop_info, subsample=subsample_dict, calc_coverage=calc_coverage, extract_ploidy=True
)
# multiply number of individuals subsamples by the ploidy to get sample size
projections = [individuals*ploidy for individuals in subsample]
print(projections, ploidy, subsample)
else:
dd = dadi.Misc.make_data_dict_vcf(vcf_filename=vcf, popinfo_filename=pop_info)
dd = dadi.Misc.make_data_dict_vcf(vcf_filename=vcf, popinfo_filename=pop_info, calc_coverage=calc_coverage)

# Moved this lower, since using subsamples make projections not required
if len(pop_ids) != len(projections):
raise ValueError("The lengths of `pop_ids` and `projections` must match.")

if calc_coverage:
import pickle
coverage_dd = {chrom_pos:{'coverage':dd[chrom_pos]['coverage']} for chrom_pos in dd}
print(f"\nSaving coverage dictionary in pickle named:\n{output}.coverage.pickle\n")
pickle.dump(coverage_dd, open(f"{output}.coverage.pickle","wb"))

if bootstrap is None:
fs = dadi.Spectrum.from_data_dict(
Expand Down
15 changes: 15 additions & 0 deletions dadi_cli/InferDFE.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def infer_dfe(
lower_bounds: list[float],
fixed_params: list[float],
misid: bool,
cov_args: list,
cov_inbreeding: list,
cuda: bool,
maxeval: int,
maxtime: int,
Expand Down Expand Up @@ -109,6 +111,19 @@ def infer_dfe(
if misid:
func = dadi.Numerics.make_anc_state_misid_func(func)

if cov_args != []:
try:
from dadi.LowPass.LowPass import make_low_pass_func_GATK_multisample as func_cov
except ModuleNotFoundError:
raise ImportError("ERROR:\nCurrent dadi version does not support coverage model\n")
nseq = [int(ele) for ele in cov_args[1:]]
if cov_inbreeding == []:
Fx = None
else:
Fx = cov_inbreeding

func = func_cov(func, cov_args[0], fs.pop_ids, nseq, fs.sample_sizes, Fx=Fx)

p0_len = len(p0)
lower_bounds = convert_to_None(lower_bounds, p0_len)
upper_bounds = convert_to_None(upper_bounds, p0_len)
Expand Down
32 changes: 32 additions & 0 deletions dadi_cli/InferDM.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def infer_demography(
lower_bounds: list[float],
fixed_params: list[float],
misid: bool,
cov_args: list,
cov_inbreeding: list,
cuda: bool,
maxeval: int,
maxtime: int,
Expand Down Expand Up @@ -82,6 +84,20 @@ def infer_demography(
func = dadi.Numerics.make_anc_state_misid_func(func)

func_ex = dadi.Numerics.make_extrap_func(func)

if cov_args != []:
try:
from dadi.LowPass.LowPass import make_low_pass_func_GATK_multisample as func_cov
except ModuleNotFoundError:
raise ImportError("ERROR:\nCurrent dadi version does not support coverage model\n")
nseq = [int(ele) for ele in cov_args[1:]]
if cov_inbreeding == []:
Fx = None
else:
Fx = cov_inbreeding

func_ex = func_cov(func_ex, cov_args[0], fs.pop_ids, nseq, fs.sample_sizes, Fx=Fx)

p0_len = len(p0)
lower_bounds = convert_to_None(lower_bounds, p0_len)
upper_bounds = convert_to_None(upper_bounds, p0_len)
Expand Down Expand Up @@ -131,6 +147,8 @@ def infer_global_opt(
lower_bounds: list[float],
fixed_params: list[float],
misid: bool,
cov_args: list,
cov_inbreeding: list,
cuda: bool,
maxeval: int,
maxtime: int,
Expand Down Expand Up @@ -191,6 +209,20 @@ def infer_global_opt(

func_ex = dadi.Numerics.make_extrap_func(func)

if cov_args != []:
try:
from dadi.LowPass.LowPass import make_low_pass_func_GATK_multisample as func_cov
import pickle
except ModuleNotFoundError:
raise ImportError("ERROR:\nCurrent dadi version does not support coverage model\n")
nseq = [int(ele) for ele in cov_args[1:]]
if cov_inbreeding == []:
Fx = None
else:
Fx = cov_inbreeding

func_ex = func_cov(func_ex, cov_args[0], fs.pop_ids, nseq, fs.sample_sizes, Fx=Fx)

p0_len = len(p0)
lower_bounds = convert_to_None(lower_bounds, p0_len)
upper_bounds = convert_to_None(upper_bounds, p0_len)
Expand Down
28 changes: 14 additions & 14 deletions dadi_cli/Models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
from inspect import getmembers, isfunction

duplicated_models = ["snm", "bottlegrowth"]
duplicated_sele_models = [
"IM",
"IM_pre",
"IM_pre_single_gamma",
"IM_single_gamma",
"split_asym_mig",
"split_asym_mig_single_gamma",
"split_mig",
"three_epoch",
"two_epoch",
"split_mig_single_gamma",
]
# duplicated_sele_models = [
# "IM",
# "IM_pre",
# "IM_pre_single_gamma",
# "IM_single_gamma",
# "split_asym_mig",
# "split_asym_mig_single_gamma",
# "split_mig",
# "three_epoch",
# "two_epoch",
# "split_mig_single_gamma",
# ]
oned_models = [m[0] for m in getmembers(dadi.Demographics1D, isfunction)]
twod_models = [m[0] for m in getmembers(dadi.Demographics2D, isfunction)]
try:
Expand All @@ -29,8 +29,8 @@
oned_models.remove(m)
for m in duplicated_models:
twod_models.remove(m)
for m in duplicated_sele_models:
sele_models.remove(m)
# for m in duplicated_sele_models:
# sele_models.remove(m)


def get_model(
Expand Down
51 changes: 37 additions & 14 deletions dadi_cli/Plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def plot_fitted_demography(
func: callable,
popt: str,
projections: list[int],
nomisid: bool,
cov_args: list,
cov_inbreeding: list,
output: str,
vmin: float,
resid_range: list[float],
Expand All @@ -153,8 +154,6 @@ def plot_fitted_demography(
Path to the file containing the best-fit parameters for the demographic model.
projections : list[int]
List of integers representing the sample sizes after projection, used to adjust the plots.
nomisid : bool
If True, ancestral state misidentification is considered in the modeling; if False, it is not.
output : str
Path where the comparison plot will be saved. The file format is inferred from the file extension.
vmin : float
Expand All @@ -171,15 +170,27 @@ def plot_fitted_demography(
"""

popt, _ = get_opts_and_theta(popt)

popt, _, param_names = get_opts_and_theta(popt, post_infer=True)
fs = dadi.Spectrum.from_file(fs)
if not nomisid:
if param_names[-2] == 'misid':
func = dadi.Numerics.make_anc_state_misid_func(func)
func_ex = dadi.Numerics.make_extrap_func(func)
ns = fs.sample_sizes
pts_l = pts_l_func(ns)

if cov_args != []:
try:
from dadi.LowPass.LowPass import make_low_pass_func_GATK_multisample as func_cov
except ModuleNotFoundError:
raise ImportError("ERROR:\nCurrent dadi version does not support coverage model\n")
nseq = [int(ele) for ele in cov_args[1:]]
if cov_inbreeding == []:
Fx = None
else:
Fx = cov_inbreeding

func_ex = func_cov(func_ex, cov_args[0], fs.pop_ids, nseq, fs.sample_sizes, Fx=Fx)

model = func_ex(popt, ns, pts_l)

fig = plt.figure(219033)
Expand Down Expand Up @@ -218,7 +229,8 @@ def plot_fitted_dfe(
projections: list[int],
pdf: str,
pdf2: str,
nomisid: bool,
cov_args: list,
cov_inbreeding: list,
output: str,
vmin: float,
resid_range: list[float],
Expand All @@ -243,8 +255,6 @@ def plot_fitted_dfe(
Name of the 1D probability density function file for modeling the DFE.
pdf2 : str
Name of the 2D probability density function file for modeling the DFE.
nomisid : bool
If True, includes ancestral state misidentification in the modeling; if False, it does not.
output : str
Path where the comparison plot will be saved. The file format is inferred from the file extension.
vmin : float
Expand All @@ -260,7 +270,7 @@ def plot_fitted_dfe(
If comparison with more than three populations.
"""
sele_popt, theta = get_opts_and_theta(sele_popt)
sele_popt, theta, param_names = get_opts_and_theta(sele_popt, post_infer=True)

fs = dadi.Spectrum.from_file(fs)

Expand All @@ -283,8 +293,21 @@ def plot_fitted_dfe(
if (cache1d != None) and (cache2d != None):
func = dadi.DFE.mixture

if not nomisid:
if param_names[-2] == 'misid':
func = dadi.Numerics.make_anc_state_misid_func(func)

if cov_args != []:
try:
from dadi.LowPass.LowPass import make_low_pass_func_GATK_multisample as func_cov
except ModuleNotFoundError:
raise ImportError("ERROR:\nCurrent dadi version does not support coverage model\n")
nseq = [int(ele) for ele in cov_args[1:]]
if cov_inbreeding == []:
Fx = None
else:
Fx = cov_inbreeding

func = func_cov(func, cov_args[0], fs.pop_ids, nseq, fs.sample_sizes, Fx=Fx)
# Get expected SFS for MLE
if (cache1d != None) and (cache2d != None):
model = func(sele_popt, None, spectra1d, spectra2d, pdf, pdf2, theta, None)
Expand All @@ -297,22 +320,22 @@ def plot_fitted_dfe(
projections = ns
fs = fs.project(projections)
model = model.project(projections)
dadi.Plotting.plot_1d_comp_Poisson(model, fs)
dadi.Plotting.plot_1d_comp_Poisson(model, fs, show=show)
if len(ns) == 2:
if projections == None:
projections = ns
fs = fs.project(projections)
model = model.project(projections)
dadi.Plotting.plot_2d_comp_Poisson(
model, fs, vmin=vmin, resid_range=resid_range
model, fs, vmin=vmin, resid_range=resid_range, show=show
)
if len(ns) == 3:
if projections == None:
projections = ns
fs = fs.project(projections)
model = model.project(projections)
dadi.Plotting.plot_3d_comp_Poisson(
model, fs, vmin=vmin, resid_range=resid_range
model, fs, vmin=vmin, resid_range=resid_range, show=show
)
if len(ns) > 3:
raise ValueError("dadi-cli does not support comparing fs and model with more than three populations")
Expand Down
Loading

0 comments on commit 540491d

Please sign in to comment.