Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Base extractor #367

Open
wants to merge 5 commits into
base: segmentation
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactored imagingextractor to use baseimagingextractormixin
pauladkisson committed Sep 30, 2024
commit a1f0c74bcb3af96e02f7d0816df7f750e177c872
71 changes: 71 additions & 0 deletions tests/mixins/base_extractor_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pytest
import numpy as np


class BaseExtractorMixin:
def test_get_image_size(self, extractor, expected_image_size):
image_size = extractor.get_image_size()
assert image_size == expected_image_size

def test_get_num_frames(self, extractor, expected_num_frames):
num_frames = extractor.get_num_frames()
assert num_frames == expected_num_frames

def test_get_sampling_frequency(self, extractor, expected_sampling_frequency):
sampling_frequency = extractor.get_sampling_frequency()
assert sampling_frequency == expected_sampling_frequency

@pytest.mark.parametrize("sampling_frequency", [1, 2, 3])
def test_frame_to_time_no_times(self, extractor, sampling_frequency):
extractor._times = None
extractor._sampling_frequency = sampling_frequency
times = extractor.frame_to_time(frames=[0, 1])
expected_times = np.array([0, 1]) / sampling_frequency
assert np.array_equal(times, expected_times)

def test_frame_to_time_with_times(self, extractor):
expected_times = np.array([0, 1])
extractor._times = expected_times
times = extractor.frame_to_time(frames=[0, 1])

assert np.array_equal(times, expected_times)

@pytest.mark.parametrize("sampling_frequency", [1, 2, 3])
def test_time_to_frame_no_times(self, extractor, sampling_frequency):
extractor._times = None
extractor._sampling_frequency = sampling_frequency
times = np.array([0, 1]) / sampling_frequency
frames = extractor.time_to_frame(times=times)
expected_frames = np.array([0, 1])
assert np.array_equal(frames, expected_frames)

def test_time_to_frame_with_times(self, extractor):
extractor._times = np.array([0, 1])
times = np.array([0, 1])
frames = extractor.time_to_frame(times=times)
expected_frames = np.array([0, 1])
assert np.array_equal(frames, expected_frames)

def test_set_times(self, extractor):
times = np.arange(extractor.get_num_frames())
extractor.set_times(times)
assert np.array_equal(extractor._times, times)

def test_set_times_invalid_length(self, extractor):
with pytest.raises(AssertionError):
extractor.set_times(np.arange(extractor.get_num_frames() + 1))

@pytest.mark.parametrize("times", [None, np.array([0, 1])])
def test_has_time_vector(self, times, extractor):
extractor._times = times
if times is None:
assert not extractor.has_time_vector()
else:
assert extractor.has_time_vector()

def test_copy_times(self, extractor, extractor2):
expected_times = np.arange(extractor.get_num_frames())
extractor._times = expected_times
extractor2.copy_times(extractor)
assert np.array_equal(extractor2._times, expected_times)
assert extractor2._times is not expected_times
80 changes: 15 additions & 65 deletions tests/mixins/imaging_extractor_mixin.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
import pytest
import numpy as np
from .base_extractor_mixin import BaseExtractorMixin


class ImagingExtractorMixin:
def test_get_image_size(self, imaging_extractor, expected_video):
image_size = imaging_extractor.get_image_size()
assert image_size == (expected_video.shape[1], expected_video.shape[2])
class ImagingExtractorMixin(BaseExtractorMixin):
@pytest.fixture(scope="function")
def extractor(self, imaging_extractor):
return imaging_extractor

def test_get_num_frames(self, imaging_extractor, expected_video):
num_frames = imaging_extractor.get_num_frames()
assert num_frames == expected_video.shape[0]
@pytest.fixture(scope="function")
def extractor2(self, imaging_extractor2):
return imaging_extractor2

def test_get_sampling_frequency(self, imaging_extractor, expected_sampling_frequency):
sampling_frequency = imaging_extractor.get_sampling_frequency()
assert sampling_frequency == expected_sampling_frequency
@pytest.fixture(scope="function")
def expected_image_size(self, expected_video):
return expected_video.shape[1], expected_video.shape[2]

@pytest.fixture(scope="function")
def expected_num_frames(self, expected_video):
return expected_video.shape[0]

def test_get_dtype(self, imaging_extractor, expected_video):
dtype = imaging_extractor.get_dtype()
@@ -60,61 +65,6 @@ def test_get_frames_invalid_frame_idxs(self, imaging_extractor):
with pytest.raises(AssertionError):
imaging_extractor.get_frames(frame_idxs=[0.5])

@pytest.mark.parametrize("sampling_frequency", [1, 2, 3])
def test_frame_to_time_no_times(self, imaging_extractor, sampling_frequency):
imaging_extractor._times = None
imaging_extractor._sampling_frequency = sampling_frequency
times = imaging_extractor.frame_to_time(frames=[0, 1])
expected_times = np.array([0, 1]) / sampling_frequency
assert np.array_equal(times, expected_times)

def test_frame_to_time_with_times(self, imaging_extractor):
expected_times = np.array([0, 1])
imaging_extractor._times = expected_times
times = imaging_extractor.frame_to_time(frames=[0, 1])

assert np.array_equal(times, expected_times)

@pytest.mark.parametrize("sampling_frequency", [1, 2, 3])
def test_time_to_frame_no_times(self, imaging_extractor, sampling_frequency):
imaging_extractor._times = None
imaging_extractor._sampling_frequency = sampling_frequency
times = np.array([0, 1]) / sampling_frequency
frames = imaging_extractor.time_to_frame(times=times)
expected_frames = np.array([0, 1])
assert np.array_equal(frames, expected_frames)

def test_time_to_frame_with_times(self, imaging_extractor):
imaging_extractor._times = np.array([0, 1])
times = np.array([0, 1])
frames = imaging_extractor.time_to_frame(times=times)
expected_frames = np.array([0, 1])
assert np.array_equal(frames, expected_frames)

def test_set_times(self, imaging_extractor):
times = np.arange(imaging_extractor.get_num_frames())
imaging_extractor.set_times(times)
assert np.array_equal(imaging_extractor._times, times)

def test_set_times_invalid_length(self, imaging_extractor):
with pytest.raises(AssertionError):
imaging_extractor.set_times(np.arange(imaging_extractor.get_num_frames() + 1))

@pytest.mark.parametrize("times", [None, np.array([0, 1])])
def test_has_time_vector(self, times, imaging_extractor):
imaging_extractor._times = times
if times is None:
assert not imaging_extractor.has_time_vector()
else:
assert imaging_extractor.has_time_vector()

def test_copy_times(self, imaging_extractor, imaging_extractor2):
expected_times = np.arange(imaging_extractor.get_num_frames())
imaging_extractor._times = expected_times
imaging_extractor2.copy_times(imaging_extractor)
assert np.array_equal(imaging_extractor2._times, expected_times)
assert imaging_extractor2._times is not expected_times

def test_eq(self, imaging_extractor, imaging_extractor2):
assert imaging_extractor == imaging_extractor2