Skip to content

Commit

Permalink
refactor: move quad_level to init (#245)
Browse files Browse the repository at this point in the history
* refactor: move quad_level to init

* update unit tests for LensSource

* fiw more unit tests

* yaml needs null not None

* yaml sim cant accept none

* refactor: lens source gridding system

* fix no source mu shape

* fix mu with no source

* actually fix it this time

* fix example usage docstring

* Add psf as parameter

* sim test quad level set

* sim should work

* get outputs right, shape and normalization

* add check conv2d vs fft

* fix fft roll in lens source

* update docs with new quad_level argument fmt

* remove unnecessary 'off' option for psf_mode
  • Loading branch information
ConnorStoneAstro authored Jul 27, 2024
1 parent 9806a84 commit f72d8ad
Show file tree
Hide file tree
Showing 9 changed files with 244 additions and 130 deletions.
12 changes: 6 additions & 6 deletions docs/source/examples/Example_ImageFit_LM.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"cosmology.to(dtype=torch.float32)\n",
"\n",
"upsample_factor = 1\n",
"quad_level = 3\n",
"thx, thy = caustics.utils.meshgrid(\n",
" pixelscale / upsample_factor,\n",
" upsample_factor * numPix,\n",
Expand Down Expand Up @@ -117,6 +118,7 @@
" pixels_x=numPix,\n",
" pixelscale=pixelscale,\n",
" upsample_factor=upsample_factor,\n",
" quad_level=quad_level,\n",
" z_s=2.0,\n",
")"
]
Expand All @@ -128,9 +130,7 @@
"source": [
"## Sample some mock data\n",
"\n",
"Here we write out the true values for all the parameters in the model. In total there are 21 parameters, so this is quite a complex model already! We then plot the data so we can see what it is we re trying to fit.\n",
"\n",
"Note that when we sample the simulator we call it with `quad_level=7`. This means the simulator will use gaussian quadrature sub-pixel integration to ensure the brightness of each pixel is very accurately computed."
"Here we write out the true values for all the parameters in the model. In total there are 21 parameters, so this is quite a complex model already! We then plot the data so we can see what it is we re trying to fit."
]
},
{
Expand Down Expand Up @@ -178,7 +178,7 @@
"print(true_params)\n",
"\n",
"# simulate lens, crop extra evaluation for PSF\n",
"true_system = sim(allparams, quad_level=7) # simulate at high resolution\n",
"true_system = sim(allparams)\n",
"\n",
"fig, axarr = plt.subplots(1, 2, figsize=(15, 8))\n",
"axarr[0].imshow(\n",
Expand Down Expand Up @@ -231,7 +231,7 @@
"res = caustics.utils.batch_lm(\n",
" batch_inits,\n",
" obs_system.reshape(-1).repeat(10, 1),\n",
" lambda x: sim(x, quad_level=3).reshape(-1),\n",
" lambda x: sim(x).reshape(-1),\n",
" C=variance.reshape(-1).repeat(10, 1),\n",
")\n",
"best_fit = res[0][np.argmin(res[2].numpy())]\n",
Expand Down Expand Up @@ -288,7 +288,7 @@
"outputs": [],
"source": [
"# Compute jacobian\n",
"J = torch.func.jacfwd(lambda x: sim(x, quad_level=3))(best_fit)\n",
"J = torch.func.jacfwd(lambda x: sim(x))(best_fit)\n",
"fig, axarr = plt.subplots(3, 7, figsize=(21, 9))\n",
"for i, ax in enumerate(axarr.flatten()):\n",
" ax.imshow(J[..., i], origin=\"lower\")\n",
Expand Down
7 changes: 3 additions & 4 deletions docs/source/examples/Example_ImageFit_NUTS.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"cosmology.to(dtype=torch.float32)\n",
"\n",
"upsample_factor = 1\n",
"quad_level = 3\n",
"thx, thy = caustics.utils.meshgrid(\n",
" pixelscale / upsample_factor,\n",
" upsample_factor * numPix,\n",
Expand Down Expand Up @@ -133,9 +134,7 @@
"source": [
"## Sample some mock data\n",
"\n",
"Here we write out the true values for all the parameters in the model. In total there are 21 parameters, so this is quite a complex model already! We then plot the data so we can see what it is we re trying to fit.\n",
"\n",
"Note that when we sample the simulator we call it with quad_level=7. This means the simulator will use gaussian quadrature sub-pixel integration to ensure the brightness of each pixel is very accurately computed."
"Here we write out the true values for all the parameters in the model. In total there are 21 parameters, so this is quite a complex model already! We then plot the data so we can see what it is we re trying to fit."
]
},
{
Expand Down Expand Up @@ -182,7 +181,7 @@
"print(true_params)\n",
"\n",
"# simulate lens, crop extra evaluation for PSF\n",
"true_system = sim(allparams, quad_level=7) # simulate at high resolution\n",
"true_system = sim(allparams) # simulate at high resolution\n",
"\n",
"fig, axarr = plt.subplots(1, 2, figsize=(15, 8))\n",
"axarr[0].imshow(\n",
Expand Down
10 changes: 5 additions & 5 deletions docs/source/intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ x = torch.tensor([
5.0, -0.2, 0.0, 0.8, 0.0, 1., 1.0, 10.0
]) # fmt: skip

minisim = caustics.LensSource(
lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100
sim = caustics.LensSource(
lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100, quad_level=3
)
plt.imshow(minisim(x, quad_level=3), origin="lower")
plt.imshow(sim(x), origin="lower")
plt.show()
```

Expand All @@ -54,7 +54,7 @@ plt.show()
newx = x.repeat(20, 1)
newx += torch.normal(mean=0, std=0.1 * torch.ones_like(newx))

images = torch.vmap(minisim)(newx)
images = torch.vmap(sim)(newx)

fig, axarr = plt.subplots(4, 5, figsize=(20, 16))
for ax, im in zip(axarr.flatten(), images):
Expand All @@ -67,7 +67,7 @@ plt.show()
### Automatic Differentiation

```python
J = torch.func.jacfwd(minisim)(x)
J = torch.func.jacfwd(sim)(x)

# Plot the new images
fig, axarr = plt.subplots(3, 7, figsize=(20, 9))
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/InterfaceIntroduction_oop.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@
"outputs": [],
"source": [
"sim = caustics.LensSource(\n",
" lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100\n",
" lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100, quad_level=3\n",
")"
]
},
Expand Down Expand Up @@ -174,7 +174,7 @@
"outputs": [],
"source": [
"# Substitute minisim with sim for yaml method\n",
"image = sim(x, quad_level=3).detach().cpu().numpy()\n",
"image = sim(x).detach().cpu().numpy()\n",
"\n",
"plt.imshow(image, origin=\"lower\")\n",
"plt.axis(\"off\")\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/InterfaceIntroduction_yaml.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
"source": [
"### Plot the Results!\n",
"\n",
"This section is mostly self explanatory. We evaluate the simulator configuration by calling it like a function. There are several possible arguments to the simulator function, here we use `quad_level` which indicates the depth of quadrature integration to use for each pixel (how accurate to make the flux)."
"This section is mostly self explanatory. We evaluate the simulator configuration by calling it like a function. "
]
},
{
Expand All @@ -110,7 +110,7 @@
"outputs": [],
"source": [
"# Here we sample an image\n",
"image = sim(x, quad_level=3).detach().cpu().numpy()\n",
"image = sim(x).detach().cpu().numpy()\n",
"\n",
"plt.imshow(image, origin=\"lower\")\n",
"plt.axis(\"off\")\n",
Expand Down
1 change: 1 addition & 0 deletions docs/source/tutorials/example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ simulator:
lens_light: *lnslt
pixelscale: 0.05
pixels_x: 100
quad_level: 3
Loading

0 comments on commit f72d8ad

Please sign in to comment.