From b285ecca31c79ea19366689e4090cd22261cf8ce Mon Sep 17 00:00:00 2001 From: Johannes Hofmanninger Date: Thu, 15 Jun 2023 08:01:23 +0200 Subject: [PATCH] test with path provided --- tests/test_mask.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/tests/test_mask.py b/tests/test_mask.py index 2e89597..825c1ba 100644 --- a/tests/test_mask.py +++ b/tests/test_mask.py @@ -1,10 +1,11 @@ import os +import shutil import numpy as np -import pydicom as pyd import pytest +import torch -from lungmask.mask import LMInferer, apply, apply_fused +from lungmask.mask import MODEL_URLS, LMInferer from lungmask.utils import read_dicoms @@ -13,8 +14,32 @@ def fixture_testvol(): return read_dicoms(os.path.join(os.path.dirname(__file__), "testdata"))[0] -def test_LMInferer(fixture_testvol): +@pytest.fixture(scope="session") +def fixture_weights_path_R231(tmpdir_factory): + # we make sure the model is there + torch.hub.load_state_dict_from_url( + MODEL_URLS["R231"][0], progress=True, map_location=torch.device("cpu") + ) + modelbasename = os.path.basename(MODEL_URLS["R231"][0]) + modelpath = os.path.join(torch.hub.get_dir(), "checkpoints", modelbasename) + tmppath = str(tmpdir_factory.mktemp("weights").join(modelbasename)) + shutil.copy(modelpath, tmppath) + return tmppath + + +def test_LMInferer(fixture_testvol, fixture_weights_path_R231): + inferer = LMInferer( + force_cpu=True, + tqdm_disable=True, + ) + res = inferer.apply(fixture_testvol) + assert np.all(np.unique(res, return_counts=True)[1] == [423000, 64752, 36536]) + + # here, we provide a path to the R231 weights but specify LTRCLobes (6 channel) as modelname + # The modelname should be ignored and a 3 channel output should be generated inferer = LMInferer( + modelname="LTRCLobes", + modelpath=fixture_weights_path_R231, force_cpu=True, tqdm_disable=True, )