Skip to content

Commit

Permalink
Merge pull request #453 from jeromekelleher/match-python-api
Browse files Browse the repository at this point in the history
Implement run_hmm API endpoint
  • Loading branch information
jeromekelleher authored Dec 17, 2024
2 parents 562c467 + 4f61e2f commit 12842dd
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 122 deletions.
100 changes: 11 additions & 89 deletions sc2ts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,63 +475,7 @@ def tally_lineages(ts, metadata, verbose):
df.to_csv(sys.stdout, sep="\t", index=False)


@dataclasses.dataclass(frozen=True)
class HmmRun:
strain: str
num_mismatches: int
direction: str
match: sc2ts.HmmMatch

def asdict(self):
d = dataclasses.asdict(self)
d["match"] = dataclasses.asdict(self.match)
return d

def asjson(self):
return json.dumps(self.asdict())


@dataclasses.dataclass(frozen=True)
class MatchWork:
ts_path: str
samples: List
num_mismatches: int
direction: str


def _match_worker(work):
msg = (
f"k={work.num_mismatches} n={len(work.samples)} "
f"{work.direction} {work.ts_path}"
)
logger.info(f"Start: {msg}")
ts = tszip.load(work.ts_path)
sc2ts.match_tsinfer(
samples=work.samples,
ts=ts,
num_mismatches=work.num_mismatches,
mismatch_threshold=100,
# FIXME!
deletions_as_missing=False,
num_threads=0,
show_progress=False,
mirror_coordinates=work.direction == "reverse",
)
runs = []
for sample in work.samples:
runs.append(
HmmRun(
strain=sample.strain,
num_mismatches=work.num_mismatches,
direction=work.direction,
match=sample.hmm_match,
)
)
logger.info(f"Finish: {msg}")
return runs


@click.command(name="match")
@click.command()
@click.argument("dataset", type=click.Path(exists=True, dir_okay=False))
@click.argument("ts_path", type=click.Path(exists=True, dir_okay=False))
@click.argument("strains", nargs=-1)
Expand Down Expand Up @@ -559,7 +503,7 @@ def _match_worker(work):
@click.option("--progress/--no-progress", default=True)
@click.option("-v", "--verbose", count=True)
@click.option("-l", "--log-file", default=None, type=click.Path(dir_okay=False))
def _match(
def run_hmm(
dataset,
ts_path,
strains,
Expand All @@ -573,44 +517,21 @@ def _match(
log_file,
):
"""
Run matches for a specified set of strains, outputting details to stdout as JSON.
Run matches for a specified set of strains, outputing details to stdout as JSON.
"""
setup_logging(verbose, log_file)
ts = tszip.load(ts_path)
ds = sc2ts.Dataset(dataset)
if len(strains) == 0:
return
progress_title = "Match"
samples = sc2ts.preprocess(
list(strains),
dataset=ds,
show_progress=progress,
progress_title=progress_title,
keep_sites=ts.sites_position.astype(int),
)
for sample in samples:
if sample.haplotype is None:
raise ValueError(f"No alignment stored for {sample.strain}")

sc2ts.match_tsinfer(
samples=samples,
ts=ts,
runs = sc2ts.run_hmm(
dataset,
ts_path,
strains=strains,
num_mismatches=num_mismatches,
deletions_as_missing=deletions_as_missing,
mismatch_threshold=mismatch_threshold,
direction=direction,
num_threads=num_threads,
show_progress=progress,
progress_title=progress_title,
progress_phase="HMM",
mirror_coordinates=direction == "reverse",
)
for sample in samples:
run = HmmRun(
strain=sample.strain,
num_mismatches=num_mismatches,
direction=direction,
match=sample.hmm_match,
)
for run in runs:
print(run.asjson())


Expand Down Expand Up @@ -649,5 +570,6 @@ def cli():

cli.add_command(infer)
cli.add_command(validate)
cli.add_command(_match)
cli.add_command(run_hmm)

cli.add_command(tally_lineages)
70 changes: 70 additions & 0 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1834,6 +1834,76 @@ def add_root_edge(ts, flags=0):
return tables.tree_sequence()


@dataclasses.dataclass(frozen=True)
class HmmRun:
strain: str
num_mismatches: int
direction: str
match: HmmMatch

def asdict(self):
d = dataclasses.asdict(self)
d["match"] = dataclasses.asdict(self.match)
return d

def asjson(self):
return json.dumps(self.asdict())


def run_hmm(
dataset_path,
ts_path,
strains,
*,
num_mismatches,
direction="forward",
mismatch_threshold=None,
deletions_as_missing=None,
num_threads=0,
show_progress=False,
):
if deletions_as_missing is None:
deletions_as_missing = False
if mismatch_threshold is None:
mismatch_threshold = 100

ds = _dataset.Dataset(dataset_path)
ts = tszip.load(ts_path)
if len(strains) == 0:
return
progress_title = "Match"
samples = preprocess(
list(strains),
dataset=ds,
show_progress=show_progress,
progress_title=progress_title,
keep_sites=ts.sites_position.astype(int),
)
match_tsinfer(
samples=samples,
ts=ts,
num_mismatches=num_mismatches,
deletions_as_missing=deletions_as_missing,
mismatch_threshold=mismatch_threshold,
num_threads=num_threads,
show_progress=show_progress,
progress_title=progress_title,
progress_phase="HMM",
mirror_coordinates=direction == "reverse",
)
ret = []
for sample in samples:
ret.append(
HmmRun(
strain=sample.strain,
num_mismatches=num_mismatches,
direction=direction,
match=sample.hmm_match,
)
)
return ret


def get_group_strains(ts):
"""
Returns the strain IDs for samples gathered by sample group ID.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_viridian_metadata(
)


class TestMatch:
class TestRunHmm:

def test_single_defaults(self, tmp_path, fx_ts_map, fx_dataset):
strain = "ERR4206593"
Expand All @@ -88,7 +88,7 @@ def test_single_defaults(self, tmp_path, fx_ts_map, fx_dataset):
runner = ct.CliRunner(mix_stderr=False)
result = runner.invoke(
cli.cli,
f"match {fx_dataset.path} {ts_path} {strain}",
f"run-hmm {fx_dataset.path} {ts_path} {strain}",
catch_exceptions=False,
)
assert result.exit_code == 0
Expand All @@ -110,7 +110,7 @@ def test_multi_defaults(self, tmp_path, fx_ts_map, fx_dataset):
runner = ct.CliRunner(mix_stderr=False)
result = runner.invoke(
cli.cli,
f"match {fx_dataset.path} {ts_path} " + " ".join(strains),
f"run-hmm {fx_dataset.path} {ts_path} " + " ".join(strains),
catch_exceptions=False,
)
assert result.exit_code == 0
Expand All @@ -134,7 +134,7 @@ def test_single_options(self, tmp_path, fx_ts_map, fx_dataset):
runner = ct.CliRunner(mix_stderr=False)
result = runner.invoke(
cli.cli,
f"match {fx_dataset.path} {ts_path} {strain}"
f"run-hmm {fx_dataset.path} {ts_path} {strain}"
" --direction=reverse --num-mismatches=5 --num-threads=4",
" --no-deletions-as-missing",
catch_exceptions=False,
Expand Down
54 changes: 25 additions & 29 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,28 +1100,29 @@ class TestMatchingDetails:
("strain", "parent"), [("SRR11597207", 34), ("ERR4205570", 47)]
)
@pytest.mark.parametrize("num_mismatches", [2, 3, 4])
@pytest.mark.parametrize("direction", ["forward", "reverse"])
def test_exact_matches(
self,
fx_ts_map,
fx_dataset,
strain,
parent,
num_mismatches,
direction,
):
ts = fx_ts_map["2020-02-10"]
samples = sc2ts.preprocess(

runs = sc2ts.run_hmm(
fx_dataset.path,
ts.path,
[strain],
fx_dataset,
keep_sites=ts.sites_position.astype(int),
)
sc2ts.match_tsinfer(
samples=samples,
ts=ts,
num_mismatches=num_mismatches,
mismatch_threshold=num_mismatches,
num_threads=0,
direction=direction,
)
s = samples[0].hmm_match
assert len(runs) == 1
assert runs[0].num_mismatches == num_mismatches
assert runs[0].direction == direction
s = runs[0].match
assert len(s.mutations) == 0
assert len(s.path) == 1
assert s.path[0].parent == parent
Expand All @@ -1144,18 +1145,16 @@ def test_one_mismatch(
num_mismatches,
):
ts = fx_ts_map["2020-02-10"]
samples = sc2ts.preprocess(
runs = sc2ts.run_hmm(
fx_dataset.path,
ts.path,
[strain],
fx_dataset,
keep_sites=ts.sites_position.astype(int),
)
sc2ts.match_tsinfer(
samples=samples,
ts=ts,
num_mismatches=num_mismatches,
mismatch_threshold=1,
)
s = samples[0].hmm_match
assert len(runs) == 1
assert runs[0].num_mismatches == num_mismatches
assert runs[0].direction == "forward"
s = runs[0].match
assert len(s.mutations) == 1
assert s.mutations[0].site_position == position
assert s.mutations[0].derived_state == derived_state
Expand All @@ -1171,25 +1170,22 @@ def test_two_mismatches(
):
strain = "SRR11597164"
ts = fx_ts_map["2020-02-01"]
samples = sc2ts.preprocess(
runs = sc2ts.run_hmm(
fx_dataset.path,
ts.path,
[strain],
fx_dataset,
keep_sites=ts.sites_position.astype(int),
)
sc2ts.match_tsinfer(
samples=samples,
ts=ts,
num_mismatches=num_mismatches,
mismatch_threshold=2,
)
s = samples[0].hmm_match
assert len(runs) == 1
assert runs[0].num_mismatches == num_mismatches
assert runs[0].direction == "forward"
s = runs[0].match
assert len(s.path) == 1
assert s.path[0].parent == 1
assert len(s.mutations) == 2

def test_match_recombinant(self, fx_ts_map):
ts, s = recombinant_example_1(fx_ts_map)

sc2ts.match_tsinfer(
samples=[s],
ts=ts,
Expand Down

0 comments on commit 12842dd

Please sign in to comment.