Skip to content

Commit

Permalink
feat: Added pixelated cube with time dimension (#217)
Browse files Browse the repository at this point in the history
  • Loading branch information
ConnorStoneAstro authored Jun 19, 2024
1 parent d871428 commit 6d2d31b
Show file tree
Hide file tree
Showing 6 changed files with 354 additions and 3 deletions.
1 change: 1 addition & 0 deletions .codespell-whitelist
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ sie
childs
dout
din
tht
8 changes: 7 additions & 1 deletion src/caustics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
MassSheet,
TNFW,
)
from .light import Source, Pixelated, Sersic # PROBESDataset conflicts with .data
from .light import (
Source,
Pixelated,
PixelatedTime,
Sersic,
) # PROBESDataset conflicts with .data
from .data import HDF5Dataset, IllustrisKappaDataset, PROBESDataset
from . import utils
from .sims import Lens_Source, Simulator
Expand Down Expand Up @@ -56,6 +61,7 @@
"TNFW",
"Source",
"Pixelated",
"PixelatedTime",
"Sersic",
"HDF5Dataset",
"IllustrisKappaDataset",
Expand Down
3 changes: 2 additions & 1 deletion src/caustics/light/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .pixelated import Pixelated
from .probes import PROBESDataset
from .sersic import Sersic
from .pixelated_time import PixelatedTime

__all__ = ["Source", "Pixelated", "PROBESDataset", "Sersic"]
__all__ = ["Source", "Pixelated", "PixelatedTime", "PROBESDataset", "Sersic"]
195 changes: 195 additions & 0 deletions src/caustics/light/pixelated_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# mypy: disable-error-code="operator,union-attr"
from typing import Optional, Union, Annotated

from torch import Tensor

from ..utils import interp3d
from .base import Source, NameType
from ..parametrized import unpack
from ..packed import Packed

__all__ = ("PixelatedTime",)


class PixelatedTime(Source):
"""
``PixelatedTime`` is a subclass of the abstract class ``Source``. It
represents the brightness profile of source with a pixelated grid of
intensity values that also may vary over time.
This class provides a concrete implementation of the ``brightness`` method
required by the ``Source`` superclass. In this implementation, brightness is
determined by interpolating values from the provided source image.
Attributes
----------
x0 : Tensor, optional
The x-coordinate of the source image's center.
*Unit: arcsec*
y0 : Tensor, optional
The y-coordinate of the source image's center.
*Unit: arcsec*
cube : Tensor, optional
The source image cube from which brightness values will be interpolated.
*Unit: flux*
pixelscale : Tensor, optional
The pixelscale of the source image in the lens plane.
*Unit: arcsec/pixel*
t_end : Tensor, optional
The end time of the source image cube. Time in the cube is assumed to be
in the range (0, t_end) in seconds.
*Unit: seconds*
shape : Tuple of ints, optional
The shape of the source image and time dim.
"""

def __init__(
self,
cube: Annotated[
Optional[Tensor],
"The source image cube from which brightness values will be interpolated.",
True,
] = None,
x0: Annotated[
Optional[Union[Tensor, float]],
"The x-coordinate of the source image's center.",
True,
] = None,
y0: Annotated[
Optional[Union[Tensor, float]],
"The y-coordinate of the source image's center.",
True,
] = None,
pixelscale: Annotated[
Optional[Union[Tensor, float]],
"The pixelscale of the source image in the lens plane",
False,
"arcsec/pixel",
] = None,
t_end: Annotated[
Optional[Union[Tensor, float]],
"The end time of the source image cube.",
False,
"seconds",
] = None,
shape: Annotated[
Optional[tuple[int, ...]], "The shape of the source image."
] = None,
name: NameType = None,
):
"""
Constructs the `PixelatedTime` object with the given parameters.
Parameters
----------
name : str
The name of the source.
x0 : Tensor, optional
The x-coordinate of the source image's center.
*Unit: arcsec*
y0 : Tensor, optional
The y-coordinate of the source image's center.
*Unit: arcsec*
cube : Tensor, optional
The source cube from which brightness values will be interpolated. Note the indexing of the cube should be cube[time][y][x]
pixelscale : Tensor, optional
The pixelscale of the source image in the lens plane.
*Unit: arcsec/pixel*
shape : Tuple of ints, optional
The shape of the source image.
"""
if cube is not None and cube.ndim not in [3, 4]:
raise ValueError(
f"image must be 3D or 4D (channels first). Received a {cube.ndim}D tensor)"
)
elif shape is not None and len(shape) not in [3, 4]:
raise ValueError(
f"shape must be specify 3D or 4D tensors. Received shape={shape}"
)
super().__init__(name=name)
self.add_param("x0", x0)
self.add_param("y0", y0)
self.add_param("cube", cube, shape)
self.pixelscale = pixelscale
self.t_end = t_end

@unpack
def brightness(
self,
x,
y,
t,
*args,
params: Optional["Packed"] = None,
x0: Optional[Tensor] = None,
y0: Optional[Tensor] = None,
cube: Optional[Tensor] = None,
**kwargs,
):
"""
Implements the `brightness` method for `Pixelated`.
The brightness at a given point is determined
by interpolating values from the source image.
Parameters
----------
x : Tensor
The x-coordinate(s) at which to calculate the source brightness.
This could be a single value or a tensor of values.
*Unit: arcsec*
y : Tensor
The y-coordinate(s) at which to calculate the source brightness.
This could be a single value or a tensor of values.
*Unit: arcsec*
t : Tensor
The time coordinate(s) at which to calculate the source brightness.
This could be a single value or a tensor of values.
*Unit: seconds*
params : Packed, optional
A dictionary containing additional parameters that might be required to
calculate the brightness.
Returns
-------
Tensor
The brightness of the source at the given coordinate(s).
The brightness is determined by interpolating values
from the source image.
*Unit: flux*
"""
fov_x = self.pixelscale * cube.shape[2]
fov_y = self.pixelscale * cube.shape[1]
return interp3d(
cube,
(x - x0).view(-1) / fov_x * 2,
(y - y0).view(-1) / fov_y * 2,
(t - self.t_end / 2).view(-1) / self.t_end * 2,
).reshape(x.shape)
117 changes: 117 additions & 0 deletions src/caustics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,123 @@ def interp2d(
return result


def interp3d(
cu: Tensor,
x: Tensor,
y: Tensor,
t: Tensor,
method: str = "linear",
padding_mode: str = "zeros",
) -> Tensor:
"""
Interpolates a 3D image at specified coordinates.
Similar to `torch.nn.functional.grid_sample` with `align_corners=False`.
Parameters
----------
cu: Tensor
A 3D tensor representing the cube.
x: Tensor
A 0D or 1D tensor of x coordinates at which to interpolate.
y: Tensor
A 0D or 1D tensor of y coordinates at which to interpolate.
t: Tensor
A 0D or 1D tensor of t coordinates at which to interpolate.
method: (str, optional)
Interpolation method. Either 'nearest' or 'linear'. Defaults to 'linear'.
padding_mode: (str, optional)
Defines the padding mode when out-of-bound indices are encountered.
Either 'zeros' or 'extrapolate'. Defaults to 'zeros'.
Raises
------
ValueError
If `cu` is not a 3D tensor.
ValueError
If `x` is not a 0D or 1D tensor.
ValueError
If `y` is not a 0D or 1D tensor.
ValueError
If `t` is not a 0D or 1D tensor.
ValueError
If `padding_mode` is not 'extrapolate' or 'zeros'.
ValueError
If `method` is not 'nearest' or 'linear'.
Returns
-------
Tensor
Tensor with the same shape as `x` and `y` containing the interpolated values.
"""
if cu.ndim != 3:
raise ValueError(f"im must be 3D (received {cu.ndim}D tensor)")
if x.ndim > 1:
raise ValueError(f"x must be 0 or 1D (received {x.ndim}D tensor)")
if y.ndim > 1:
raise ValueError(f"y must be 0 or 1D (received {y.ndim}D tensor)")
if t.ndim > 1:
raise ValueError(f"t must be 0 or 1D (received {t.ndim}D tensor)")
if padding_mode not in ["extrapolate", "zeros"]:
raise ValueError(f"{padding_mode} is not a valid padding mode")

idxs_out_of_bounds = (y < -1) | (y > 1) | (x < -1) | (x > 1) | (t < -1) | (t > 1)
# Convert coordinates to pixel indices
d, h, w = cu.shape
x = 0.5 * ((x + 1) * w - 1)
y = 0.5 * ((y + 1) * h - 1)
t = 0.5 * ((t + 1) * d - 1)

if method == "nearest":
result = cu[
t.round().long().clamp(0, d - 1),
y.round().long().clamp(0, h - 1),
x.round().long().clamp(0, w - 1),
]
elif method == "linear":
x0 = x.floor().long()
y0 = y.floor().long()
t0 = t.floor().long()
x1 = x0 + 1
y1 = y0 + 1
t1 = t0 + 1
x0 = x0.clamp(0, w - 2)
x1 = x1.clamp(1, w - 1)
y0 = y0.clamp(0, h - 2)
y1 = y1.clamp(1, h - 1)
t0 = t0.clamp(0, d - 2)
t1 = t1.clamp(1, d - 1)

fa = cu[t0, y0, x0]
fb = cu[t0, y1, x0]
fc = cu[t0, y0, x1]
fd = cu[t0, y1, x1]
fe = cu[t1, y0, x0]
ff = cu[t1, y1, x0]
fg = cu[t1, y0, x1]
fh = cu[t1, y1, x1]

xd = x - x0
yd = y - y0
td = t - t0

c00 = fa * (1 - xd) + fc * xd
c01 = fe * (1 - xd) + fg * xd
c10 = fb * (1 - xd) + fd * xd
c11 = ff * (1 - xd) + fh * xd

c0 = c00 * (1 - yd) + c10 * yd
c1 = c01 * (1 - yd) + c11 * yd

result = c0 * (1 - td) + c1 * td
else:
raise ValueError(f"{method} is not a valid interpolation method")

if padding_mode == "zeros": # else padding_mode == "extrapolate"
result = torch.where(idxs_out_of_bounds, torch.zeros_like(result), result)

return result


def vmap_n(
func: Callable,
depth: int = 1,
Expand Down
Loading

0 comments on commit 6d2d31b

Please sign in to comment.