Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions medsegpy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ class Config(object):
VAL_DATASET = ""
TEST_DATASET = ""

# Inference only, with no evaluation
INFERENCE_ONLY = False

# Cross-Validation-Parameters
USE_CROSS_VALIDATION = False
CV_FILE = ""
Expand Down Expand Up @@ -550,6 +553,7 @@ def summary(self, additional_vars=None):
"TRAIN_DATASET",
"VAL_DATASET",
"TEST_DATASET",
"INFERENCE_ONLY",
"",
"CATEGORIES",
"",
Expand Down
19 changes: 16 additions & 3 deletions medsegpy/data/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from medsegpy.config import Config
from medsegpy.modeling import Model
from medsegpy.utils import env
from pydicom import dcmread

from .data_utils import add_background_labels, collect_mask, compute_patches
from .transforms import apply_transform_gens, build_preprocessing
Expand Down Expand Up @@ -242,8 +243,13 @@ def _load_input(self, dataset_dict):
if self._cached_data is not None:
image, mask = self._cached_data[(image_file, sem_seg_file)]
else:
with h5py.File(image_file, "r") as f:
image = f["data"][:]
if image_file.endswith('.dcm'):
ds = dcmread(image_file)
image = ds.pixel_array
else:
with h5py.File(image_file, "r") as f:
image = f["data"][:]

if image.shape[-1] != 1:
image = image[..., np.newaxis]

Expand Down Expand Up @@ -323,7 +329,11 @@ def _restructure_data(self, vols: Sequence[np.ndarray]):
axes = (1, 2, 0)
if v.ndim > 3:
axes = axes + tuple(i for i in range(3, v.ndim))
new_vols.append(v.transpose(axes))
# new_vols.append(v.transpose(axes))
if v.ndim == 1:
new_vols.append(v)
else:
new_vols.append(v.transpose(axes))
vols = (np.squeeze(v) for v in new_vols)
return tuple(vols)

Expand All @@ -337,6 +347,9 @@ def inference(self, model: Model, **kwargs):

workers = kwargs.pop("workers", self._cfg.NUM_WORKERS)
use_multiprocessing = kwargs.pop("use_multiprocessing", workers > 1)

kwargs["batch_size"] = 1

for scan_id in scan_ids:
self._dataset_dicts = scan_to_dict_mapping[scan_id]

Expand Down
5 changes: 4 additions & 1 deletion medsegpy/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ def inference_on_dataset(

eval_start = time.perf_counter()
logger.info("Begin evaluation...")
results = {e.__class__.__name__: e.evaluate() for e in evaluator}
if any([e._config.INFERENCE_ONLY for e in evaluator]):
results = None
else:
results = {e.__class__.__name__: e.evaluate() for e in evaluator}
total_eval_time = time.perf_counter() - eval_start
logger.info("Time Elapsed: {:.4f} seconds".format(total_compute_time + total_eval_time))
# An evaluator may return None when not in main process.
Expand Down
19 changes: 10 additions & 9 deletions medsegpy/evaluation/sem_seg_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,16 @@ def eval_single_scan(self, input, output, labels, time_elapsed):

metrics_kwargs = {"spacing": spacing} if spacing is not None else {}

summary = metrics_manager(
scan_id, y_true=y_true, y_pred=labels, x=x, runtime=time_elapsed, **metrics_kwargs
)

logger_info_str = "Scan #{:03d} (name = {}, {:0.2f}s) = {}".format(
scan_cnt, scan_id, time_elapsed, summary
)
self._results_str = self._results_str + logger_info_str + "\n"
logger.info(logger_info_str)
if not self._config.INFERENCE_ONLY:
summary = metrics_manager(
scan_id, y_true=y_true, y_pred=labels, x=x, runtime=time_elapsed, **metrics_kwargs
)

logger_info_str = "Scan #{:03d} (name = {}, {:0.2f}s) = {}".format(
scan_cnt, scan_id, time_elapsed, summary
)
self._results_str = self._results_str + logger_info_str + "\n"
logger.info(logger_info_str)

if output_dir and save_raw_data:
save_name = "{}/{}.pred".format(output_dir, scan_id)
Expand Down
4 changes: 4 additions & 0 deletions medsegpy/modeling/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def inference_generator(
workers=1,
use_multiprocessing=False,
verbose=0,
batch_size=None
):
return self.inference_generator_static(
self, generator, steps, max_queue_size, workers, use_multiprocessing, verbose
Expand All @@ -58,6 +59,7 @@ def inference_generator_static(
workers=1,
use_multiprocessing=False,
verbose=0,
batch_size=None
):
"""Generates predictions for the input samples from a data generator
and returns inputs, ground truth, and predictions.
Expand Down Expand Up @@ -116,6 +118,7 @@ def inference_generator_static(
workers=workers,
use_multiprocessing=use_multiprocessing,
verbose=verbose,
batch_size=batch_size
)
else:
return model._inference_generator_tf1(
Expand Down Expand Up @@ -295,6 +298,7 @@ def _inference_generator_tf2(
batch_x, batch_y, batch_x_raw = _extract_inference_inputs(next(iterator))
# tmp_batch_outputs = predict_function(iterator)
tmp_batch_outputs = model.predict(batch_x)

if data_handler.should_sync:
context.async_wait() # noqa: F821
batch_outputs = tmp_batch_outputs # No error, now safe to assign.
Expand Down