From 1476cd34229a44e135b3bc04cd97e4519eca52a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20K=C3=B6hler?= <27728103+Ceyron@users.noreply.github.com> Date: Wed, 12 Jun 2024 14:18:30 +0200 Subject: [PATCH] Improved Viz module (including new wrapped volume renderer) (#5) * Lighweight wrap around volume rendering * Add gamma correction * Change default bg color to white * Change to batch rendering by default * Add routine for spatio-temporal of 2d states * Change default to batch rendering * Add single channel animation routine * Faceting 3d states * Facet 3d animation * Implement spatio temporal facet * Remove arguments not yet understood by vape * Add documentation * Forward changes in interface * Remove ax arg * Adapt return structure * add Docs * Add docs * Adapt interfaces * Change return structure * Final change to plotting interface * close figure facets * Add arguments for compatibility * Add dummy function * Add docs * Add dummy function * Remove debug print * Include temporal grid with state 2d animation * Add docs remove ax * Add missing time suptitle * Add docs * Allow changing cmap * Start writing tests for viz routines * Fix formatting --- exponax/viz/__init__.py | 26 +++- exponax/viz/_animate.py | 136 ++++++++++++++++++- exponax/viz/_animate_facet.py | 197 ++++++++++++++++++++++++++- exponax/viz/_plot.py | 242 ++++++++++++++++++++++++++++++++-- exponax/viz/_plot_facet.py | 238 ++++++++++++++++++++++++++++++++- exponax/viz/_volume.py | 139 +++++++++++++++++++ tests/test_viz.py | 35 +++++ 7 files changed, 989 insertions(+), 24 deletions(-) create mode 100644 exponax/viz/_volume.py create mode 100644 tests/test_viz.py diff --git a/exponax/viz/__init__.py b/exponax/viz/__init__.py index 1f86891..35e7284 100644 --- a/exponax/viz/__init__.py +++ b/exponax/viz/__init__.py @@ -17,18 +17,33 @@ can be animated over another axis (some notion of time). """ -from ._animate import animate_spatio_temporal, animate_state_1d, animate_state_2d +from ._animate import ( + animate_spatio_temporal, + animate_state_1d, + animate_state_2d, + animate_state_3d, +) from ._animate_facet import ( animate_spatial_temporal_facet, animate_state_1d_facet, animate_state_2d_facet, + animate_state_3d_facet, +) +from ._plot import ( + plot_spatio_temporal, + plot_spatio_temporal_2d, + plot_state_1d, + plot_state_2d, + plot_state_3d, ) -from ._plot import plot_spatio_temporal, plot_state_1d, plot_state_2d from ._plot_facet import ( + plot_spatio_temporal_2d_facet, plot_spatio_temporal_facet, plot_state_1d_facet, plot_state_2d_facet, + plot_state_3d_facet, ) +from ._volume import volume_render_state_3d # from IPython.display import HTML @@ -45,4 +60,11 @@ "animate_state_2d_facet", "animate_spatio_temporal", "animate_spatial_temporal_facet", + "volume_render_state_3d", + "plot_state_3d", + "plot_spatio_temporal_2d", + "animate_state_3d", + "plot_state_3d_facet", + "animate_state_3d_facet", + "plot_spatio_temporal_2d_facet", ] diff --git a/exponax/viz/_animate.py b/exponax/viz/_animate.py index b8a380e..1736f87 100644 --- a/exponax/viz/_animate.py +++ b/exponax/viz/_animate.py @@ -1,11 +1,14 @@ -from typing import TypeVar +from typing import Literal, TypeVar, Union +import jax import jax.numpy as jnp import matplotlib.pyplot as plt from jaxtyping import Array, Float from matplotlib.animation import FuncAnimation +from .._utils import wrap_bc from ._plot import plot_spatio_temporal, plot_state_1d, plot_state_2d +from ._volume import volume_render_state_3d, zigzag_alpha N = TypeVar("N") @@ -91,6 +94,7 @@ def animate_spatio_temporal( trjs: Float[Array, "S T C N"], *, vlim: tuple[float, float] = (-1.0, 1.0), + cmap: str = "RdBu_r", domain_extent: float = None, dt: float = None, include_init: bool = False, @@ -114,6 +118,7 @@ def animate_spatio_temporal( - `trjs`: The trajectory of states to animate. Must be a four-axis array with shape `(n_timesteps_outer, n_time_steps, n_channels, n_spatial)`. - `vlim`: The limits of the colorbar. Default is `(-1, 1)`. + - `cmap`: The colormap to use. Default is `"RdBu_r"`. - `domain_extent`: The extent of the spatial domain. Default is `None`. This affects the x-axis limits of the plot. - `dt`: The time step between each frame. Default is `None`. If provided, @@ -136,6 +141,7 @@ def animate_spatio_temporal( plot_spatio_temporal( trjs[0], vlim=vlim, + cmap=cmap, domain_extent=domain_extent, dt=dt, include_init=include_init, @@ -148,6 +154,7 @@ def animate(i): plot_spatio_temporal( trjs[i], vlim=vlim, + cmap=cmap, domain_extent=domain_extent, dt=dt, include_init=include_init, @@ -166,6 +173,7 @@ def animate_state_2d( trj: Float[Array, "T 1 N N"], *, vlim: tuple[float, float] = (-1.0, 1.0), + cmap: str = "RdBu_r", domain_extent: float = None, dt: float = None, include_init: bool = False, @@ -186,6 +194,7 @@ def animate_state_2d( - `trj`: The trajectory of states to animate. Must be a four-axis array with shape `(n_timesteps, 1, n_spatial, n_spatial)`. - `vlim`: The limits of the colorbar. Default is `(-1, 1)`. + - `cmap`: The colormap to use. Default is `"RdBu_r"`. - `domain_extent`: The extent of the spatial domain. Default is `None`. This affects the x- and y-axis limits of the plot. - `dt`: The time step between each frame. Default is `None`. If provided, @@ -205,31 +214,146 @@ def animate_state_2d( fig, ax = plt.subplots() - if dt is not None: - time_range = (0, dt * trj.shape[0]) - if not include_init: - time_range = (dt, time_range[1]) + if include_init: + temporal_grid = jnp.arange(trj.shape[0]) else: - time_range = (0, trj.shape[0] - 1) + temporal_grid = jnp.arange(1, trj.shape[0] + 1) + + if dt is not None: + temporal_grid *= dt plot_state_2d( trj[0], vlim=vlim, + cmap=cmap, domain_extent=domain_extent, ax=ax, ) + ax.set_title(f"t = {temporal_grid[0]:.2f}") def animate(i): ax.clear() plot_state_2d( trj[i], vlim=vlim, + cmap=cmap, domain_extent=domain_extent, ax=ax, ) + ax.set_title(f"t = {temporal_grid[i]:.2f}") plt.close(fig) ani = FuncAnimation(fig, animate, frames=trj.shape[0], interval=100, blit=False) return ani + + +def animate_state_3d( + trj: Float[Array, "T 1 N N N"], + *, + vlim: tuple[float, float] = (-1.0, 1.0), + domain_extent: float = None, + dt: float = None, + include_init: bool = False, + bg_color: Union[ + Literal["black"], + Literal["white"], + tuple[jnp.int8, jnp.int8, jnp.int8, jnp.int8], + ] = "white", + resolution: int = 384, + cmap: str = "RdBu_r", + transfer_function: callable = zigzag_alpha, + distance_scale: float = 10.0, + gamma_correction: float = 2.4, + chunk_size: int = 64, + **kwargs, +): + """ + Animate a trajectory of 3d states as volume renderings. + + Requires the input to be a five-axis array with a leading time axis, a + channel axis, and three spatial axes. Only the zeroth dimension in the + channel axis is plotted. + + Periodic boundary conditions will be applied to the spatial axes (the state + is wrapped around). + + **Arguments**: + + - `trj`: The trajectory of states to animate. Must be a five-axis array with + shape `(n_timesteps, 1, n_spatial, n_spatial, n_spatial)`. + - `vlim`: The limits of the colorbar. Default is `(-1, 1)`. + - `domain_extent`: (Unused as of now) + - `dt`: The time step between each frame. Default is `None`. If provided, + a title will be displayed with the current time. If not provided, just + the frames are counted. + - `include_init`: Whether to the state starts at an initial condition (t=0) + or at the first frame in the trajectory. This affects is the the time + range is [0, (T-1)dt] or [dt, Tdt]. Default is `False`. + - `bg_color`: The background color. Either `"black"`, `"white"`, or a tuple + of RGBA values. Default is `"white"`. + - `resolution`: The resolution of the output image (affects render time). + Default is `384`. + - `cmap`: The colormap to use. Default is `"RdBu_r"`. + - `transfer_function`: The transfer function to use. Default is `zigzag_alpha`. + - `distance_scale`: The distance scale. Default is `10.0`. + - `gamma_correction`: The gamma correction. Default is `2.4`. + - `chunk_size`: The chunk size. Default is `64`. + + **Returns**: + + - `ani`: The animation object. + + **Note:** + + - This function requires the `vape` volume renderer package. + """ + if trj.ndim != 5: + raise ValueError("trj must be a five-axis array.") + + fig, ax = plt.subplots() + + if include_init: + temporal_grid = jnp.arange(trj.shape[0]) + else: + temporal_grid = jnp.arange(1, trj.shape[0] + 1) + + if dt is not None: + temporal_grid *= dt + + trj_wrapped = jax.vmap(wrap_bc)(trj) + trj_wrapped_no_channel = trj_wrapped[:, 0] + + imgs = volume_render_state_3d( + trj_wrapped_no_channel, + vlim=vlim, + bg_color=bg_color, + resolution=resolution, + cmap=cmap, + transfer_function=transfer_function, + distance_scale=distance_scale, + gamma_correction=gamma_correction, + chunk_size=chunk_size, + **kwargs, + ) + + ax.imshow(imgs[0]) + ax.axis("off") + ax.set_title(f"t = {temporal_grid[0]:.2f}") + + def animate(i): + ax.clear() + ax.imshow(imgs[i]) + ax.axis("off") + ax.set_title(f"t = {temporal_grid[i]:.2f}") + + ani = FuncAnimation(fig, animate, frames=trj.shape[0], interval=100, blit=False) + + plt.close(fig) + + return ani + + +def animate_spatio_temporal_2d(): + raise NotImplementedError("This function is not yet implemented.") diff --git a/exponax/viz/_animate_facet.py b/exponax/viz/_animate_facet.py index dcbd10f..f364b3f 100644 --- a/exponax/viz/_animate_facet.py +++ b/exponax/viz/_animate_facet.py @@ -1,11 +1,14 @@ -from typing import TypeVar, Union +from typing import Literal, TypeVar, Union +import jax import jax.numpy as jnp import matplotlib.pyplot as plt from jaxtyping import Array, Float from matplotlib.animation import FuncAnimation +from .._utils import wrap_bc from ._plot import plot_state_1d, plot_state_2d +from ._volume import volume_render_state_3d, zigzag_alpha N = TypeVar("N") @@ -17,6 +20,8 @@ def animate_state_1d_facet( labels: list[str] = None, titles: list[str] = None, domain_extent: float = None, + dt: float = None, + include_init: bool = False, grid: tuple[int, int] = (3, 3), figsize: tuple[float, float] = (10, 10), **kwargs, @@ -56,6 +61,14 @@ def animate_state_1d_facet( if trj.ndim != 4: raise ValueError("states must be a four-axis array.") + if include_init: + temporal_grid = jnp.arange(trj.shape[1]) + else: + temporal_grid = jnp.arange(1, trj.shape[1] + 1) + + if dt is not None: + temporal_grid *= dt + fig, ax_s = plt.subplots(*grid, figsize=figsize) num_subplots = trj.shape[0] @@ -74,6 +87,7 @@ def animate_state_1d_facet( else: if titles is not None: ax.set_title(titles[j]) + title = fig.suptitle(f"t = {temporal_grid[0]:.2f}") def animate(i): for j, ax in enumerate(ax_s.flatten()): @@ -91,6 +105,7 @@ def animate(i): else: if titles is not None: ax.set_title(titles[j]) + title.set_text(f"t = {temporal_grid[i]:.2f}") ani = FuncAnimation(fig, animate, frames=trj.shape[1], interval=100, blit=False) @@ -102,6 +117,7 @@ def animate_spatial_temporal_facet( *, facet_over_channels: bool = True, vlim: tuple[float, float] = (-1.0, 1.0), + cmap: str = "RdBu_r", domain_extent: float = None, dt: float = None, include_init: bool = False, @@ -139,6 +155,7 @@ def animate_spatial_temporal_facet( - `facet_over_channels`: Whether to facet over the channel axis or the batch axis. Default is `True`. - `vlim`: The limits of the colorbar. Default is `(-1, 1)`. + - `cmap`: The colormap to use. Default is `"RdBu_r"`. - `domain_extent`: The extent of the spatial domain. Default is `None`. This affects the x-axis limits of the plot. - `dt`: The time step between each frame. Default is `None`. If provided, @@ -170,6 +187,10 @@ def animate_state_2d_facet( *, facet_over_channels: bool = True, vlim: tuple[float, float] = (-1.0, 1.0), + cmap: str = "RdBu_r", + domain_extent: float = None, + dt: float = None, + include_init: bool = False, grid: tuple[int, int] = (3, 3), figsize: tuple[float, float] = (10, 10), titles=None, @@ -200,6 +221,13 @@ def animate_state_2d_facet( - `facet_over_channels`: Whether to facet over the channel axis or the batch axis. Default is `True`. - `vlim`: The limits of the colorbar. Default is `(-1, 1)`. + - `cmap`: The colormap to use. Default is `"RdBu_r"`. + - `domain_extent`: The extent of the spatial domain. Default is `None`. This + affects the x-axis and y-axis limits of the plot. + - `dt`: The time step between each frame. Default is `None`. + - `include_init`: Whether to the state starts at an initial condition (t=0) + or at the first frame in the trajectory. This affects is the the time + range is [0, (T-1)dt] or [dt, Tdt]. Default is `False`. - `grid`: The grid of subplots. Default is `(3, 3)`. - `figsize`: The size of the figure. Default is `(10, 10)`. - `titles`: The titles for each subplot. Default is `None`. @@ -219,16 +247,27 @@ def animate_state_2d_facet( trj = jnp.swapaxes(trj, 0, 1) trj = trj[:, :, None] + if include_init: + temporal_grid = jnp.arange(trj.shape[1]) + else: + temporal_grid = jnp.arange(1, trj.shape[1] + 1) + + if dt is not None: + temporal_grid *= dt + fig, ax_s = plt.subplots(*grid, sharex=True, sharey=True, figsize=figsize) for j, ax in enumerate(ax_s.flatten()): plot_state_2d( trj[j, 0], vlim=vlim, + cmap=cmap, ax=ax, + domain_extent=domain_extent, ) if titles is not None: ax.set_title(titles[j]) + title = fig.suptitle(f"t = {temporal_grid[0]:.2f}") def animate(i): for j, ax in enumerate(ax_s.flatten()): @@ -236,13 +275,169 @@ def animate(i): plot_state_2d( trj[j, i], vlim=vlim, + cmap=cmap, ax=ax, ) if titles is not None: ax.set_title(titles[j]) + title.set_text(f"t = {temporal_grid[i]:.2f}") plt.close(fig) ani = FuncAnimation(fig, animate, frames=trj.shape[1], interval=100, blit=False) return ani + + +def animate_state_3d_facet( + trj: Union[Float[Array, "T C N N N"], Float[Array, "B T 1 N N N"]], + *, + facet_over_channels: bool = True, + vlim: tuple[float, float] = (-1.0, 1.0), + grid: tuple[int, int] = (3, 3), + figsize: tuple[float, float] = (10, 10), + titles=None, + domain_extent: float = None, + dt: float = None, + include_init: bool = False, + bg_color: Union[ + Literal["black"], + Literal["white"], + tuple[jnp.int8, jnp.int8, jnp.int8, jnp.int8], + ] = "white", + resolution: int = 384, + cmap: str = "RdBu_r", + transfer_function: callable = zigzag_alpha, + distance_scale: float = 10.0, + gamma_correction: float = 2.4, + chunk_size: int = 64, + **kwargs, +): + """ + Animate a facet of trajectories of 3d states as volume renderings. + + Requires the input to be either a five-axis array or a six-axis array: + + - If `facet_over_channels` is `True`, the input must be a five-axis array + with a leading time axis, a channel axis, and three spatial axes. Each + faceted subplot displays a different channel. + - If `facet_over_channels` is `False`, the input must be a six-axis array + with a leading batch axis, a time axis, a channel axis, and three spatial + axes. Each faceted subplot displays a different batch. Only the zeroth + dimension in the channel axis is plotted. + + **Arguments**: + + - `trj`: The trajectory of states to animate. Must be a five-axis array with + shape `(n_timesteps, n_channels, n_spatial, n_spatial, n_spatial)` if + `facet_over_channels` is `True`, or a six-axis array with shape + `(n_batches, n_timesteps, n_channels, n_spatial, n_spatial, n_spatial)` + if `facet_over_channels` is `False`. + - `facet_over_channels`: Whether to facet over the channel axis or the batch + axis. Default is `True`. + - `vlim`: The limits of the colorbar. Default is `(-1, 1)`. + - `grid`: The grid of subplots. Default is `(3, 3)`. + - `figsize`: The size of the figure. Default is `(10, 10)`. + - `titles`: The titles for each subplot. Default is `None`. + - `domain_extent`: The extent of the spatial domain. Default is `None`. This + affects the x-axis and y-axis limits of the plot. + - `dt`: The time step between each frame. Default is `None`. + - `include_init`: Whether to the state starts at an initial condition (t=0) + or at the first frame in the trajectory. This affects is the the time + range is [0, (T-1)dt] or [dt, Tdt]. Default is `False`. + - `bg_color`: The background color. Either `"black"`, `"white"`, or a tuple + of RGBA values. Default is `"white"`. + - `resolution`: The resolution of the output image (affects render time). + Default is `384`. + - `cmap`: The colormap to use. Default is `"RdBu_r"`. + - `transfer_function`: The transfer function to use, i.e., how values within + the `vlim` range are mapped to alpha values. Default is `zigzag_alpha`. + - `distance_scale`: The distance scale of the volume renderer. Default is + `10.0`. + - `gamma_correction`: The gamma correction to apply to the image. Default is + `2.4`. + - `chunk_size`: The number of images to render at once. Default is `64`. + + **Returns**: + + - `ani`: The animation object. + + **Note:** + + - This function requires the `vape` volume renderer package. + """ + if facet_over_channels: + if trj.ndim != 5: + raise ValueError("trj must be a five-axis array.") + else: + if trj.ndim != 6: + raise ValueError("trj must be a six-axis array.") + + if facet_over_channels: + trj = jnp.swapaxes(trj, 0, 1) + trj = trj[:, :, None] + + trj_wrapped = jax.vmap(jax.vmap(wrap_bc))(trj) + + imgs = [] + for facet_entry_trj in trj_wrapped: + facet_entry_trj_no_channel = facet_entry_trj[:, 0] + imgs.append( + volume_render_state_3d( + facet_entry_trj_no_channel, + vlim=vlim, + bg_color=bg_color, + resolution=resolution, + cmap=cmap, + transfer_function=transfer_function, + distance_scale=distance_scale, + gamma_correction=gamma_correction, + chunk_size=chunk_size, + **kwargs, + ) + ) + + # shape = (B, T, resolution, resolution, 3) + imgs = jnp.stack(imgs) + + if include_init: + temporal_grid = jnp.arange(trj.shape[1]) + else: + temporal_grid = jnp.arange(1, trj.shape[1] + 1) + + if dt is not None: + temporal_grid *= dt + + fig, ax_s = plt.subplots(*grid, figsize=figsize) + + # num_subplots = trj.shape[0] + + for j, ax in enumerate(ax_s.flatten()): + ax.imshow(imgs[j, 0]) + ax.axis("off") + # if j >= num_subplots: + # ax.remove() + # else: + if titles is not None: + ax.set_title(titles[j]) + title = fig.suptitle(f"t = {temporal_grid[0]:.2f}") + + def animate(i): + for j, ax in enumerate(ax_s.flatten()): + ax.clear() + ax.imshow(imgs[j, i]) + ax.axis("off") + if titles is not None: + ax.set_title(titles[j]) + title.set_text(f"t = {temporal_grid[i]:.2f}") + + ani = FuncAnimation(fig, animate, frames=trj.shape[1], interval=100, blit=False) + + plt.close(fig) + + return ani + + +def animate_spatio_temporal_2d_facet(): + # TODO + raise NotImplementedError("Not implemented yet.") diff --git a/exponax/viz/_plot.py b/exponax/viz/_plot.py index d37b30e..412e38d 100644 --- a/exponax/viz/_plot.py +++ b/exponax/viz/_plot.py @@ -1,10 +1,12 @@ -from typing import TypeVar +from typing import Literal, TypeVar, Union import jax +import jax.numpy as jnp import matplotlib.pyplot as plt from jaxtyping import Array, Float from .._utils import make_grid, wrap_bc +from ._volume import volume_render_state_3d, zigzag_alpha N = TypeVar("N") @@ -43,8 +45,8 @@ def plot_state_1d( **Returns:** - - If `ax` is not provided, returns a tuple with the figure, axis, and plot - object. Otherwise, returns the plot object. + - If `ax` is not provided, returns the figure. Otherwise, returns the plot + object. """ if state.ndim != 2: raise ValueError("state must be a two-axis array.") @@ -60,7 +62,10 @@ def plot_state_1d( grid = make_grid(1, domain_extent, num_points, full=True) if ax is None: + return_all = True fig, ax = plt.subplots() + else: + return_all = False p = ax.plot(grid[0], state_wrapped.T, label=labels, **kwargs) ax.set_ylim(vlim) @@ -70,8 +75,9 @@ def plot_state_1d( ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) - if ax is None: - return fig, ax, p + if return_all: + plt.close(fig) + return fig else: return p @@ -80,6 +86,7 @@ def plot_spatio_temporal( trj: Float[Array, "T 1 N"], *, vlim: tuple[float, float] = (-1.0, 1.0), + cmap: str = "RdBu_r", ax=None, domain_extent: float = None, dt: float = None, @@ -104,6 +111,7 @@ def plot_spatio_temporal( be the time axis, the second axis the channel axis, and the third axis the spatial axis. - `vlim`: The limits of the color scale. + - `cmap`: The colormap to use. - `ax`: The axis to plot on. If not provided, a new figure will be created. - `domain_extent`: The extent of the spatial domain. If not provided, the domain extent will be the number of points in the spatial axis. This @@ -116,8 +124,8 @@ def plot_spatio_temporal( **Returns:** - - If `ax` is not provided, returns a tuple with the figure, axis, and image - object. Otherwise, returns the image object. + - If `ax` is not provided, returns the figure. Otherwise, returns the image + object. """ if trj.ndim != 3: raise ValueError("trj must be a two-axis array.") @@ -139,12 +147,15 @@ def plot_spatio_temporal( if ax is None: fig, ax = plt.subplots() + return_all = True + else: + return_all = False im = ax.imshow( trj_wrapped[:, 0, :].T, vmin=vlim[0], vmax=vlim[1], - cmap="RdBu_r", + cmap=cmap, origin="lower", aspect="auto", extent=(*time_range, *space_range), @@ -153,13 +164,18 @@ def plot_spatio_temporal( ax.set_xlabel("Time") ax.set_ylabel("Space") - return im + if return_all: + plt.close(fig) + return fig + else: + return im def plot_state_2d( state: Float[Array, "1 N N"], *, vlim: tuple[float, float] = (-1.0, 1.0), + cmap: str = "RdBu_r", domain_extent: float = None, ax=None, **kwargs, @@ -180,6 +196,7 @@ def plot_state_2d( - `state`: The state to plot as a three axis array. The first axis should be the channel axis, and the subsequent two axes the spatial axes. - `vlim`: The limits of the color scale. + - `cmap`: The colormap to use. - `domain_extent`: The extent of the spatial domain. If not provided, the domain extent will be the number of points in the spatial axes. This adjusts the x and y axes. @@ -188,8 +205,8 @@ def plot_state_2d( **Returns:** - - If `ax` is not provided, returns a tuple with the figure, axis, and image - object. Otherwise, returns the image object. + - If `ax` is not provided, returns the figure. Otherwise, returns the image + object. """ if state.ndim != 3: raise ValueError("state must be a three-axis array.") @@ -204,12 +221,15 @@ def plot_state_2d( if ax is None: fig, ax = plt.subplots() + return_all = True + else: + return_all = False im = ax.imshow( state_wrapped.T, vmin=vlim[0], vmax=vlim[1], - cmap="RdBu_r", + cmap=cmap, origin="lower", aspect="auto", extent=(*space_range, *space_range), @@ -219,4 +239,200 @@ def plot_state_2d( ax.set_ylabel("x_1") ax.set_aspect("equal") - return im + if return_all: + plt.close(fig) + return fig + else: + return im + + +def plot_state_3d( + state: Float[Array, "1 N N N"], + *, + vlim: tuple[float, float] = (-1.0, 1.0), + domain_extent: float = None, + ax=None, + bg_color: Union[ + Literal["black"], + Literal["white"], + tuple[jnp.int8, jnp.int8, jnp.int8, jnp.int8], + ] = "white", + resolution: int = 384, + cmap: str = "RdBu_r", + transfer_function: callable = zigzag_alpha, + distance_scale: float = 10.0, + gamma_correction: float = 2.4, + **kwargs, +): + """ + Visualizes a three-dimensional state as a volume rendering. + + Requires the input to be a real array with four axes: a leading channel axis, + and three subsequent spatial axes. This function will visualize the zeroth + channel. For plotting multiple channels at the same time, see + `plot_state_3d_facet`. + + Periodic boundary conditions will be applied to the spatial axes (the state + is wrapped around). + + **Arguments:** + + - `state`: The state to plot as a four axis array. The first axis should be + the channel axis, and the subsequent three axes the spatial axes. + - `vlim`: The limits of the color scale. + - `domain_extent`: (Unused as of now) + - `ax`: The axis to plot on. If not provided, a new figure will be created. + - `bg_color`: The background color. Either `"black"`, `"white"`, or a tuple + of RGBA values. + - `resolution`: The resolution of the output image (affects render time). + - `cmap`: The colormap to use. + - `transfer_function`: The transfer function to use, i.e., how values within + the `vlim` range are mapped to alpha values. + - `distance_scale`: The distance scale of the volume renderer. + - `gamma_correction`: The gamma correction to apply to the image. + + **Returns:** + + - If `ax` is not provided, returns the figure. Otherwise, returns the image + object. + + **Note:** + + - This function requires the `vape` volume renderer package. + """ + if state.ndim != 4: + raise ValueError("state must be a four-axis array.") + + one_channel_state = state[0:1] + one_channel_state_wrapped = wrap_bc(one_channel_state) + + imgs = volume_render_state_3d( + one_channel_state_wrapped, + vlim=vlim, + bg_color=bg_color, + resolution=resolution, + cmap=cmap, + transfer_function=transfer_function, + distance_scale=distance_scale, + gamma_correction=gamma_correction, + **kwargs, + ) + + img = imgs[0] + + if ax is None: + fig, ax = plt.subplots() + return_all = True + else: + return_all = False + + im = ax.imshow(img) + ax.axis("off") + + if return_all: + plt.close(fig) + return fig + else: + return im + + +def plot_spatio_temporal_2d( + trj: Float[Array, "T 1 N N"], + *, + vlim: tuple[float, float] = (-1.0, 1.0), + ax=None, + domain_extent: float = None, + dt: float = None, + include_init: bool = False, + bg_color: Union[ + Literal["black"], + Literal["white"], + tuple[jnp.int8, jnp.int8, jnp.int8, jnp.int8], + ] = "white", + resolution: int = 384, + cmap: str = "RdBu_r", + transfer_function: callable = zigzag_alpha, + distance_scale: float = 10.0, + gamma_correction: float = 2.4, + **kwargs, +): + """ + Plot a trajectory of a 2d state as a spatio-temporal plot visualized by a + volume render (space in in plain parallel to screen, and time in the + direction into the screen). + + Requires the input to be a real array with four axes: a leading time axis, a + channel axis, and two subsequent spatial axes. Only the leading dimension in + the channel axis will be plotted. See `plot_spatio_temporal_facet` for + plotting multiple trajectories (e.g. for problems consisting of multiple + channels like Burgers simulations). + + Periodic boundary conditions will be applied to the spatial axes (the state + is wrapped around). + + **Arguments:** + + - `trj`: The trajectory to plot as a four axis array. The first axis should + be the time axis, the second axis the channel axis, and the third and + fourth axes the spatial axes. + - `vlim`: The limits of the color scale. + - `ax`: The axis to plot on. If not provided, a new figure will be created. + - `domain_extent`: (Unused as of now) + - `dt`: (Unused as of now) + - `include_init`: (Unused as of now) + - `bg_color`: The background color. Either `"black"`, `"white"`, or a tuple + of RGBA values. + - `resolution`: The resolution of the output image (affects render time). + - `cmap`: The colormap to use. + - `transfer_function`: The transfer function to use, i.e., how values within + the `vlim` range are mapped to alpha values. + - `distance_scale`: The distance scale of the volume renderer. + - `gamma_correction`: The gamma correction to apply to the image. + + **Returns:** + + - If `ax` is not provided, returns the figure. Otherwise, returns the image + object. + + **Note:** + + - This function requires the `vape` volume renderer package. + """ + if trj.ndim != 4: + raise ValueError("trj must be a four-axis array.") + + trj_one_channel = trj[:, 0:1] + trj_one_channel_wrapped = jax.vmap(wrap_bc)(trj_one_channel) + + trj_reshaped_to_3d = jnp.flip( + jnp.array(trj_one_channel_wrapped.transpose(1, 2, 3, 0)), 3 + ) + + imgs = volume_render_state_3d( + trj_reshaped_to_3d, + vlim=vlim, + bg_color=bg_color, + resolution=resolution, + cmap=cmap, + transfer_function=transfer_function, + distance_scale=distance_scale, + gamma_correction=gamma_correction, + **kwargs, + ) + + img = imgs[0] + + if ax is None: + fig, ax = plt.subplots() + return_all = True + else: + return_all = False + + im = ax.imshow(img) + ax.axis("off") + + if return_all: + plt.close(fig) + return fig + else: + return im diff --git a/exponax/viz/_plot_facet.py b/exponax/viz/_plot_facet.py index 9493449..39970a0 100644 --- a/exponax/viz/_plot_facet.py +++ b/exponax/viz/_plot_facet.py @@ -1,10 +1,17 @@ -from typing import TypeVar, Union +from typing import Literal, TypeVar, Union import jax.numpy as jnp import matplotlib.pyplot as plt from jaxtyping import Array, Float -from ._plot import plot_spatio_temporal, plot_state_1d, plot_state_2d +from ._plot import ( + plot_spatio_temporal, + plot_spatio_temporal_2d, + plot_state_1d, + plot_state_2d, + plot_state_3d, +) +from ._volume import zigzag_alpha N = TypeVar("N") @@ -76,6 +83,8 @@ def plot_state_1d_facet( else: ax.remove() + plt.close(fig) + return fig @@ -84,6 +93,7 @@ def plot_spatio_temporal_facet( *, facet_over_channels: bool = True, vlim: tuple[float, float] = (-1.0, 1.0), + cmap: str = "RdBu_r", grid: tuple[int, int] = (3, 3), figsize: tuple[float, float] = (10, 10), titles: list[str] = None, @@ -116,6 +126,7 @@ def plot_spatio_temporal_facet( - `facet_over_channels`: Whether to facet over the channel axis (three axes) or the batch axis (four axes). - `vlim`: The limits of the color scale. + - `cmap`: The colormap to use. - `grid`: The grid layout for the facet plot. This should be a tuple with two integers. If the number of trajectories is less than the product of the grid, the remaining axes will be removed. @@ -155,6 +166,7 @@ def plot_spatio_temporal_facet( plot_spatio_temporal( single_trj, vlim=vlim, + cmap=cmap, ax=ax, domain_extent=domain_extent, dt=dt, @@ -167,6 +179,8 @@ def plot_spatio_temporal_facet( if titles is not None: ax.set_title(titles[i]) + plt.close(fig) + return fig @@ -175,6 +189,7 @@ def plot_state_2d_facet( *, facet_over_channels: bool = True, vlim: tuple[float, float] = (-1.0, 1.0), + cmap: str = "RdBu_r", grid: tuple[int, int] = (3, 3), figsize: tuple[float, float] = (10, 10), titles: list[str] = None, @@ -204,6 +219,7 @@ def plot_state_2d_facet( - `facet_over_channels`: Whether to facet over the channel axis (three axes) or the batch axis (four axes). - `vlim`: The limits of the color scale. + - `cmap`: The colormap to use. - `grid`: The grid layout for the facet plot. This should be a tuple with two integers. If the number of states is less than the product of the grid, the remaining axes will be removed. @@ -234,8 +250,224 @@ def plot_state_2d_facet( plot_state_2d( states[i], vlim=vlim, + cmap=cmap, + ax=ax, + domain_extent=domain_extent, + **kwargs, + ) + if i >= num_subplots: + ax.remove() + else: + if titles is not None: + ax.set_title(titles[i]) + + plt.close(fig) + + return fig + + +def plot_state_3d_facet( + states: Union[Float[Array, "C N N N"], Float[Array, "B 1 N N N"]], + *, + facet_over_channels: bool = True, + vlim: tuple[float, float] = (-1.0, 1.0), + grid: tuple[int, int] = (3, 3), + figsize: tuple[float, float] = (10, 10), + titles: list[str] = None, + domain_extent: float = None, + bg_color: Union[ + Literal["black"], + Literal["white"], + tuple[jnp.int8, jnp.int8, jnp.int8, jnp.int8], + ] = "white", + resolution: int = 384, + cmap: str = "RdBu_r", + transfer_function: callable = zigzag_alpha, + distance_scale: float = 10.0, + gamma_correction: float = 2.4, + **kwargs, +): + """ + Plot a facet of 3d states as volume renders. + + Requires the input to be a real array with four or five axes: a leading + batch axis, a channel axis, and three subsequent spatial axes. The facet + will be done over the batch axis, requires the `facet_over_channels` + argument to be `False`. Only the zeroth channel for each state will be + plotted. + + Periodic boundary conditions will be applied to the spatial axes (the state + is wrapped around). + + **Arguments:** + + - `states`: The states to plot as a four or five axis array. See above for + the requirements. + - `facet_over_channels`: Whether to facet over the channel axis (four axes) + or the batch axis (five axes). + - `vlim`: The limits of the color scale. + - `grid`: The grid layout for the facet plot. This should be a tuple with + two integers. If the number of states is less than the product of the + grid, the remaining axes will be removed. + - `figsize`: The size of the figure. + - `titles`: The titles for each plot. This should be a list of strings with + the same length as the number of states. + - `domain_extent`: (Unused as of now) + - `bg_color`: The background color. Either `"black"`, `"white"`, or a tuple + of RGBA values. + - `resolution`: The resolution of the output image (affects render time). + - `cmap`: The colormap to use. + - `transfer_function`: The transfer function to use, i.e., how values within + the `vlim` range are mapped to alpha values. + - `distance_scale`: The distance scale of the volume renderer. + - `gamma_correction`: The gamma correction to apply to the image. + + **Returns:** + + - The figure. + + **Note:** + + - This function requires the `vape` volume renderer package. + """ + if facet_over_channels: + if states.ndim != 4: + raise ValueError("states must be a four-axis array.") + states = states[:, None, :, :, :] + else: + if states.ndim != 5: + raise ValueError("states must be a five-axis array.") + + fig, ax_s = plt.subplots(*grid, figsize=figsize) + + num_subplots = states.shape[0] + + for i, ax in enumerate(ax_s.flatten()): + plot_state_3d( + states[i], + vlim=vlim, + domain_extent=domain_extent, + ax=ax, + bg_color=bg_color, + resolution=resolution, + cmap=cmap, + transfer_function=transfer_function, + distance_scale=distance_scale, + gamma_correction=gamma_correction, + **kwargs, + ) + if i >= num_subplots: + ax.remove() + else: + if titles is not None: + ax.set_title(titles[i]) + + plt.close(fig) + + return fig + + +def plot_spatio_temporal_2d_facet( + trjs: Union[Float[Array, "T C N N"], Float[Array, "B T 1 N N"]], + *, + facet_over_channels: bool = True, + vlim: tuple[float, float] = (-1.0, 1.0), + grid: tuple[int, int] = (3, 3), + figsize: tuple[float, float] = (10, 10), + titles: list[str] = None, + domain_extent: float = None, + dt: float = None, + include_init: bool = False, + bg_color: Union[ + Literal["black"], + Literal["white"], + tuple[jnp.int8, jnp.int8, jnp.int8, jnp.int8], + ] = "white", + resolution: int = 384, + cmap: str = "RdBu_r", + transfer_function: callable = zigzag_alpha, + distance_scale: float = 10.0, + gamma_correction: float = 2.4, + **kwargs, +): + """ + Plot a facet of spatio-temporal trajectories. + + Requires the input to be a real array with either four or five axes: + + * Four axes: a leading time axis, a channel axis, and two subsequent spatial + axes. The faceting is performed over the channel axis. Requires the + `facet_over_channels` argument to be `True` (default). + * Five axes: a leading batch axis, a time axis, a channel axis, and two + subsequent spatial axes. The faceting is performed over the batch axis. + Requires the `facet_over_channels` argument to be `False`. + + Periodic boundary conditions will be applied to the spatial axes (the state + is wrapped around). + + **Arguments:** + + - `trjs`: The trajectories to plot as a four or five axis array. See above + for the requirements. + - `facet_over_channels`: Whether to facet over the channel axis (four axes) + or the batch axis (five axes). + - `vlim`: The limits of the color scale. + - `grid`: The grid layout for the facet plot. This should be a tuple with + two integers. If the number of trajectories is less than the product of + the grid, the remaining axes will be removed. + - `figsize`: The size of the figure. + - `titles`: The titles for each plot. This should be a list of strings with + the same length as the number of trajectories. + - `domain_extent`: (Unused as of now) + - `dt`: (Unused as of now) + - `include_init`: (Unused as of now) + - `bg_color`: The background color. Either `"black"`, `"white"`, or a tuple + of RGBA values. + - `resolution`: The resolution of the output image (affects render time). + - `cmap`: The colormap to use. + - `transfer_function`: The transfer function to use, i.e., how values within + the `vlim` range are mapped to alpha values. + - `distance_scale`: The distance scale of the volume renderer. + - `gamma_correction`: The gamma correction to apply to the image. + + **Returns:** + + - The figure. + + **Note:** + + - This function requires the `vape` volume renderer package. + """ + if facet_over_channels: + if trjs.ndim != 4: + raise ValueError("trjs must be a four-axis array.") + else: + if trjs.ndim != 5: + raise ValueError("trjs must be a five-axis array.") + + fig, ax_s = plt.subplots(*grid, figsize=figsize) + + if facet_over_channels: + trjs = jnp.swapaxes(trjs, 0, 1) + trjs = trjs[:, :, None, :, :] + + num_subplots = trjs.shape[0] + + for i, ax in enumerate(ax_s.flatten()): + single_trj = trjs[i] + plot_spatio_temporal_2d( + single_trj, + vlim=vlim, ax=ax, domain_extent=domain_extent, + dt=dt, + include_init=include_init, + bg_color=bg_color, + resolution=resolution, + cmap=cmap, + transfer_function=transfer_function, + distance_scale=distance_scale, + gamma_correction=gamma_correction, **kwargs, ) if i >= num_subplots: @@ -244,4 +476,6 @@ def plot_state_2d_facet( if titles is not None: ax.set_title(titles[i]) + plt.close(fig) + return fig diff --git a/exponax/viz/_volume.py b/exponax/viz/_volume.py new file mode 100644 index 0000000..3726f8a --- /dev/null +++ b/exponax/viz/_volume.py @@ -0,0 +1,139 @@ +""" +High-Level abstractions around the vape volume renderer. +""" + +import copy +from typing import Literal, Union + +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +from jaxtyping import Array, Float +from matplotlib.colors import LinearSegmentedColormap, ListedColormap + + +def triangle_wave(x, p): + return 2 * jnp.abs(x / p - jnp.floor(x / p + 0.5)) + + +def zigzag_alpha(cmap, min_alpha=0.0): + """changes the alpha channel of a colormap to be linear (0->0, 1->1) + + Args: + cmap (Colormap): colormap + + Returns:a + Colormap: new colormap + """ + if isinstance(cmap, ListedColormap): + colors = copy.deepcopy(cmap.colors) + for i, a in enumerate(colors): + a.append( + (triangle_wave(i / (cmap.N - 1), 0.5) * (1 - min_alpha)) + min_alpha + ) + return ListedColormap(colors, cmap.name) + elif isinstance(cmap, LinearSegmentedColormap): + segmentdata = copy.deepcopy(cmap._segmentdata) + segmentdata["alpha"] = jnp.array( + [ + [0.0, 0.0, 0.0], + [0.25, 1.0, 1.0], + [0.5, 0.0, 0.0], + [0.75, 1.0, 1.0], + [1.0, 0.0, 0.0], + ] + ) + return LinearSegmentedColormap(cmap.name, segmentdata) + else: + raise TypeError( + "cmap must be either a ListedColormap or a LinearSegmentedColormap" + ) + + +def chunk_list(lst, n): + for i in range(0, len(lst), n): + yield lst[i : i + n] + + +def volume_render_state_3d( + states: Float[Array, "B N N N"], + *, + vlim: tuple[float, float] = (-1.0, 1.0), + bg_color: Union[ + Literal["black"], + Literal["white"], + tuple[jnp.int8, jnp.int8, jnp.int8, jnp.int8], + ] = "white", + resolution: int = 384, + cmap: str = "RdBu_r", + transfer_function: callable = zigzag_alpha, + distance_scale: float = 10.0, + gamma_correction: float = 2.4, + chunk_size: int = 64, +) -> Float[Array, "B resolution resolution 3"]: + """ + (Batched) rendering using the vape volume renderer. + + **Arguments:** + + - `states`: The states to render, shape `(B, N, N, N)`. To just render one + image, this array must have a leading singleton axis (i.e., has shape + `(1, N, N, N)`), then extract the one image from the returned array. + - `vlim`: The min and max values for the colormap. + - `bg_color`: The background color. Either `"black"`, `"white"`, or a tuple + of RGBA values. + - `resolution`: The resolution of the output image (affects render time). + - `cmap`: The colormap to use. + - `transfer_function`: The transfer function to use, i.e., how values within + the `vlim` range are mapped to alpha values. + - `distance_scale`: The distance scale of the volume renderer. + - `gamma_correction`: The gamma correction to apply to the image. + - `chunk_size`: The number of images to render at once. + + **Returns:** + + - `imgs`: The rendered images, in terms of RBG-images (channels-last) and a + leading batch axis, shape `(B, resolution, resolution, 3)`. + """ + if states.ndim != 4: + raise ValueError("state must be a four-axis array.") + try: + import vape + except ImportError: + raise ImportError("This function requires the `vape` volume renderer package.") + + if bg_color == "black": + bg_color = (0, 0, 0, 255) + elif bg_color == "white": + bg_color = (255, 255, 255, 255) + + # Need to convert to numpy array + states = np.array(states).astype(np.float32) + + cmap_with_alpha_transfer = transfer_function(plt.get_cmap(cmap)) + + num_images = states.shape[0] + + imgs = [] + for time_steps in chunk_list(range(num_images), chunk_size): + if num_images == 1: + sub_time_steps = [0.0] + else: + sub_time_steps = [i / (num_images - 1) for i in time_steps] + imgs_this_batch = vape.render( + states, + cmap=cmap_with_alpha_transfer, + time=sub_time_steps, + width=resolution, + height=resolution, + background=bg_color, + vmin=vlim[0], + vmax=vlim[1], + distance_scale=distance_scale, + ) + imgs.append(imgs_this_batch) + + imgs = np.concatenate(imgs, axis=0) + imgs = ((imgs / 255.0) ** (gamma_correction) * 255).astype(np.uint8) + + return imgs diff --git a/tests/test_viz.py b/tests/test_viz.py new file mode 100644 index 0000000..045ac57 --- /dev/null +++ b/tests/test_viz.py @@ -0,0 +1,35 @@ +import jax +import matplotlib.pyplot as plt + +import exponax as ex + + +def test_plot_state_1d(): + state = jax.random.normal( + jax.random.PRNGKey(0), + (10, 100), + ) + + fig = ex.viz.plot_state_1d(state) + plt.close(fig) + + +def test_plot_spatio_temporal(): + trj = jax.random.normal(jax.random.PRNGKey(0), (100, 1, 64)) + + fig = ex.viz.plot_spatio_temporal(trj) + plt.close(fig) + + +def test_plot_state_2d(): + state = jax.random.normal(jax.random.PRNGKey(0), (1, 100, 100)) + + fig = ex.viz.plot_state_2d(state) + plt.close(fig) + + +def test_plot_state_3d(): + state = jax.random.normal(jax.random.PRNGKey(0), (1, 32, 32, 32)) + + fig = ex.viz.plot_state_3d(state) + plt.close(fig)