Skip to content

Commit

Permalink
Added light stack module (#272)
Browse files Browse the repository at this point in the history
  • Loading branch information
ConnorStoneAstro authored Oct 18, 2024
1 parent 5afeab7 commit fb0b0e9
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/caustics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Pixelated,
PixelatedTime,
Sersic,
LightStack,
StarSource,
)
from . import utils
Expand Down Expand Up @@ -69,6 +70,7 @@
"Pixelated",
"PixelatedTime",
"Sersic",
"LightStack",
"StarSource",
"utils",
"LensSource",
Expand Down
3 changes: 2 additions & 1 deletion src/caustics/light/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .base import Source
from .pixelated import Pixelated
from .sersic import Sersic
from .light_stack import LightStack
from .pixelated_time import PixelatedTime
from .star_source import StarSource

__all__ = ["Source", "Pixelated", "PixelatedTime", "Sersic", "StarSource"]
__all__ = ["Source", "Pixelated", "PixelatedTime", "Sersic", "LightStack", "StarSource"]
92 changes: 92 additions & 0 deletions src/caustics/light/light_stack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# mypy: disable-error-code="operator,union-attr"
from typing import Optional, Annotated, List

import torch

from .base import Source, NameType
from ..parametrized import unpack
from ..packed import Packed

__all__ = ("LightStack",)


class LightStack(Source):
"""
``LightStack`` is a subclass of the abstract class ``Source`` which takes
the sum of multiple ``Source`` models to make a single brightness model.
Attributes
-----------
light_models: List[Source]
A list of light models to sum.
"""

def __init__(
self,
light_models: Annotated[
List[Source], "a list of light models to sum their brightnesses"
],
name: NameType = None,
):
"""
Constructs the ``LightStack`` object to sum multiple light models.
Parameters
----------
name: str
The name of the source.
light_models: List[Source]
A list of light models to sum.
"""
super().__init__(name=name)
self.light_models = light_models
for model in light_models:
self.add_parametrized(model)

@unpack
def brightness(
self,
x,
y,
*args,
params: Optional["Packed"] = None,
**kwargs,
):
"""
Implements the `brightness` method for `Sersic`. The brightness at a given point is
determined by the Sersic profile formula.
Parameters
----------
x: Tensor
The x-coordinate(s) at which to calculate the source brightness.
This could be a single value or a tensor of values.
*Unit: arcsec*
y: Tensor
The y-coordinate(s) at which to calculate the source brightness.
This could be a single value or a tensor of values.
*Unit: arcsec*
params: Packed, optional
Dynamic parameter container.
Returns
-------
Tensor
The brightness of the source at the given point(s).
The output tensor has the same shape as `x` and `y`.
*Unit: flux*
"""

brightness = torch.zeros_like(x)
for light_model in self.light_models:
brightness += light_model.brightness(x, y, params=params, **kwargs)
return brightness
30 changes: 30 additions & 0 deletions tests/test_light_stack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch

import caustics


def test_stack_sersic(device):
res = 0.05
nx = 100
ny = 100
thx, thy = caustics.utils.meshgrid(res, nx, ny, device=device)

models = []
params = []
for i in range(3):
sersic = caustics.Sersic(
name=f"sersic_{i}",
)
sersic.to(device=device)
models.append(sersic)
params.append(
torch.tensor([0.0 + 0.2 * i, 0.0, 0.5, 3.14 / 2, 2.0, 1.0 + 0.5 * i, 10.0])
)

stack = caustics.LightStack(light_models=models, name="stack")

brightness = stack.brightness(thx, thy, params=params)

assert brightness.shape == (nx, ny)
assert torch.all(brightness >= 0.0).item()
assert torch.all(torch.isfinite(brightness)).item()

0 comments on commit fb0b0e9

Please sign in to comment.