Skip to content

Commit

Permalink
dev(narugo): add test for aicorrupt seq
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Jan 26, 2024
1 parent 8582561 commit 0db6c82
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions test/corrupt/test_aicorrupt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import glob
import os.path

import pytest
from hbutils.testing import isolated_directory

Expand All @@ -10,11 +13,21 @@ def aicorrupt_o():
return get_testfile('aicorrupt', 'o')


@pytest.fixture()
def aicorrupt_o_files(aicorrupt_o):
return glob.glob(os.path.join(aicorrupt_o, '*.png'))


@pytest.fixture()
def aicorrupt_x():
return get_testfile('aicorrupt', 'x')


@pytest.fixture()
def aicorrupt_x_files(aicorrupt_x):
return glob.glob(os.path.join(aicorrupt_x, '*.png'))


@pytest.fixture()
def aicorrupt_metrics():
return AICorruptMetrics()
Expand All @@ -27,6 +40,19 @@ def test_score(self, aicorrupt_o, aicorrupt_x, aicorrupt_metrics):
assert aicorrupt_metrics.score(aicorrupt_o) >= 0.97
assert aicorrupt_metrics.score(aicorrupt_x) < 0.05

def test_score_files(self, aicorrupt_o_files, aicorrupt_x_files, aicorrupt_metrics):
assert aicorrupt_metrics.score(aicorrupt_o_files) >= 0.97
assert aicorrupt_metrics.score(aicorrupt_x_files) < 0.05

def test_score_files_seq(self, aicorrupt_o_files, aicorrupt_x_files, aicorrupt_metrics):
seq = aicorrupt_metrics.score(aicorrupt_o_files, mode='seq')
assert seq.shape == (len(aicorrupt_o_files),)
assert seq.mean().item() >= 0.97

seq = aicorrupt_metrics.score(aicorrupt_x_files, mode='seq')
assert seq.shape == (len(aicorrupt_x_files),)
assert seq.mean().item() < 0.05

@isolated_directory()
def test_aicorrupt_failed(self, aicorrupt_metrics):
with pytest.raises(FileNotFoundError):
Expand Down

0 comments on commit 0db6c82

Please sign in to comment.