Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
xin-huang committed Apr 25, 2024
1 parent a2f8968 commit 38f821c
Show file tree
Hide file tree
Showing 11 changed files with 224 additions and 94 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,8 @@ ssh-key
ssh-key.pub
terraform.tfstate
terraform.tfstate.backup
bestfits_file.bestfits
dadi_models.py
demes_file.yml
sfs.fs
demo-popt.bestfits
62 changes: 61 additions & 1 deletion dadi_cli/parsers/argument_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,37 @@ def positive_int(value: str) -> int:
return ivalue


def positive_number(value: str) -> float:
def nonnegative_int(value: str) -> int:
"""
Validates if the provided string represents a nonnegative integer.
Parameters
----------
value : str
The value to validate.
Returns
-------
int
The validated nonnegative integer.
Raises
------
argparse.ArgumentTypeError
If the value is not a valid integer or nonnegative integer.
"""
if value is not None:
try:
ivalue = int(value)
except ValueError:
raise argparse.ArgumentTypeError(f"{value} is not a valid integer")
if ivalue < 0:
raise argparse.ArgumentTypeError(f"{value} is not a nonnegative integer")
return ivalue


def positive_num(value: str) -> float:
"""
Validates if the provided string represents a positive number.
Expand Down Expand Up @@ -61,6 +91,36 @@ def positive_number(value: str) -> float:
return fvalue


def nonnegative_num(value: str) -> float:
"""
Validates if the provided string represents a nonnegative number.
Parameters
----------
value : str
The value to validate.
Returns
-------
float
The validated nonnegative number.
Raises
------
argparse.ArgumentTypeError
If the value is not a valid number or nonnegative number.
"""
if value is not None:
try:
fvalue = float(value)
except ValueError:
raise argparse.ArgumentTypeError(f"{value} is not a valid number")
if fvalue < 0:
raise argparse.ArgumentTypeError(f"{value} is not a nonnegative number")
return fvalue


def existed_file(value: str) -> str:
"""
Validates if the provided string is a path to an existing file.
Expand Down
5 changes: 3 additions & 2 deletions dadi_cli/parsers/common_arguments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import argparse, multiprocessing, sys
import numpy as np
from dadi_cli.parsers.argument_validation import *


Expand Down Expand Up @@ -290,7 +291,7 @@ def add_inference_argument(parser) -> None:

parser.add_argument(
"--gpus",
type=positive_int,
type=nonnegative_int,
default=0,
help="Number of GPUs to use in multiprocessing. Default: 0.",
)
Expand Down
4 changes: 2 additions & 2 deletions dadi_cli/parsers/generate_cache_parsers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import argparse
import argparse, multiprocessing
from dadi_cli.parsers.common_arguments import *
from dadi_cli.parsers.argument_validation import *
from dadi_cli.GenerateCache import *
Expand Down Expand Up @@ -88,7 +88,7 @@ def add_generate_cache_parsers(subparsers) -> None:

parser.add_argument(
"--gpus",
type=positive_int,
type=nonnegative_int,
default=0,
help="Number of GPUs to use in multiprocessing. Default: 0.",
)
Expand Down
20 changes: 11 additions & 9 deletions dadi_cli/parsers/infer_dm_parsers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import argparse, inspect, nlopt, os, random, sys, time,
import argparse, inspect, nlopt, os, random, sys, time
import work_queue as wq
from multiprocessing import Process, Queue
from sys import exit
from dadi_cli.parsers.common_arguments import *
from dadi_cli.parsers.argument_validation import *
from dadi_cli.InferDM import *
from dadi_cli.utilities import *
from dadi_cli.BestFit import get_bestfit_params
from dadi_cli.BestFit import *
from dadi_cli.Models import *


def _run_infer_dm(args) -> None:
Expand Down Expand Up @@ -56,13 +58,13 @@ def _run_infer_dm(args) -> None:
args.misid = not (fs.folded or args.nomisid)

if not args.model_file and args.constants != -1:
args.constants = _check_params(
args.constants = check_params(
args.constants, args.model, "--constant", args.misid
)
if not args.model_file and args.lbounds != -1:
args.lbounds = _check_params(args.lbounds, args.model, "--lbounds", args.misid)
args.lbounds = check_params(args.lbounds, args.model, "--lbounds", args.misid)
if not args.model_file and args.ubounds != -1:
args.ubounds = _check_params(args.ubounds, args.model, "--ubounds", args.misid)
args.ubounds = check_params(args.ubounds, args.model, "--ubounds", args.misid)

if args.p0 == -1:
args.p0 = _calc_p0_from_bounds(args.lbounds, args.ubounds)
Expand Down Expand Up @@ -191,13 +193,13 @@ def _run_infer_dm(args) -> None:
# Create workers
workers = [
Process(
target=_worker_func,
target=worker_func,
args=(infer_global_opt, in_queue, out_queue, worker_args)
)
for ii in range(args.cpus)
]
workers.extend([
Process(target=_worker_func,
Process(target=worker_func,
args=(infer_global_opt, in_queue, out_queue, worker_args, True))
for ii in range(args.gpus)
])
Expand Down Expand Up @@ -312,13 +314,13 @@ def _run_infer_dm(args) -> None:
# Create workers
workers = [
Process(
target=_worker_func,
target=worker_func,
args=(infer_demography, in_queue, out_queue, worker_args)
)
for ii in range(args.cpus)
]
workers.extend([
Process(target=_worker_func,
Process(target=worker_func,
args=(infer_demography, in_queue, out_queue, worker_args, True))
for ii in range(args.gpus)
])
Expand Down
2 changes: 1 addition & 1 deletion dadi_cli/parsers/pdf_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dadi_cli.Pdfs import *


def run_pdf(args) -> None:
def _run_pdf(args) -> None:
"""
"""
if args.names is None:
Expand Down
2 changes: 1 addition & 1 deletion dadi_cli/parsers/stat_dfe_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dadi_cli.Stat import godambe_stat_dfe


def run_stat_dfe(args) -> None:
def _run_stat_dfe(args) -> None:
"""
"""
# # Code kept just in case user requests functionality if the future
Expand Down
2 changes: 1 addition & 1 deletion dadi_cli/parsers/stat_dm_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dadi_cli.Stat import godambe_stat_demograpy


def run_stat_demography(args) -> None:
def _run_stat_demography(args) -> None:
"""
"""
# # Code kept just in case user requests functionality if the future
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os.path
from setuptools import setup
from setuptools import setup, find_packages

# The directory containing this file
HERE = os.path.abspath(os.path.dirname(__file__))
Expand All @@ -11,6 +11,7 @@
# This call to setup() does all the work
setup(
name="dadi-cli",
python_requires='>=3.9',
version="0.9.4b",
description="A command line interface for dadi",
long_description=README,
Expand All @@ -24,7 +25,7 @@
"Programming Language :: Python",
"Programming Language :: Python :: 3.9",
],
packages=["dadi_cli"],
packages=find_packages(),
include_package_data=True,
# install_requires=[
# "dadi"
Expand Down
52 changes: 30 additions & 22 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dadi_cli.__main__ as cli
import os


try:
if not os.path.exists("./tests/test_results"):
os.makedirs("./tests/test_results")
Expand All @@ -10,10 +11,10 @@


def test_generate_fs_args():
parser = cli.dadi_cli_parser()
parser = cli._dadi_cli_parser()
cmd = "GenerateFs"
vcf = "test.vcf"
pop_infos = "test.info"
vcf = "tests/example_data/1KG.YRI.CEU.biallelic.synonymous.snps.withanc.strict.short.vcf"
pop_infos = "tests/example_data/four.popfile.txt"
output = "test.output"
args = parser.parse_args(
[
Expand All @@ -38,10 +39,10 @@ def test_generate_fs_args():


def test_generate_cache_args():
parser = cli.dadi_cli_parser()
parser = cli._dadi_cli_parser()
cmd = "GenerateCache"
model = "split_mig_fix_T_one_s"
model_file = "tests/example_data/example_models"
model_file = "tests/example_data/example_models.py"
grids = [20, 40, 60]
demo_popt = "tests/example_data/example.split_mig_fix_T.demo.params.InferDM.bestfits"
gamma_bounds = [1e-4, 10.0]
Expand Down Expand Up @@ -93,8 +94,9 @@ def test_generate_cache_args():
assert args.demo_popt == demo_popt
assert args.additional_gammas == additional_gammas


def test_simulate_dm_args():
parser = cli.dadi_cli_parser()
parser = cli._dadi_cli_parser()
cmd = "SimulateDM"
model = "two_epoch"
model_file = None
Expand Down Expand Up @@ -132,8 +134,9 @@ def test_simulate_dm_args():
assert args.output == output
assert args.inference_file == inference_file


def test_simulate_dfe_args():
parser = cli.dadi_cli_parser()
parser = cli._dadi_cli_parser()
cmd = "SimulateDFE"
cache1d = "tests/example_data/cache_split_mig_1d.bpkl"
cache2d = "tests/example_data/cache_split_mig_2d.bpkl"
Expand Down Expand Up @@ -173,15 +176,17 @@ def test_simulate_dfe_args():
assert args.nomisid == nomisid
assert args.output == output


try:
import demes
skip = False
except:
skip = True


@pytest.mark.skipif(skip, reason="Could not load Demes")
def test_simulate_demes_args():
parser = cli.dadi_cli_parser()
parser = cli._dadi_cli_parser()
cmd = "SimulateDemes"
demes_file = "examples/data/gutenkunst_ooa.yml"
sample_sizes = [10, 10, 10]
Expand Down Expand Up @@ -213,13 +218,13 @@ def test_simulate_demes_args():


@pytest.mark.parametrize("model, model_file, nomisid",
[
("split_mig_fix_T", "tests/example_data/example_models", False),
("snm_1d", None, True)
]
)
[
("split_mig_fix_T", "tests/example_data/example_models.py", False),
("snm_1d", None, True)
]
)
def test_infer_dm_args(model, model_file, nomisid):
parser = cli.dadi_cli_parser()
parser = cli._dadi_cli_parser()
cmd = "InferDM"
fs = "tests/example_data/two_epoch_syn.fs"
p0 = [1, 1, 0.01, 1, 0.05]
Expand Down Expand Up @@ -300,7 +305,8 @@ def test_infer_dm_args(model, model_file, nomisid):
assert args.delta_ll == delta_ll
assert args.gpus == gpus
assert args.model == model
assert args.model_file == str(model_file)
#assert args.model_file == str(model_file)
assert args.model_file == model_file
assert args.grids == grids
assert args.nomisid == False
assert args.constants == -1
Expand All @@ -311,7 +317,7 @@ def test_infer_dm_args(model, model_file, nomisid):


def test_infer_dfe_args():
parser = cli.dadi_cli_parser()
parser = cli._dadi_cli_parser()
cmd = "InferDFE"
fs_mix = "tests/example_data/split_mig_non_mix.fs"
fs_1d_lognorm = "tests/example_data/split_mig_non_1d.fs"
Expand Down Expand Up @@ -425,8 +431,9 @@ def test_infer_dfe_args():
assert args.pdf_file == None
assert args.bestfit_p0 == None


def test_bestfit_args():
parser = cli.dadi_cli_parser()
parser = cli._dadi_cli_parser()
cmd = "BestFit"
input_prefix = "tests/example_data/example.split_mig.demo.params.InferDM"
args = parser.parse_args(
Expand All @@ -449,7 +456,7 @@ def test_bestfit_args():


def test_plot_args():
parser = cli.dadi_cli_parser()
parser = cli._dadi_cli_parser()
cmd = "Plot"
fs = "tests/example_data/split_mig_non_1d.fs"
fs2 = "tests/example_data/split_mig_non_2d.fs"
Expand Down Expand Up @@ -519,9 +526,8 @@ def test_plot_args():
assert args.ratio == ratio



def test_stat_dm_args():
parser = cli.dadi_cli_parser()
parser = cli._dadi_cli_parser()
cmd = "StatDM"
fs = "tests/example_data/split_mig_syn.fs"
model = "split_mig"
Expand Down Expand Up @@ -564,8 +570,9 @@ def test_stat_dm_args():
assert args.bootstrapping_dir == bootstrapping_dir
assert args.logscale == logscale


def test_stat_dfe_args():
parser = cli.dadi_cli_parser()
parser = cli._dadi_cli_parser()
cmd = "StatDFE"
fs = "tests/example_data/split_mig_non_mix.fs"
cache1d = "tests/example_data/cache_split_mig_1d.bpkl"
Expand Down Expand Up @@ -618,8 +625,9 @@ def test_stat_dfe_args():
assert args.bootstrapping_non_dir == bootstrapping_nonsynonymous_dir
assert args.logscale == logscale


def test_model_args():
parser = cli.dadi_cli_parser()
parser = cli._dadi_cli_parser()
cmd = "Model"
args = parser.parse_args([cmd, "--names", "two_epoch"])
assert args.names == "two_epoch"
Expand Down
Loading

0 comments on commit 38f821c

Please sign in to comment.