Skip to content

Commit

Permalink
bug fixing
Browse files Browse the repository at this point in the history
  • Loading branch information
paquiteau committed Jan 23, 2024
1 parent a1bbee2 commit e6493b9
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 47 deletions.
1 change: 0 additions & 1 deletion src/conf/scenario2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ cache_dir: ${oc.env:PWD}/cache
result_dir: ${oc.env:PWD}/results
ignore_patterns:
- "n_jobs"
-
simulation:
sim_params:
sim_tr: 0.1
Expand Down
2 changes: 1 addition & 1 deletion src/snkf/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def main_app(cfg: DictConfig) -> None:
logging.captureWarnings(True)

cache_dir = Path(cfg.cache_dir or os.getcwd())
hash_sim = hash_config(cfg.simulation, cfg.ignore_patterns)
hash_sim = hash_config(cfg.simulation, *getattr(cfg, "ignore_patterns", []))
sim_file = cache_dir / f"{hash_sim}.pkl"
# 1. Simulate (use cache if available)
with PerfLogger(logger, name="Simulation"):
Expand Down
4 changes: 3 additions & 1 deletion src/snkf/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pkgutil

from .base import (
AbstractHandler,
H,
handler,
get_handler,
Expand All @@ -28,6 +29,7 @@
"handler",
"get_handler",
"list_handlers",
"requires_field" "AbstractHandler",
"requires_field",
"AbstractHandler",
"HandlerChain",
]
11 changes: 6 additions & 5 deletions src/snkf/handlers/acquisition/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tqdm.auto import tqdm

from snkf.simulation import SimData, LazySimArray

from snkf.utils import DuplicateFilter
from ._tools import TrajectoryGeneratorType

# from mrinufft import get_operator
Expand Down Expand Up @@ -265,9 +265,10 @@ def work_generator(
data_acq: np.ndarray | LazySimArray, kspace_bulk_gen: Generator
) -> Generator[tuple, None, None]:
"""Set up all the work."""
for sim_frame_idx, shot_batch, shot_pos in kspace_bulk_gen:
sim_frame = np.complex64(data_acq[sim_frame_idx]) # heavy to compute
yield sim_frame, shot_batch, shot_pos
with DuplicateFilter(logging.getLogger("simulation")):
for sim_frame_idx, shot_batch, shot_pos in kspace_bulk_gen:
sim_frame = np.complex64(data_acq[sim_frame_idx]) # heavy to compute
yield sim_frame, shot_batch, shot_pos


def _single_worker(
Expand All @@ -278,7 +279,7 @@ def _single_worker(
smaps: np.ndarray,
) -> tuple[np.ndarray, tuple[int, int], np.ndarray]:
"""Perform a shot acquisition."""
with (warnings.catch_warnings(),):
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
Expand Down
8 changes: 5 additions & 3 deletions src/snkf/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing_extensions import TypedDict
import yaml

from ..simulation import SimData, SimParams
from ..simulation import SimData, SimParams, UndefinedArrayError


CallbackType = Callable[[SimData, SimData], Any]
Expand Down Expand Up @@ -390,15 +390,17 @@ def class_wrapper(cls: type[AbstractHandler]) -> type[AbstractHandler]:

@functools.wraps(old_handle)
def wrap_handler(self: AbstractHandler, sim: SimData) -> SimData:
if getattr(sim, field_name, None) is None:
try:
getattr(sim, field_name)
except UndefinedArrayError as e:
if callable(factory):
setattr(sim, field_name, factory(sim))
else:
msg = (
f"'{field_name}' is missing in simulation"
"and no way of computing it provided."
)
raise ValueError(msg)
raise ValueError(msg) from e

return old_handle(self, sim)

Expand Down
39 changes: 21 additions & 18 deletions src/snkf/handlers/phantom/phantom.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,14 @@ def _handle(self, sim: SimData) -> SimData:
os.makedirs(static_path.parent, exist_ok=True)
np.save(static_path, static_vol)

roi: NDArray
if os.path.exists(roi_path) and not self.force:
roi = np.load(roi_path)
else:
roi = self._make_roi(sim)
np.save(roi_path, roi)

# update the simulation
if -1 in sim.shape:
old_shape = sim.shape
sim._meta.shape = static_vol.shape
Expand All @@ -224,24 +232,15 @@ def _handle(self, sim: SimData) -> SimData:
self.log.warning(f"sim.fov was {sim.fov}, it is now {new_fov}.")
sim._meta.fov = tuple(new_fov)

roi: np.ndarray
data_ref: LazySimArray | NDArray
if os.path.exists(roi_path) and not self.force:
roi = np.load(roi_path)
else:
roi = self._make_roi(sim)
np.save(roi_path, roi)
sim.static_vol = static_vol
sim.roi = roi

if sim.lazy and sim.static_vol is not None:
data_ref = LazySimArray(static_vol, sim.n_frames)
elif sim.static_vol is not None:
data_ref = np.repeat(static_vol[None, ...], sim.n_frames, axis=0)
# create the data ref field.
if sim.lazy:
sim.data_ref = LazySimArray(static_vol, sim.n_frames)
else:
raise ValueError("Could not initialize data_ref ")
sim.data_ref = np.repeat(static_vol[None, ...], sim.n_frames, axis=0)

sim.static_vol = static_vol
sim.roi = roi
sim.data_ref = data_ref
self.log.debug(f"roi shape: {sim.roi.shape}")
self.log.debug(f"data_ref shape: {sim.data_ref.shape}")
return sim
Expand All @@ -267,12 +266,16 @@ def _make_roi(self, sim: SimData) -> np.ndarray:
BRAINWEB_OCCIPITAL_ROI,
)

shape: tuple[int, ...] | None = sim.shape
if shape == (-1, -1, -1):
shape = None
roi = get_mri(
self.sub_id,
brainweb_dir=self.brainweb_folder,
contrast="fuzzy",
shape=sim.shape,
shape=shape,
bbox=self.bbox,
output_res=self.res,
)[..., 2]

occ_roi = BRAINWEB_OCCIPITAL_ROI.copy()
Expand Down Expand Up @@ -302,8 +305,8 @@ def _make_roi(self, sim: SimData) -> np.ndarray:
)
occ_roi["center"] = (
occ_roi["center"][0] - scaled_bbox[0],
occ_roi["center"][1] - scaled_bbox[1],
occ_roi["center"][2] - scaled_bbox[2],
occ_roi["center"][1] - scaled_bbox[2],
occ_roi["center"][2] - scaled_bbox[4],
)
self.log.debug("ROI shape is ", occ_roi)
roi_zoom = np.array(roi.shape) / np.array(occ_roi["shape"])
Expand Down
5 changes: 3 additions & 2 deletions src/snkf/reconstructors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Reconstructor interfaces for the simulator."""

from .base import get_reconstructor

from .pysap import (
SequentialReconstructor,
Expand All @@ -9,7 +9,8 @@

__all__ = [
"RECONSTRUCTOR",
"get_reconstructor" "BaseReconstructor",
"get_reconstructor",
"BaseReconstructor",
"SequentialReconstructor",
"ZeroFilledReconstructor",
"LowRankPlusSparseReconstructor",
Expand Down
3 changes: 2 additions & 1 deletion src/snkf/simulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
This module contains the core simulation objects definition.
"""
from .simulation import SimData, SimParams, LazySimArray
from .simulation import SimData, SimParams, LazySimArray, UndefinedArrayError


__all__ = [
"SimData",
"SimParams",
"LazySimArray",
"UndefinedArrayError",
]
26 changes: 15 additions & 11 deletions src/snkf/simulation/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from __future__ import annotations
from typing import Literal, Any, Union
from typing import Literal, Any, Union, Sequence
from dataclasses import InitVar, dataclass, field
import copy
import pickle
Expand All @@ -24,6 +24,10 @@
OptionalArray = Union[NDArray, None]


class UndefinedArrayError(ValueError):
"""Raise when an array is undefined."""


@dataclass
class SimParams:
"""Simulation metadata."""
Expand All @@ -44,11 +48,11 @@ class SimParams:
"""Field of view of the volume in mm"""
fov: tuple[float, ...] = field(init=False)

def __post_init__(self, fov_init: tuple[float, ...] | float) -> None:
def __post_init__(self, fov_init: Sequence[float] | float) -> None:
if isinstance(fov_init, float):
self.fov = (fov_init,) * len(self.shape)
elif isinstance(fov_init, tuple):
self.fov = fov_init
elif isinstance(fov_init, Sequence):
self.fov = tuple(fov_init)


class SimData:
Expand Down Expand Up @@ -129,7 +133,7 @@ def __init__(
rng=rng,
extra_infos=extra_infos,
)
self._static_vol: NDArray | None
self._static_vol: NDArray | None = None
self._roi: NDArray | None = None
self._data_ref: NDArray | None | LazySimArray = None
self._data_acq: NDArray | None | LazySimArray = None
Expand All @@ -144,7 +148,7 @@ def static_vol(self) -> NDArray:
"""Static volume."""
if self._static_vol is not None:
return self._static_vol
raise ValueError("static_vol is not defined")
raise UndefinedArrayError("static_vol is not defined")

@static_vol.setter
def static_vol(self, value: NDArray) -> None:
Expand All @@ -155,7 +159,7 @@ def kspace_data(self) -> NDArray:
"""Static volume."""
if self._kspace_data is not None:
return self._kspace_data
raise ValueError("static_vol is not defined")
raise UndefinedArrayError("static_vol is not defined")

@kspace_data.setter
def kspace_data(self, value: NDArray) -> None:
Expand All @@ -166,7 +170,7 @@ def kspace_mask(self) -> NDArray:
"""Static volume."""
if self._kspace_mask is not None:
return self._kspace_mask
raise ValueError("static_vol is not defined")
raise UndefinedArrayError("static_vol is not defined")

@kspace_mask.setter
def kspace_mask(self, value: NDArray) -> None:
Expand All @@ -177,7 +181,7 @@ def data_ref(self) -> NDArray | LazySimArray:
"""Static volume."""
if self._data_ref is not None:
return self._data_ref
raise ValueError("data_ref is not defined")
raise UndefinedArrayError("data_ref is not defined")

@data_ref.setter
def data_ref(self, value: NDArray) -> None:
Expand All @@ -188,7 +192,7 @@ def data_acq(self) -> NDArray | LazySimArray:
"""Acquired Volume."""
if self._data_acq is not None:
return self._data_acq
raise ValueError("data_acq is not defined")
raise UndefinedArrayError("data_acq is not defined")

@data_acq.setter
def data_acq(self, value: NDArray) -> None:
Expand All @@ -199,7 +203,7 @@ def roi(self) -> NDArray:
"""Reference data volume."""
if self._roi is not None:
return self._roi
raise ValueError("static_vol is not defined")
raise UndefinedArrayError("roi is not defined")

@roi.setter
def roi(self, value: NDArray) -> None:
Expand Down
3 changes: 2 additions & 1 deletion src/snkf/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Utilities tools for snkf."""
from .typing import RngType, AnyShape
from .utils import validate_rng, cplx_type, real_type
from .utils import validate_rng, cplx_type, real_type, DuplicateFilter

__all__ = [
"AnyShape",
"RngType",
"validate_rng",
"cplx_type",
"real_type",
"DuplicateFilter",
]
45 changes: 42 additions & 3 deletions src/snkf/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,38 @@
import logging
import numpy as np
from numpy.typing import DTypeLike
from typing import Any
from snkf.utils.typing import RngType

sim_logger = logging.getLogger("simulation")


class DuplicateFilter(logging.Filter):
"""
Filters away duplicate log messages.
https://stackoverflow.com/a/60462619
"""

def __init__(self, logger: logging.Logger):
self.msgs: set[str] = set()
self.logger = logger

def filter(self, record: logging.LogRecord) -> bool:
"""Filter duplicate records."""
msg = str(record.msg)
is_duplicate = msg in self.msgs
if not is_duplicate:
self.msgs.add(msg)
return not is_duplicate

def __enter__(self):
self.logger.addFilter(self)

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
self.logger.removeFilter(self)


def validate_rng(rng: RngType = None) -> np.random.Generator:
"""Validate Random Number Generator."""
if isinstance(rng, (int, list)): # TODO use pattern matching
Expand All @@ -29,12 +56,18 @@ def cplx_type(dtype: DTypeLike) -> DTypeLike:
np.complex64
"""
d = np.dtype(dtype)
if d.type is np.float64:
if d.type is np.complex64:
return np.complex64
elif d.type is np.complex128:
return np.complex128
elif d.type is np.float64:
return np.complex128
elif d.type is np.float32:
return np.complex64
else:
sim_logger.warning("unsupported dtype, use matching complex64", stack_info=True)
sim_logger.warning(
f"unsupported dtype {d}, using default complex64", stack_info=True
)
return np.complex64


Expand All @@ -51,8 +84,14 @@ def real_type(
d = np.dtype(dtype)
if d.type is np.complex64:
return np.dtype("float32")
elif d.type is np.float32:
return np.dtype("float32")
elif d.type is np.complex128:
return np.dtype("float64")
elif d.type is np.float64:
return np.dtype("float64")
else:
sim_logger.warning("unsupported dtype, use matching float32", stack_info=True)
sim_logger.warning(
f"unsupported dtype ({d}) using default float32", stack_info=True
)
return np.dtype("float32")

0 comments on commit e6493b9

Please sign in to comment.