-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5afeab7
commit fb0b0e9
Showing
4 changed files
with
126 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
from .base import Source | ||
from .pixelated import Pixelated | ||
from .sersic import Sersic | ||
from .light_stack import LightStack | ||
from .pixelated_time import PixelatedTime | ||
from .star_source import StarSource | ||
|
||
__all__ = ["Source", "Pixelated", "PixelatedTime", "Sersic", "StarSource"] | ||
__all__ = ["Source", "Pixelated", "PixelatedTime", "Sersic", "LightStack", "StarSource"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# mypy: disable-error-code="operator,union-attr" | ||
from typing import Optional, Annotated, List | ||
|
||
import torch | ||
|
||
from .base import Source, NameType | ||
from ..parametrized import unpack | ||
from ..packed import Packed | ||
|
||
__all__ = ("LightStack",) | ||
|
||
|
||
class LightStack(Source): | ||
""" | ||
``LightStack`` is a subclass of the abstract class ``Source`` which takes | ||
the sum of multiple ``Source`` models to make a single brightness model. | ||
Attributes | ||
----------- | ||
light_models: List[Source] | ||
A list of light models to sum. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
light_models: Annotated[ | ||
List[Source], "a list of light models to sum their brightnesses" | ||
], | ||
name: NameType = None, | ||
): | ||
""" | ||
Constructs the ``LightStack`` object to sum multiple light models. | ||
Parameters | ||
---------- | ||
name: str | ||
The name of the source. | ||
light_models: List[Source] | ||
A list of light models to sum. | ||
""" | ||
super().__init__(name=name) | ||
self.light_models = light_models | ||
for model in light_models: | ||
self.add_parametrized(model) | ||
|
||
@unpack | ||
def brightness( | ||
self, | ||
x, | ||
y, | ||
*args, | ||
params: Optional["Packed"] = None, | ||
**kwargs, | ||
): | ||
""" | ||
Implements the `brightness` method for `Sersic`. The brightness at a given point is | ||
determined by the Sersic profile formula. | ||
Parameters | ||
---------- | ||
x: Tensor | ||
The x-coordinate(s) at which to calculate the source brightness. | ||
This could be a single value or a tensor of values. | ||
*Unit: arcsec* | ||
y: Tensor | ||
The y-coordinate(s) at which to calculate the source brightness. | ||
This could be a single value or a tensor of values. | ||
*Unit: arcsec* | ||
params: Packed, optional | ||
Dynamic parameter container. | ||
Returns | ||
------- | ||
Tensor | ||
The brightness of the source at the given point(s). | ||
The output tensor has the same shape as `x` and `y`. | ||
*Unit: flux* | ||
""" | ||
|
||
brightness = torch.zeros_like(x) | ||
for light_model in self.light_models: | ||
brightness += light_model.brightness(x, y, params=params, **kwargs) | ||
return brightness |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import torch | ||
|
||
import caustics | ||
|
||
|
||
def test_stack_sersic(device): | ||
res = 0.05 | ||
nx = 100 | ||
ny = 100 | ||
thx, thy = caustics.utils.meshgrid(res, nx, ny, device=device) | ||
|
||
models = [] | ||
params = [] | ||
for i in range(3): | ||
sersic = caustics.Sersic( | ||
name=f"sersic_{i}", | ||
) | ||
sersic.to(device=device) | ||
models.append(sersic) | ||
params.append( | ||
torch.tensor([0.0 + 0.2 * i, 0.0, 0.5, 3.14 / 2, 2.0, 1.0 + 0.5 * i, 10.0]) | ||
) | ||
|
||
stack = caustics.LightStack(light_models=models, name="stack") | ||
|
||
brightness = stack.brightness(thx, thy, params=params) | ||
|
||
assert brightness.shape == (nx, ny) | ||
assert torch.all(brightness >= 0.0).item() | ||
assert torch.all(torch.isfinite(brightness)).item() |