diff --git a/sc2ts/inference.py b/sc2ts/inference.py index a13b517..fd455f5 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -1867,10 +1867,12 @@ def run_hmm( if mismatch_threshold is None: mismatch_threshold = 100 + directions = ["forward", "reverse"] + if direction not in directions: + raise ValueError(f"Direction must be one of {directions}") + ds = _dataset.Dataset(dataset_path) ts = tszip.load(ts_path) - if len(strains) == 0: - return progress_title = "Match" samples = preprocess( list(strains), diff --git a/tests/test_inference.py b/tests/test_inference.py index e91dbca..60cdc1e 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1207,6 +1207,26 @@ def test_match_recombinant(self, fx_ts_map): assert m.path[1].right == ts.sequence_length +class TestRunHmm: + + @pytest.mark.parametrize("direction", ["F", "R", "forwards", "backwards", "", None]) + def test_bad_direction(self, fx_dataset, fx_ts_map, direction): + strain = "SRR11597164" + ts = fx_ts_map["2020-02-01"] + with pytest.raises(ValueError, match="Direction must be one of"): + sc2ts.run_hmm( + fx_dataset.path, + ts.path, + [strain], + direction=direction, + num_mismatches=3, + ) + + def test_no_strains(self, fx_dataset, fx_ts_map): + ts = fx_ts_map["2020-02-01"] + assert len(sc2ts.run_hmm(fx_dataset.path, ts.path, [], num_mismatches=3)) == 0 + + class TestCharacteriseRecombinants: def test_example_1(self, fx_ts_map):