diff --git a/pmd_beamphysics/wavefront.py b/pmd_beamphysics/wavefront.py index 100828b..0a1c846 100644 --- a/pmd_beamphysics/wavefront.py +++ b/pmd_beamphysics/wavefront.py @@ -8,6 +8,7 @@ import pathlib from typing import Any, List, Optional, Sequence, Tuple, Union +import h5py import matplotlib import matplotlib.axes import matplotlib.pyplot as plt @@ -232,20 +233,7 @@ def is_odd(value: int) -> bool: return value % 2 == 1 -def _fix_fft_dimension(dim: int): - """Get a dimension that's efficient for the FFT - and also odd for symmetry.""" - while True: - next_dim = scipy.fft.next_fast_len(dim, real=False) - if next_dim is None: - raise ValueError(f"Unable to get the next valid dimension for: {dim}") - dim = next_dim - if is_odd(dim): - break - dim += 1 - return dim - - -def _fix_grid_padding(grid: int, pad: int) -> Tuple[int, int]: +def _fix_grid_padding(grid: int, pad: int) -> int: """ Fix gridding and padding values for symmetry and FFT efficiency. @@ -260,29 +248,29 @@ def _fix_grid_padding(grid: int, pad: int) -> Tuple[int, int]: Returns ------- - int - Adjusted data gridding. int Adjusted data padding. """ - # Grid must be odd for us: - if not is_odd(grid): - grid += 1 - # Ensure that our FFT dimension is odd and optimal for scipy's FFT: - dim = _fix_fft_dimension(grid + 2 * pad) - assert is_odd(dim), "FFT dimension not odd?" + def is_good(dim: int) -> bool: + return is_odd(dim) and scipy.fft.next_fast_len(dim, real=False) == dim + + while not is_good(grid + pad): + dim = scipy.fft.next_fast_len(grid + pad + 1, real=False) + if dim is None: + raise ValueError( + f"Unable to get the next valid FFT length for: {grid=} {pad=}" + ) - # Fix padding based on our optimal dimension: - pad = (dim - grid) // 2 - assert not is_odd(dim - grid), "End dimension not even as expected?" - return grid, pad + pad = dim - grid + + return pad def fix_padding( grid: Sequence[int], pad: Sequence[int], -) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: +) -> Tuple[int, ...]: """ Fix gridding and padding values for symmetry and FFT efficiency. @@ -297,8 +285,6 @@ def fix_padding( Returns ------- - tuple of ints - Adjusted data gridding. tuple of ints Adjusted data padding. """ @@ -308,21 +294,20 @@ def fix_padding( f"Got: {len(grid)} and {len(pad)}" ) - result = [[], []] + final_padding = [] for dim, (dim_grid, dim_pad) in enumerate(zip(grid, pad)): - new_grid, new_pad = _fix_grid_padding(dim_grid, dim_pad) - logger.debug( - "Grid[%d] %d -> %d pad %d -> %d", - dim, - dim_grid, - new_grid, - dim_pad, - new_pad, - ) - result[0].append(new_grid) - result[1].append(new_pad) + new_pad = _fix_grid_padding(dim_grid, dim_pad) + if new_pad != dim_pad: + logger.debug( + "Grid[dim=%d] grid=%d pad=%d -> adjusted padding=%d", + dim, + dim_grid, + dim_pad, + new_pad, + ) + final_padding.append(new_pad) - return tuple(result[0]), tuple(result[1]) + return tuple(final_padding) def get_shifts( @@ -584,8 +569,7 @@ def fix(self) -> WavefrontPadding: Such that the total array size will be: `grid + 2 * pad`. """ - grid, pad = fix_padding(self.grid, self.pad) - return WavefrontPadding(grid, pad) + return WavefrontPadding(self.grid, fix_padding(self.grid, self.pad)) def get_padded_shape(self, field_rspace: np.ndarray) -> Tuple[int, ...]: """Get the padded shape given a 3D field rspace array.""" @@ -593,10 +577,10 @@ def get_padded_shape(self, field_rspace: np.ndarray) -> Tuple[int, ...]: if field_rspace.ndim != nd: raise ValueError(f"`field_rspace` is not an {nd}D array") - if not all(is_odd(dim) for dim in field_rspace.shape): - raise ValueError( - f"`field_rspace` dimensions are not all odd numbers: {field_rspace.shape}" - ) + # if not all(is_odd(dim) for dim in field_rspace.shape): + # raise ValueError( + # f"`field_rspace` dimensions are not all odd numbers: {field_rspace.shape}" + # ) return tuple(dim + 2 * pad for dim, pad in zip(field_rspace.shape, self.pad)) @@ -636,6 +620,12 @@ def __init__( ) -> None: if not pad: pad = (40,) + (100,) * (field_rspace.ndim - 1) + + if len(ranges) != field_rspace.ndim: + raise ValueError( + "'ranges' must have the same number of dimensions as `field_rspace`; " + "each should describe the cartesian range of the corresponding axis." + ) self._phasors = None self._field_rspace = field_rspace self._field_rspace_shape = field_rspace.shape @@ -1003,7 +993,13 @@ def plot( sum_axis = { # TODO: when standardized, this will be xyz instead of txy - "xy": 0, + # "xy": 0, + # (1, 2): 0, + "xy": 2, + (0, 1): 2, + "xz": 1, + (0, 2): 1, + "yz": 0, (1, 2): 0, }[plane] @@ -1055,3 +1051,25 @@ def plot(dat, title: str): fig.savefig(save) return fig, axs + + @classmethod + def from_genesis4( + cls, h5: Union[h5py.File, pathlib.Path, str], pad: int = 100 + ) -> Wavefront: + from genesis.version4 import FieldFile + + field = FieldFile.from_file(h5) + + _nx, _ny, nz = field.dfl.shape + z_low = field.param.refposition + z_high = z_low + field.param.slicespacing * nz # TODO: off by one? + return cls( + field_rspace=field.dfl, + wavelength=field.param.wavelength, + ranges=[ + (-field.param.gridsize, field.param.gridsize), + (-field.param.gridsize, field.param.gridsize), + (z_low, z_high), + ], + pad=(pad, pad, pad), + )