Skip to content

Commit

Permalink
add tests for occlusion
Browse files Browse the repository at this point in the history
  • Loading branch information
rwood-97 committed Sep 2, 2024
1 parent 82c2214 commit b64a724
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions tests/test_post_processing/test_occlusion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import os
from pathlib import Path
import pathlib
from typing import Callable

import pandas as pd
Expand All @@ -12,16 +12,18 @@
from torchvision import transforms

from mapreader.process.occlusion_analysis import OcclusionAnalyzer
from mapreader.utils.load_frames import eval_dataframe


@pytest.fixture
def sample_dir():
return Path(__file__).resolve().parent.parent / "sample_files"
return pathlib.Path(__file__).resolve().parent.parent / "sample_files"


@pytest.fixture
def patch_df(sample_dir):
return pd.read_csv(f"{sample_dir}/post_processing_patch_df.csv", index_col=0)
df = pd.read_csv(f"{sample_dir}/post_processing_patch_df.csv", index_col=0)
return eval_dataframe(df)


@pytest.fixture
Expand All @@ -46,6 +48,22 @@ def test_init_path(sample_dir, model):
assert isinstance(analyzer.patch_df.iloc[0]["pixel_bounds"], tuple)


def test_init_pathlib(sample_dir, model):
patch_df = pathlib.Path(f"{sample_dir}/post_processing_patch_df.csv")
analyzer = OcclusionAnalyzer(patch_df, model)
assert isinstance(analyzer, OcclusionAnalyzer)
assert len(analyzer) == 81
assert isinstance(analyzer.patch_df.iloc[0]["pixel_bounds"], tuple)


def test_init_geojson(sample_dir, model):
patch_df = f"{sample_dir}/post_processing_patch_df.geojson"
analyzer = OcclusionAnalyzer(patch_df, model)
assert isinstance(analyzer, OcclusionAnalyzer)
assert len(analyzer) == 81
assert isinstance(analyzer.patch_df.iloc[0]["pixel_bounds"], tuple)


def test_init_dataframe_transform(patch_df, model):
transform = transforms.ToTensor()
analyzer = OcclusionAnalyzer(patch_df, model, transform=transform)
Expand All @@ -62,6 +80,8 @@ def test_init_fake_path_error(model):


def test_init_error(model):
with pytest.raises(ValueError, match="path to a CSV/TSV/etc or geojson"):
OcclusionAnalyzer("fake.file", model)
with pytest.raises(ValueError, match="as a string"):
OcclusionAnalyzer({"image_id": "patch"}, model)

Expand Down

0 comments on commit b64a724

Please sign in to comment.