diff --git a/tests/test_series.py b/tests/test_series.py index 207f438..cf1061c 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -4,6 +4,8 @@ from sleap_roots.series import Series, find_all_series from pathlib import Path from typing import Literal +from contextlib import redirect_stdout +import io @pytest.fixture @@ -56,6 +58,67 @@ def csv_path(tmp_path): ) return csv_path +def test_primary_prediction_not_found(tmp_path): + dummy_video_path = tmp_path / "dummy_video.mp4" + dummy_video_path.write_text("This is a dummy video file.") + + # Create a dummy Series instance with a non-existent primary prediction file + output = io.StringIO() + with redirect_stdout(output): + Series.load(h5_path=str(dummy_video_path), primary_name="nonexistent") + + # format file path string for assert statement + path_obj = Path(dummy_video_path) + parent_dir = path_obj.parent + filename_without_extension = path_obj.stem + new_file_path = parent_dir / filename_without_extension + + assert output.getvalue() == f"Primary prediction file not found: {new_file_path}.nonexistent.predictions.slp\n" + +def test_lateral_prediction_not_found(tmp_path): + dummy_video_path = tmp_path / "dummy_video.mp4" + dummy_video_path.write_text("This is a dummy video file.") + + # Create a dummy Series instance with a non-existent primary prediction file + output = io.StringIO() + with redirect_stdout(output): + Series.load(h5_path=str(dummy_video_path), lateral_name="nonexistent") + + # format file path string for assert statement + path_obj = Path(dummy_video_path) + parent_dir = path_obj.parent + filename_without_extension = path_obj.stem + new_file_path = parent_dir / filename_without_extension + + assert output.getvalue() == f"Lateral prediction file not found: {new_file_path}.nonexistent.predictions.slp\n" + +def test_crown_prediction_not_found(tmp_path): + dummy_video_path = tmp_path / "dummy_video.mp4" + dummy_video_path.write_text("This is a dummy video file.") + + # Create a dummy Series instance with a non-existent primary prediction file + output = io.StringIO() + with redirect_stdout(output): + Series.load(h5_path=str(dummy_video_path), crown_name="nonexistent") + + # format file path string for assert statement + path_obj = Path(dummy_video_path) + parent_dir = path_obj.parent + filename_without_extension = path_obj.stem + new_file_path = parent_dir / filename_without_extension + + assert output.getvalue() == f"Crown prediction file not found: {new_file_path}.nonexistent.predictions.slp\n" + +def test_video_loading_error(tmp_path): + # Create a dummy Series instance with an invalid video file path + invalid_video_path = tmp_path / "invalid_video.mp4" + + output = io.StringIO() + with redirect_stdout(output): + Series.load(h5_path=str(invalid_video_path)) + + # Check if the correct error message is output + assert output.getvalue() == f"Error loading video file {invalid_video_path}: File not found\n" def test_series_name(dummy_series): expected_name = "dummy_video" # Based on the dummy_video_path fixture @@ -85,6 +148,16 @@ def test_expected_count(series_instance, csv_path): assert series_instance.expected_count == 10 +def test_expected_count_error(series_instance, tmp_path): + series_instance.csv_path = tmp_path / "invalid" + + output = io.StringIO() + with redirect_stdout(output): + series_instance.expected_count + # Check if the correct error message is output + assert output.getvalue() == "CSV path is not set or the file does not exist.\n" + + def test_qc_cylinder(series_instance, csv_path): series_instance.csv_path = csv_path assert series_instance.qc_fail == 0