diff --git a/src/conf/scenario2.yaml b/src/conf/scenario2.yaml index 1aae7ea..a3e6166 100644 --- a/src/conf/scenario2.yaml +++ b/src/conf/scenario2.yaml @@ -5,7 +5,6 @@ cache_dir: ${oc.env:PWD}/cache result_dir: ${oc.env:PWD}/results ignore_patterns: - "n_jobs" - - simulation: sim_params: sim_tr: 0.1 diff --git a/src/snkf/cli/main.py b/src/snkf/cli/main.py index e3e994e..6207f99 100755 --- a/src/snkf/cli/main.py +++ b/src/snkf/cli/main.py @@ -39,7 +39,7 @@ def main_app(cfg: DictConfig) -> None: logging.captureWarnings(True) cache_dir = Path(cfg.cache_dir or os.getcwd()) - hash_sim = hash_config(cfg.simulation, cfg.ignore_patterns) + hash_sim = hash_config(cfg.simulation, *getattr(cfg, "ignore_patterns", [])) sim_file = cache_dir / f"{hash_sim}.pkl" # 1. Simulate (use cache if available) with PerfLogger(logger, name="Simulation"): diff --git a/src/snkf/handlers/__init__.py b/src/snkf/handlers/__init__.py index 54eaf89..392fc81 100644 --- a/src/snkf/handlers/__init__.py +++ b/src/snkf/handlers/__init__.py @@ -8,6 +8,7 @@ import pkgutil from .base import ( + AbstractHandler, H, handler, get_handler, @@ -28,6 +29,7 @@ "handler", "get_handler", "list_handlers", - "requires_field" "AbstractHandler", + "requires_field", + "AbstractHandler", "HandlerChain", ] diff --git a/src/snkf/handlers/acquisition/workers.py b/src/snkf/handlers/acquisition/workers.py index d0e8ccc..274b904 100644 --- a/src/snkf/handlers/acquisition/workers.py +++ b/src/snkf/handlers/acquisition/workers.py @@ -19,7 +19,7 @@ from tqdm.auto import tqdm from snkf.simulation import SimData, LazySimArray - +from snkf.utils import DuplicateFilter from ._tools import TrajectoryGeneratorType # from mrinufft import get_operator @@ -265,9 +265,10 @@ def work_generator( data_acq: np.ndarray | LazySimArray, kspace_bulk_gen: Generator ) -> Generator[tuple, None, None]: """Set up all the work.""" - for sim_frame_idx, shot_batch, shot_pos in kspace_bulk_gen: - sim_frame = np.complex64(data_acq[sim_frame_idx]) # heavy to compute - yield sim_frame, shot_batch, shot_pos + with DuplicateFilter(logging.getLogger("simulation")): + for sim_frame_idx, shot_batch, shot_pos in kspace_bulk_gen: + sim_frame = np.complex64(data_acq[sim_frame_idx]) # heavy to compute + yield sim_frame, shot_batch, shot_pos def _single_worker( @@ -278,7 +279,7 @@ def _single_worker( smaps: np.ndarray, ) -> tuple[np.ndarray, tuple[int, int], np.ndarray]: """Perform a shot acquisition.""" - with (warnings.catch_warnings(),): + with warnings.catch_warnings(): warnings.filterwarnings( "ignore", category=UserWarning, diff --git a/src/snkf/handlers/base.py b/src/snkf/handlers/base.py index 47fbd7c..a0530f1 100644 --- a/src/snkf/handlers/base.py +++ b/src/snkf/handlers/base.py @@ -12,7 +12,7 @@ from typing_extensions import TypedDict import yaml -from ..simulation import SimData, SimParams +from ..simulation import SimData, SimParams, UndefinedArrayError CallbackType = Callable[[SimData, SimData], Any] @@ -390,7 +390,9 @@ def class_wrapper(cls: type[AbstractHandler]) -> type[AbstractHandler]: @functools.wraps(old_handle) def wrap_handler(self: AbstractHandler, sim: SimData) -> SimData: - if getattr(sim, field_name, None) is None: + try: + getattr(sim, field_name) + except UndefinedArrayError as e: if callable(factory): setattr(sim, field_name, factory(sim)) else: @@ -398,7 +400,7 @@ def wrap_handler(self: AbstractHandler, sim: SimData) -> SimData: f"'{field_name}' is missing in simulation" "and no way of computing it provided." ) - raise ValueError(msg) + raise ValueError(msg) from e return old_handle(self, sim) diff --git a/src/snkf/handlers/phantom/phantom.py b/src/snkf/handlers/phantom/phantom.py index b2d333f..e7557dd 100644 --- a/src/snkf/handlers/phantom/phantom.py +++ b/src/snkf/handlers/phantom/phantom.py @@ -203,6 +203,14 @@ def _handle(self, sim: SimData) -> SimData: os.makedirs(static_path.parent, exist_ok=True) np.save(static_path, static_vol) + roi: NDArray + if os.path.exists(roi_path) and not self.force: + roi = np.load(roi_path) + else: + roi = self._make_roi(sim) + np.save(roi_path, roi) + + # update the simulation if -1 in sim.shape: old_shape = sim.shape sim._meta.shape = static_vol.shape @@ -224,24 +232,15 @@ def _handle(self, sim: SimData) -> SimData: self.log.warning(f"sim.fov was {sim.fov}, it is now {new_fov}.") sim._meta.fov = tuple(new_fov) - roi: np.ndarray - data_ref: LazySimArray | NDArray - if os.path.exists(roi_path) and not self.force: - roi = np.load(roi_path) - else: - roi = self._make_roi(sim) - np.save(roi_path, roi) + sim.static_vol = static_vol + sim.roi = roi - if sim.lazy and sim.static_vol is not None: - data_ref = LazySimArray(static_vol, sim.n_frames) - elif sim.static_vol is not None: - data_ref = np.repeat(static_vol[None, ...], sim.n_frames, axis=0) + # create the data ref field. + if sim.lazy: + sim.data_ref = LazySimArray(static_vol, sim.n_frames) else: - raise ValueError("Could not initialize data_ref ") + sim.data_ref = np.repeat(static_vol[None, ...], sim.n_frames, axis=0) - sim.static_vol = static_vol - sim.roi = roi - sim.data_ref = data_ref self.log.debug(f"roi shape: {sim.roi.shape}") self.log.debug(f"data_ref shape: {sim.data_ref.shape}") return sim @@ -267,12 +266,16 @@ def _make_roi(self, sim: SimData) -> np.ndarray: BRAINWEB_OCCIPITAL_ROI, ) + shape: tuple[int, ...] | None = sim.shape + if shape == (-1, -1, -1): + shape = None roi = get_mri( self.sub_id, brainweb_dir=self.brainweb_folder, contrast="fuzzy", - shape=sim.shape, + shape=shape, bbox=self.bbox, + output_res=self.res, )[..., 2] occ_roi = BRAINWEB_OCCIPITAL_ROI.copy() @@ -302,8 +305,8 @@ def _make_roi(self, sim: SimData) -> np.ndarray: ) occ_roi["center"] = ( occ_roi["center"][0] - scaled_bbox[0], - occ_roi["center"][1] - scaled_bbox[1], - occ_roi["center"][2] - scaled_bbox[2], + occ_roi["center"][1] - scaled_bbox[2], + occ_roi["center"][2] - scaled_bbox[4], ) self.log.debug("ROI shape is ", occ_roi) roi_zoom = np.array(roi.shape) / np.array(occ_roi["shape"]) diff --git a/src/snkf/reconstructors/__init__.py b/src/snkf/reconstructors/__init__.py index 754bb76..bbd2c5f 100644 --- a/src/snkf/reconstructors/__init__.py +++ b/src/snkf/reconstructors/__init__.py @@ -1,5 +1,5 @@ """Reconstructor interfaces for the simulator.""" - +from .base import get_reconstructor from .pysap import ( SequentialReconstructor, @@ -9,7 +9,8 @@ __all__ = [ "RECONSTRUCTOR", - "get_reconstructor" "BaseReconstructor", + "get_reconstructor", + "BaseReconstructor", "SequentialReconstructor", "ZeroFilledReconstructor", "LowRankPlusSparseReconstructor", diff --git a/src/snkf/simulation/__init__.py b/src/snkf/simulation/__init__.py index 0c13b46..75517f4 100644 --- a/src/snkf/simulation/__init__.py +++ b/src/snkf/simulation/__init__.py @@ -2,11 +2,12 @@ This module contains the core simulation objects definition. """ -from .simulation import SimData, SimParams, LazySimArray +from .simulation import SimData, SimParams, LazySimArray, UndefinedArrayError __all__ = [ "SimData", "SimParams", "LazySimArray", + "UndefinedArrayError", ] diff --git a/src/snkf/simulation/simulation.py b/src/snkf/simulation/simulation.py index d83662b..a88fe4e 100644 --- a/src/snkf/simulation/simulation.py +++ b/src/snkf/simulation/simulation.py @@ -5,7 +5,7 @@ """ from __future__ import annotations -from typing import Literal, Any, Union +from typing import Literal, Any, Union, Sequence from dataclasses import InitVar, dataclass, field import copy import pickle @@ -24,6 +24,10 @@ OptionalArray = Union[NDArray, None] +class UndefinedArrayError(ValueError): + """Raise when an array is undefined.""" + + @dataclass class SimParams: """Simulation metadata.""" @@ -44,11 +48,11 @@ class SimParams: """Field of view of the volume in mm""" fov: tuple[float, ...] = field(init=False) - def __post_init__(self, fov_init: tuple[float, ...] | float) -> None: + def __post_init__(self, fov_init: Sequence[float] | float) -> None: if isinstance(fov_init, float): self.fov = (fov_init,) * len(self.shape) - elif isinstance(fov_init, tuple): - self.fov = fov_init + elif isinstance(fov_init, Sequence): + self.fov = tuple(fov_init) class SimData: @@ -129,7 +133,7 @@ def __init__( rng=rng, extra_infos=extra_infos, ) - self._static_vol: NDArray | None + self._static_vol: NDArray | None = None self._roi: NDArray | None = None self._data_ref: NDArray | None | LazySimArray = None self._data_acq: NDArray | None | LazySimArray = None @@ -144,7 +148,7 @@ def static_vol(self) -> NDArray: """Static volume.""" if self._static_vol is not None: return self._static_vol - raise ValueError("static_vol is not defined") + raise UndefinedArrayError("static_vol is not defined") @static_vol.setter def static_vol(self, value: NDArray) -> None: @@ -155,7 +159,7 @@ def kspace_data(self) -> NDArray: """Static volume.""" if self._kspace_data is not None: return self._kspace_data - raise ValueError("static_vol is not defined") + raise UndefinedArrayError("static_vol is not defined") @kspace_data.setter def kspace_data(self, value: NDArray) -> None: @@ -166,7 +170,7 @@ def kspace_mask(self) -> NDArray: """Static volume.""" if self._kspace_mask is not None: return self._kspace_mask - raise ValueError("static_vol is not defined") + raise UndefinedArrayError("static_vol is not defined") @kspace_mask.setter def kspace_mask(self, value: NDArray) -> None: @@ -177,7 +181,7 @@ def data_ref(self) -> NDArray | LazySimArray: """Static volume.""" if self._data_ref is not None: return self._data_ref - raise ValueError("data_ref is not defined") + raise UndefinedArrayError("data_ref is not defined") @data_ref.setter def data_ref(self, value: NDArray) -> None: @@ -188,7 +192,7 @@ def data_acq(self) -> NDArray | LazySimArray: """Acquired Volume.""" if self._data_acq is not None: return self._data_acq - raise ValueError("data_acq is not defined") + raise UndefinedArrayError("data_acq is not defined") @data_acq.setter def data_acq(self, value: NDArray) -> None: @@ -199,7 +203,7 @@ def roi(self) -> NDArray: """Reference data volume.""" if self._roi is not None: return self._roi - raise ValueError("static_vol is not defined") + raise UndefinedArrayError("roi is not defined") @roi.setter def roi(self, value: NDArray) -> None: diff --git a/src/snkf/utils/__init__.py b/src/snkf/utils/__init__.py index f2c21b5..f4f2401 100644 --- a/src/snkf/utils/__init__.py +++ b/src/snkf/utils/__init__.py @@ -1,6 +1,6 @@ """Utilities tools for snkf.""" from .typing import RngType, AnyShape -from .utils import validate_rng, cplx_type, real_type +from .utils import validate_rng, cplx_type, real_type, DuplicateFilter __all__ = [ "AnyShape", @@ -8,4 +8,5 @@ "validate_rng", "cplx_type", "real_type", + "DuplicateFilter", ] diff --git a/src/snkf/utils/utils.py b/src/snkf/utils/utils.py index 8c9dc78..766fd0f 100644 --- a/src/snkf/utils/utils.py +++ b/src/snkf/utils/utils.py @@ -3,11 +3,38 @@ import logging import numpy as np from numpy.typing import DTypeLike +from typing import Any from snkf.utils.typing import RngType sim_logger = logging.getLogger("simulation") +class DuplicateFilter(logging.Filter): + """ + Filters away duplicate log messages. + + https://stackoverflow.com/a/60462619 + """ + + def __init__(self, logger: logging.Logger): + self.msgs: set[str] = set() + self.logger = logger + + def filter(self, record: logging.LogRecord) -> bool: + """Filter duplicate records.""" + msg = str(record.msg) + is_duplicate = msg in self.msgs + if not is_duplicate: + self.msgs.add(msg) + return not is_duplicate + + def __enter__(self): + self.logger.addFilter(self) + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): + self.logger.removeFilter(self) + + def validate_rng(rng: RngType = None) -> np.random.Generator: """Validate Random Number Generator.""" if isinstance(rng, (int, list)): # TODO use pattern matching @@ -29,12 +56,18 @@ def cplx_type(dtype: DTypeLike) -> DTypeLike: np.complex64 """ d = np.dtype(dtype) - if d.type is np.float64: + if d.type is np.complex64: + return np.complex64 + elif d.type is np.complex128: + return np.complex128 + elif d.type is np.float64: return np.complex128 elif d.type is np.float32: return np.complex64 else: - sim_logger.warning("unsupported dtype, use matching complex64", stack_info=True) + sim_logger.warning( + f"unsupported dtype {d}, using default complex64", stack_info=True + ) return np.complex64 @@ -51,8 +84,14 @@ def real_type( d = np.dtype(dtype) if d.type is np.complex64: return np.dtype("float32") + elif d.type is np.float32: + return np.dtype("float32") elif d.type is np.complex128: return np.dtype("float64") + elif d.type is np.float64: + return np.dtype("float64") else: - sim_logger.warning("unsupported dtype, use matching float32", stack_info=True) + sim_logger.warning( + f"unsupported dtype ({d}) using default float32", stack_info=True + ) return np.dtype("float32")