From 4f61e2f749ca492e61b7b0699bc882082a72e4ae Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 17 Dec 2024 10:56:07 +0000 Subject: [PATCH] Implement run_hmm API endpoint Rename CLI match -> run-hmm --- sc2ts/cli.py | 100 +++++----------------------------------- sc2ts/inference.py | 70 ++++++++++++++++++++++++++++ tests/test_cli.py | 8 ++-- tests/test_inference.py | 54 ++++++++++------------ 4 files changed, 110 insertions(+), 122 deletions(-) diff --git a/sc2ts/cli.py b/sc2ts/cli.py index 79af3b8..36bfdb1 100644 --- a/sc2ts/cli.py +++ b/sc2ts/cli.py @@ -475,63 +475,7 @@ def tally_lineages(ts, metadata, verbose): df.to_csv(sys.stdout, sep="\t", index=False) -@dataclasses.dataclass(frozen=True) -class HmmRun: - strain: str - num_mismatches: int - direction: str - match: sc2ts.HmmMatch - - def asdict(self): - d = dataclasses.asdict(self) - d["match"] = dataclasses.asdict(self.match) - return d - - def asjson(self): - return json.dumps(self.asdict()) - - -@dataclasses.dataclass(frozen=True) -class MatchWork: - ts_path: str - samples: List - num_mismatches: int - direction: str - - -def _match_worker(work): - msg = ( - f"k={work.num_mismatches} n={len(work.samples)} " - f"{work.direction} {work.ts_path}" - ) - logger.info(f"Start: {msg}") - ts = tszip.load(work.ts_path) - sc2ts.match_tsinfer( - samples=work.samples, - ts=ts, - num_mismatches=work.num_mismatches, - mismatch_threshold=100, - # FIXME! - deletions_as_missing=False, - num_threads=0, - show_progress=False, - mirror_coordinates=work.direction == "reverse", - ) - runs = [] - for sample in work.samples: - runs.append( - HmmRun( - strain=sample.strain, - num_mismatches=work.num_mismatches, - direction=work.direction, - match=sample.hmm_match, - ) - ) - logger.info(f"Finish: {msg}") - return runs - - -@click.command(name="match") +@click.command() @click.argument("dataset", type=click.Path(exists=True, dir_okay=False)) @click.argument("ts_path", type=click.Path(exists=True, dir_okay=False)) @click.argument("strains", nargs=-1) @@ -559,7 +503,7 @@ def _match_worker(work): @click.option("--progress/--no-progress", default=True) @click.option("-v", "--verbose", count=True) @click.option("-l", "--log-file", default=None, type=click.Path(dir_okay=False)) -def _match( +def run_hmm( dataset, ts_path, strains, @@ -573,44 +517,21 @@ def _match( log_file, ): """ - Run matches for a specified set of strains, outputting details to stdout as JSON. + Run matches for a specified set of strains, outputing details to stdout as JSON. """ setup_logging(verbose, log_file) - ts = tszip.load(ts_path) - ds = sc2ts.Dataset(dataset) - if len(strains) == 0: - return - progress_title = "Match" - samples = sc2ts.preprocess( - list(strains), - dataset=ds, - show_progress=progress, - progress_title=progress_title, - keep_sites=ts.sites_position.astype(int), - ) - for sample in samples: - if sample.haplotype is None: - raise ValueError(f"No alignment stored for {sample.strain}") - sc2ts.match_tsinfer( - samples=samples, - ts=ts, + runs = sc2ts.run_hmm( + dataset, + ts_path, + strains=strains, num_mismatches=num_mismatches, - deletions_as_missing=deletions_as_missing, mismatch_threshold=mismatch_threshold, + direction=direction, num_threads=num_threads, show_progress=progress, - progress_title=progress_title, - progress_phase="HMM", - mirror_coordinates=direction == "reverse", ) - for sample in samples: - run = HmmRun( - strain=sample.strain, - num_mismatches=num_mismatches, - direction=direction, - match=sample.hmm_match, - ) + for run in runs: print(run.asjson()) @@ -649,5 +570,6 @@ def cli(): cli.add_command(infer) cli.add_command(validate) -cli.add_command(_match) +cli.add_command(run_hmm) + cli.add_command(tally_lineages) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index 61e65f0..a13b517 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -1834,6 +1834,76 @@ def add_root_edge(ts, flags=0): return tables.tree_sequence() +@dataclasses.dataclass(frozen=True) +class HmmRun: + strain: str + num_mismatches: int + direction: str + match: HmmMatch + + def asdict(self): + d = dataclasses.asdict(self) + d["match"] = dataclasses.asdict(self.match) + return d + + def asjson(self): + return json.dumps(self.asdict()) + + +def run_hmm( + dataset_path, + ts_path, + strains, + *, + num_mismatches, + direction="forward", + mismatch_threshold=None, + deletions_as_missing=None, + num_threads=0, + show_progress=False, +): + if deletions_as_missing is None: + deletions_as_missing = False + if mismatch_threshold is None: + mismatch_threshold = 100 + + ds = _dataset.Dataset(dataset_path) + ts = tszip.load(ts_path) + if len(strains) == 0: + return + progress_title = "Match" + samples = preprocess( + list(strains), + dataset=ds, + show_progress=show_progress, + progress_title=progress_title, + keep_sites=ts.sites_position.astype(int), + ) + match_tsinfer( + samples=samples, + ts=ts, + num_mismatches=num_mismatches, + deletions_as_missing=deletions_as_missing, + mismatch_threshold=mismatch_threshold, + num_threads=num_threads, + show_progress=show_progress, + progress_title=progress_title, + progress_phase="HMM", + mirror_coordinates=direction == "reverse", + ) + ret = [] + for sample in samples: + ret.append( + HmmRun( + strain=sample.strain, + num_mismatches=num_mismatches, + direction=direction, + match=sample.hmm_match, + ) + ) + return ret + + def get_group_strains(ts): """ Returns the strain IDs for samples gathered by sample group ID. diff --git a/tests/test_cli.py b/tests/test_cli.py index 7f0e603..9469ac7 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -78,7 +78,7 @@ def test_viridian_metadata( ) -class TestMatch: +class TestRunHmm: def test_single_defaults(self, tmp_path, fx_ts_map, fx_dataset): strain = "ERR4206593" @@ -88,7 +88,7 @@ def test_single_defaults(self, tmp_path, fx_ts_map, fx_dataset): runner = ct.CliRunner(mix_stderr=False) result = runner.invoke( cli.cli, - f"match {fx_dataset.path} {ts_path} {strain}", + f"run-hmm {fx_dataset.path} {ts_path} {strain}", catch_exceptions=False, ) assert result.exit_code == 0 @@ -110,7 +110,7 @@ def test_multi_defaults(self, tmp_path, fx_ts_map, fx_dataset): runner = ct.CliRunner(mix_stderr=False) result = runner.invoke( cli.cli, - f"match {fx_dataset.path} {ts_path} " + " ".join(strains), + f"run-hmm {fx_dataset.path} {ts_path} " + " ".join(strains), catch_exceptions=False, ) assert result.exit_code == 0 @@ -134,7 +134,7 @@ def test_single_options(self, tmp_path, fx_ts_map, fx_dataset): runner = ct.CliRunner(mix_stderr=False) result = runner.invoke( cli.cli, - f"match {fx_dataset.path} {ts_path} {strain}" + f"run-hmm {fx_dataset.path} {ts_path} {strain}" " --direction=reverse --num-mismatches=5 --num-threads=4", " --no-deletions-as-missing", catch_exceptions=False, diff --git a/tests/test_inference.py b/tests/test_inference.py index 82b5ec1..e91dbca 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1100,6 +1100,7 @@ class TestMatchingDetails: ("strain", "parent"), [("SRR11597207", 34), ("ERR4205570", 47)] ) @pytest.mark.parametrize("num_mismatches", [2, 3, 4]) + @pytest.mark.parametrize("direction", ["forward", "reverse"]) def test_exact_matches( self, fx_ts_map, @@ -1107,21 +1108,21 @@ def test_exact_matches( strain, parent, num_mismatches, + direction, ): ts = fx_ts_map["2020-02-10"] - samples = sc2ts.preprocess( + + runs = sc2ts.run_hmm( + fx_dataset.path, + ts.path, [strain], - fx_dataset, - keep_sites=ts.sites_position.astype(int), - ) - sc2ts.match_tsinfer( - samples=samples, - ts=ts, num_mismatches=num_mismatches, - mismatch_threshold=num_mismatches, - num_threads=0, + direction=direction, ) - s = samples[0].hmm_match + assert len(runs) == 1 + assert runs[0].num_mismatches == num_mismatches + assert runs[0].direction == direction + s = runs[0].match assert len(s.mutations) == 0 assert len(s.path) == 1 assert s.path[0].parent == parent @@ -1144,18 +1145,16 @@ def test_one_mismatch( num_mismatches, ): ts = fx_ts_map["2020-02-10"] - samples = sc2ts.preprocess( + runs = sc2ts.run_hmm( + fx_dataset.path, + ts.path, [strain], - fx_dataset, - keep_sites=ts.sites_position.astype(int), - ) - sc2ts.match_tsinfer( - samples=samples, - ts=ts, num_mismatches=num_mismatches, - mismatch_threshold=1, ) - s = samples[0].hmm_match + assert len(runs) == 1 + assert runs[0].num_mismatches == num_mismatches + assert runs[0].direction == "forward" + s = runs[0].match assert len(s.mutations) == 1 assert s.mutations[0].site_position == position assert s.mutations[0].derived_state == derived_state @@ -1171,25 +1170,22 @@ def test_two_mismatches( ): strain = "SRR11597164" ts = fx_ts_map["2020-02-01"] - samples = sc2ts.preprocess( + runs = sc2ts.run_hmm( + fx_dataset.path, + ts.path, [strain], - fx_dataset, - keep_sites=ts.sites_position.astype(int), - ) - sc2ts.match_tsinfer( - samples=samples, - ts=ts, num_mismatches=num_mismatches, - mismatch_threshold=2, ) - s = samples[0].hmm_match + assert len(runs) == 1 + assert runs[0].num_mismatches == num_mismatches + assert runs[0].direction == "forward" + s = runs[0].match assert len(s.path) == 1 assert s.path[0].parent == 1 assert len(s.mutations) == 2 def test_match_recombinant(self, fx_ts_map): ts, s = recombinant_example_1(fx_ts_map) - sc2ts.match_tsinfer( samples=[s], ts=ts,