From b64a724db78fd0b694fb5d98028cc5300745d2b6 Mon Sep 17 00:00:00 2001 From: Rosie Wood Date: Mon, 2 Sep 2024 14:35:23 +0100 Subject: [PATCH] add tests for occlusion --- tests/test_post_processing/test_occlusion.py | 26 +++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/tests/test_post_processing/test_occlusion.py b/tests/test_post_processing/test_occlusion.py index bc6b8d2d..69991b83 100644 --- a/tests/test_post_processing/test_occlusion.py +++ b/tests/test_post_processing/test_occlusion.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -from pathlib import Path +import pathlib from typing import Callable import pandas as pd @@ -12,16 +12,18 @@ from torchvision import transforms from mapreader.process.occlusion_analysis import OcclusionAnalyzer +from mapreader.utils.load_frames import eval_dataframe @pytest.fixture def sample_dir(): - return Path(__file__).resolve().parent.parent / "sample_files" + return pathlib.Path(__file__).resolve().parent.parent / "sample_files" @pytest.fixture def patch_df(sample_dir): - return pd.read_csv(f"{sample_dir}/post_processing_patch_df.csv", index_col=0) + df = pd.read_csv(f"{sample_dir}/post_processing_patch_df.csv", index_col=0) + return eval_dataframe(df) @pytest.fixture @@ -46,6 +48,22 @@ def test_init_path(sample_dir, model): assert isinstance(analyzer.patch_df.iloc[0]["pixel_bounds"], tuple) +def test_init_pathlib(sample_dir, model): + patch_df = pathlib.Path(f"{sample_dir}/post_processing_patch_df.csv") + analyzer = OcclusionAnalyzer(patch_df, model) + assert isinstance(analyzer, OcclusionAnalyzer) + assert len(analyzer) == 81 + assert isinstance(analyzer.patch_df.iloc[0]["pixel_bounds"], tuple) + + +def test_init_geojson(sample_dir, model): + patch_df = f"{sample_dir}/post_processing_patch_df.geojson" + analyzer = OcclusionAnalyzer(patch_df, model) + assert isinstance(analyzer, OcclusionAnalyzer) + assert len(analyzer) == 81 + assert isinstance(analyzer.patch_df.iloc[0]["pixel_bounds"], tuple) + + def test_init_dataframe_transform(patch_df, model): transform = transforms.ToTensor() analyzer = OcclusionAnalyzer(patch_df, model, transform=transform) @@ -62,6 +80,8 @@ def test_init_fake_path_error(model): def test_init_error(model): + with pytest.raises(ValueError, match="path to a CSV/TSV/etc or geojson"): + OcclusionAnalyzer("fake.file", model) with pytest.raises(ValueError, match="as a string"): OcclusionAnalyzer({"image_id": "patch"}, model)