diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index a3222bf0..70ba13ef 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -57,7 +57,7 @@ jobs: ls -ltrh ls -ltrh dist - name: Publish to Test PyPI - uses: pypa/gh-action-pypi-publish@v1.9.0 + uses: pypa/gh-action-pypi-publish@v1.10.2 with: repository-url: https://test.pypi.org/legacy/ verbose: true @@ -95,5 +95,5 @@ jobs: name: artifact path: dist - - uses: pypa/gh-action-pypi-publish@v1.9.0 + - uses: pypa/gh-action-pypi-publish@v1.10.2 if: startsWith(github.ref, 'refs/tags') diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 244d6c92..7451bef5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -74,8 +74,9 @@ jobs: if: ${{ matrix.python-version == '3.10' && matrix.os == 'ubuntu-latest'}} uses: codecov/codecov-action@v4 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} with: - token: ${{ secrets.CODECOV_TOKEN }} files: ./coverage.xml flags: unittests name: codecov-umbrella diff --git a/README.md b/README.md index c4730fae..fe1576e8 100644 --- a/README.md +++ b/README.md @@ -47,10 +47,10 @@ x = torch.tensor([ 5.0, -0.2, 0.0, 0.8, 0.0, 1., 1.0, 10.0 ]) # fmt: skip -minisim = caustics.LensSource( +sim = caustics.LensSource( lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100 ) -plt.imshow(minisim(x, quad_level=3), origin="lower") +plt.imshow(sim(x), origin="lower") plt.axis("off") plt.show() ``` @@ -63,7 +63,7 @@ plt.show() newx = x.repeat(20, 1) newx += torch.normal(mean=0, std=0.1 * torch.ones_like(newx)) -images = torch.vmap(minisim)(newx) +images = torch.vmap(sim)(newx) fig, axarr = plt.subplots(4, 5, figsize=(20, 16)) for ax, im in zip(axarr.flatten(), images): @@ -76,7 +76,7 @@ plt.show() ### Automatic Differentiation ```python -J = torch.func.jacfwd(minisim)(x) +J = torch.func.jacfwd(sim)(x) # Plot the new images fig, axarr = plt.subplots(3, 7, figsize=(20, 9)) diff --git a/docs/requirements.txt b/docs/requirements.txt index 209d48f9..a6b3e0ae 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,3 +1,4 @@ +emcee ipywidgets jupyter-book matplotlib diff --git a/docs/source/_toc.yml b/docs/source/_toc.yml index d4a66389..9ad25f2e 100644 --- a/docs/source/_toc.yml +++ b/docs/source/_toc.yml @@ -21,11 +21,10 @@ chapters: - file: tutorials/InvertLensEquation - file: tutorials/Parameters - file: tutorials/Simulators - - file: tutorials/Playground - file: examples/index sections: - file: examples/Example_ImageFit_LM - - file: examples/Example_ImageFit_NUTS + - file: examples/Example_ImageFit_MCMC - file: examples/Example_QSOLensFit - file: contributing - file: frequently_asked_questions diff --git a/docs/source/examples/Example_ImageFit_LM.ipynb b/docs/source/examples/Example_ImageFit_LM.ipynb index 41f06858..d06aa40d 100644 --- a/docs/source/examples/Example_ImageFit_LM.ipynb +++ b/docs/source/examples/Example_ImageFit_LM.ipynb @@ -56,6 +56,7 @@ "cosmology.to(dtype=torch.float32)\n", "\n", "upsample_factor = 1\n", + "quad_level = 3\n", "thx, thy = caustics.utils.meshgrid(\n", " pixelscale / upsample_factor,\n", " upsample_factor * numPix,\n", @@ -117,6 +118,7 @@ " pixels_x=numPix,\n", " pixelscale=pixelscale,\n", " upsample_factor=upsample_factor,\n", + " quad_level=quad_level,\n", " z_s=2.0,\n", ")" ] @@ -128,16 +130,18 @@ "source": [ "## Sample some mock data\n", "\n", - "Here we write out the true values for all the parameters in the model. In total there are 21 parameters, so this is quite a complex model already! We then plot the data so we can see what it is we re trying to fit.\n", - "\n", - "Note that when we sample the simulator we call it with `quad_level=7`. This means the simulator will use gaussian quadrature sub-pixel integration to ensure the brightness of each pixel is very accurately computed." + "Here we write out the true values for all the parameters in the model. In total there are 21 parameters, so this is quite a complex model already! We then plot the data so we can see what it is we re trying to fit." ] }, { "cell_type": "code", "execution_count": null, "id": "7", - "metadata": {}, + "metadata": { + "tags": [ + "hide-input" + ] + }, "outputs": [], "source": [ "# Generate the mock data\n", @@ -178,7 +182,7 @@ "print(true_params)\n", "\n", "# simulate lens, crop extra evaluation for PSF\n", - "true_system = sim(allparams, quad_level=7) # simulate at high resolution\n", + "true_system = sim(allparams)\n", "\n", "fig, axarr = plt.subplots(1, 2, figsize=(15, 8))\n", "axarr[0].imshow(\n", @@ -231,7 +235,7 @@ "res = caustics.utils.batch_lm(\n", " batch_inits,\n", " obs_system.reshape(-1).repeat(10, 1),\n", - " lambda x: sim(x, quad_level=3).reshape(-1),\n", + " lambda x: sim(x).reshape(-1),\n", " C=variance.reshape(-1).repeat(10, 1),\n", ")\n", "best_fit = res[0][np.argmin(res[2].numpy())]\n", @@ -242,7 +246,11 @@ "cell_type": "code", "execution_count": null, "id": "10", - "metadata": {}, + "metadata": { + "tags": [ + "hide-input" + ] + }, "outputs": [], "source": [ "print(best_fit, allparams)\n", @@ -288,7 +296,7 @@ "outputs": [], "source": [ "# Compute jacobian\n", - "J = torch.func.jacfwd(lambda x: sim(x, quad_level=3))(best_fit)\n", + "J = torch.func.jacfwd(lambda x: sim(x))(best_fit)\n", "fig, axarr = plt.subplots(3, 7, figsize=(21, 9))\n", "for i, ax in enumerate(axarr.flatten()):\n", " ax.imshow(J[..., i], origin=\"lower\")\n", @@ -310,7 +318,11 @@ "cell_type": "code", "execution_count": null, "id": "14", - "metadata": {}, + "metadata": { + "tags": [ + "hide-cell" + ] + }, "outputs": [], "source": [ "def corner_plot_covariance(\n", diff --git a/docs/source/examples/Example_ImageFit_NUTS.ipynb b/docs/source/examples/Example_ImageFit_MCMC.ipynb similarity index 55% rename from docs/source/examples/Example_ImageFit_NUTS.ipynb rename to docs/source/examples/Example_ImageFit_MCMC.ipynb index e53bcbdf..3a8a613a 100644 --- a/docs/source/examples/Example_ImageFit_NUTS.ipynb +++ b/docs/source/examples/Example_ImageFit_MCMC.ipynb @@ -5,9 +5,9 @@ "id": "0", "metadata": {}, "source": [ - "# Modelling a Lens image using a No U-Turn Sampler\n", + "# Modelling a Lens image using MCMC\n", "\n", - "In this hypothetical scenario we have an image of galaxy galaxy strong lensing and we would like to recover a model of this scene. Thus we will need to determine parameters for the background source light, the lensing galaxy light, and the lensing galaxy mass distribution. A common technique for analyzing strong lensing systems is a Markov Chain Monte-Carlo which can explore the parameter space and provide us with important metrics about the model and uncertainty on all parameters. Since caustics is differentiable we have access to especially efficient gradient based MCMC algorithms. A very convenient algorithm is the No U-Turn Sampler, or NUTS, which uses derivarives to efficiently explore the likelihood distribution by treating it like a potential that a point mass is exploring. The NUST version we use as implemented in the Pyro package has no tunable parameters, thus we can simply give it a start point and it will explore for as many iterations as we give it. What's more, NUTS is so efficient that very often the autocorrelation length for the samples is approximately 1, meaning that each sample is independent from all the others! This is especially handy in the complex non-linear space of strong lensing models." + "In this hypothetical scenario we have an image of galaxy galaxy strong lensing and we would like to recover a model of this scene. Thus we will need to determine parameters for the background source light, the lensing galaxy light, and the lensing galaxy mass distribution. A common technique for analyzing strong lensing systems is a Markov Chain Monte-Carlo which can explore the parameter space and provide us with important metrics about the model and uncertainty on all parameters. Since caustics is differentiable we have access to especially efficient gradient based MCMC algorithms. A very convenient algorithm is the No U-Turn Sampler, or NUTS, which uses derivatives to efficiently explore the likelihood distribution by treating it like a potential that a point mass is exploring. The NUTS version we use as implemented in the Pyro package has no tunable parameters, thus we can simply give it a start point and it will explore for as many iterations as we give it. What's more, NUTS is so efficient that very often the autocorrelation length for the samples is approximately 1, meaning that each sample is independent from all the others! This is especially handy in the complex non-linear space of strong lensing models." ] }, { @@ -21,13 +21,16 @@ "import numpy as np\n", "import torch\n", "import matplotlib.pyplot as plt\n", + "from matplotlib import colormaps\n", "from matplotlib.patches import Ellipse\n", "from scipy.stats import norm\n", + "from tqdm.notebook import tqdm\n", "\n", "import pyro\n", "import pyro.distributions as dist\n", "from pyro.infer import MCMC as pyro_MCMC\n", - "from pyro.infer import NUTS as pyro_NUTS" + "from pyro.infer import NUTS as pyro_NUTS\n", + "import emcee" ] }, { @@ -61,6 +64,7 @@ "cosmology.to(dtype=torch.float32)\n", "\n", "upsample_factor = 1\n", + "quad_level = 3\n", "thx, thy = caustics.utils.meshgrid(\n", " pixelscale / upsample_factor,\n", " upsample_factor * numPix,\n", @@ -133,16 +137,18 @@ "source": [ "## Sample some mock data\n", "\n", - "Here we write out the true values for all the parameters in the model. In total there are 21 parameters, so this is quite a complex model already! We then plot the data so we can see what it is we re trying to fit.\n", - "\n", - "Note that when we sample the simulator we call it with quad_level=7. This means the simulator will use gaussian quadrature sub-pixel integration to ensure the brightness of each pixel is very accurately computed." + "Here we write out the true values for all the parameters in the model. In total there are 21 parameters, so this is quite a complex model already! We then plot the data so we can see what it is we re trying to fit." ] }, { "cell_type": "code", "execution_count": null, "id": "7", - "metadata": {}, + "metadata": { + "tags": [ + "hide-input" + ] + }, "outputs": [], "source": [ "# Generate the mock data\n", @@ -182,7 +188,7 @@ "print(true_params)\n", "\n", "# simulate lens, crop extra evaluation for PSF\n", - "true_system = sim(allparams, quad_level=7) # simulate at high resolution\n", + "true_system = sim(allparams) # simulate at high resolution\n", "\n", "fig, axarr = plt.subplots(1, 2, figsize=(15, 8))\n", "axarr[0].imshow(\n", @@ -226,7 +232,11 @@ "cell_type": "code", "execution_count": null, "id": "9", - "metadata": {}, + "metadata": { + "tags": [ + "hide-output" + ] + }, "outputs": [], "source": [ "def step(model, prior):\n", @@ -246,7 +256,7 @@ " \"jit_compile\": True,\n", " \"ignore_jit_warnings\": True,\n", " \"step_size\": 1e-3,\n", - " \"full_mass\": True,\n", + " \"full_mass\": False,\n", " \"adapt_step_size\": True,\n", " \"adapt_mass_matrix\": True,\n", "}\n", @@ -282,16 +292,31 @@ "metadata": {}, "outputs": [], "source": [ - "chain = mcmc.get_samples()[\"x\"]\n", - "chain = chain.numpy()\n", + "chain_nuts = mcmc.get_samples()[\"x\"]\n", + "chain_nuts = chain_nuts.numpy()\n", "\n", - "plt.plot(\n", - " range(len(chain)),\n", - " (chain - np.mean(chain, axis=0)) / np.std(chain, axis=0)\n", - " + 5 * np.arange(len(allparams)),\n", - ")\n", + "normed_chains = (chain_nuts - np.mean(chain_nuts, axis=0)) / np.std(chain_nuts, axis=0)\n", + "for i in range(chain_nuts.shape[1]):\n", + " plt.plot(normed_chains[:, i], color=colormaps[\"viridis\"](i / chain_nuts.shape[1]))\n", "plt.title(\"Chain for each parameter\")\n", - "plt.show()" + "plt.show()\n", + "\n", + "print(\n", + " \"Autocorrelation time: \",\n", + " np.mean(\n", + " emcee.autocorr.integrated_time(\n", + " chain_nuts, has_walkers=False, tol=10, quiet=True\n", + " )\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e6ebd65b", + "metadata": {}, + "source": [ + "As is common for NUTS sampling, the average autocorrelation time for the parameters is around 1, meaning that essentially every sample is independent. This is unlike a Metropolis-Hastings sampler like the one used in emcee (see the other example tutorial)." ] }, { @@ -308,7 +333,11 @@ "cell_type": "code", "execution_count": null, "id": "13", - "metadata": {}, + "metadata": { + "tags": [ + "hide-cell" + ] + }, "outputs": [], "source": [ "def corner_plot(\n", @@ -409,20 +438,336 @@ "metadata": {}, "outputs": [], "source": [ - "fig = corner_plot(chain, true_values=allparams.numpy())\n", + "fig = corner_plot(chain_nuts, true_values=allparams.numpy())\n", "plt.show()" ] }, + { + "cell_type": "markdown", + "id": "bd256dff", + "metadata": {}, + "source": [ + "## Fit using emcee\n", + "\n", + "We now model the data using emcee which handles standard Metropolis-Hastings MCMC sampling (plus a few tricks). First we need to construct a log likelihood function. In our case this is just the squared residuals, divided by the variance in each pixel. The rest is specific to the emcee implementation." + ] + }, { "cell_type": "code", "execution_count": null, "id": "16", "metadata": {}, "outputs": [], + "source": [ + "vsim = torch.vmap(sim)\n", + "\n", + "\n", + "def density(x):\n", + " # Log-likelihood function\n", + " model = vsim(torch.as_tensor(x))\n", + " log_likelihood_value = -0.5 * torch.sum(\n", + " ((model - obs_system) ** 2) / variance, dim=(1, 2)\n", + " )\n", + " log_likelihood_value = torch.nan_to_num(log_likelihood_value, nan=-np.inf)\n", + " return log_likelihood_value.numpy()\n", + "\n", + "\n", + "nwalkers = 64\n", + "ndim = len(allparams)\n", + "\n", + "sampler = emcee.EnsembleSampler(nwalkers, ndim, density, vectorize=True)\n", + "\n", + "x0 = allparams + 0.01 * torch.randn(nwalkers, ndim)\n", + "print(\"burn-in\")\n", + "state = sampler.run_mcmc(x0, 100, skip_initial_state_check=True) # burn-in\n", + "sampler.reset()\n", + "print(\"production\")\n", + "state = sampler.run_mcmc(state, 1000) # production" + ] + }, + { + "cell_type": "markdown", + "id": "d9ed9291", + "metadata": {}, + "source": [ + "We have taken 64000 samples in this demo, in general you would want many more (each chain needs to run longer than 1000 steps). Its always a good idea to plot the chains and check that they look uncorrelated. Here we can see the non zero autocorrelation length for one of the chains even over 1000 steps. This indicates we should run the chains much longer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d87abb8", + "metadata": {}, + "outputs": [], + "source": [ + "chain_mh = sampler.get_chain()\n", + "normed_chains = (chain_mh[:, 0] - np.mean(chain_mh[:, 0], axis=0)) / np.std(\n", + " chain_mh[:, 0], axis=0\n", + ")\n", + "for i in range(chain_mh.shape[2]):\n", + " plt.plot(normed_chains[:, i], color=colormaps[\"viridis\"](i / chain_mh.shape[2]))\n", + "plt.title(\"Chain for each parameter\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "480e849f", + "metadata": {}, + "source": [ + "Since the autocorrelation length is >1, we can compute an effective sample size to determine how many equivalent independent points we have drawn. As the warning suggests, in this demo we cannot compute the actual autocorrelation length, the autocorrelation length increases as we draw more samples (you can test this by changing the 1000 above to a larger number). Assuming that the autocorrelation is actually of a similar length to the chain (1000), this means we have drawn approximately 64 independent samples (one for each walker) which is similar to the NUTS example, though for NUTS we knew each sample was independent." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c5b0924f", + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " \"Autocorrelation time: \",\n", + " np.mean(emcee.autocorr.integrated_time(chain_mh, quiet=True)),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a51c577b", + "metadata": {}, + "source": [ + "Similar to the NUTS example, we may plot the samples in a corner plot. However, we thin the samples first so that the number of points is not overwhelming to plot. As you can see in the subfigures there is still a bloby structure of the samples, suggesting that the chains were not run long enough to converge." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "afae21ff", + "metadata": {}, + "outputs": [], + "source": [ + "N = chain_mh.shape[0] * chain_mh.shape[1]\n", + "fig = corner_plot(\n", + " np.concatenate(chain_mh, axis=0)[:: int(N / 200)],\n", + " true_values=allparams.numpy(),\n", + ")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "1a4a2be6", + "metadata": {}, + "source": [ + "# Fit with MALA sampling\n", + "\n", + "Metropolis Adjusted Langevin Algorithm sampling is the half way point between NUTS and MH, it uses gradient information to make an efficient proposal distribution for a MH step. We have written a basic implementation below for demo purposes. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0a6de8f", + "metadata": { + "tags": [ + "hide-cell" + ] + }, + "outputs": [], + "source": [ + "def mala_sampler(\n", + " initial_state, log_prob, log_prob_grad, num_samples, epsilon, mass_matrix\n", + "):\n", + " \"\"\"Metropolis Adjusted Langevin Algorithm (MALA) sampler with batch dimension.\n", + "\n", + " Args:\n", + " - initial_state (numpy array): Initial states of the chains, shape (num_chains, dim).\n", + " - log_prob (function): Function to compute the log probabilities of the current states.\n", + " - log_prob_grad (function): Function to compute the gradients of the log probabilities.\n", + " - num_samples (int): Number of samples to generate.\n", + " - epsilon (float): Step size for the Langevin dynamics.\n", + " - mass_matrix (numpy array): Mass matrix, shape (dim, dim), used to scale the dynamics.\n", + "\n", + "\n", + " Returns:\n", + " - samples (numpy array): Array of sampled values, shape (num_samples, num_chains, dim).\n", + " \"\"\"\n", + " num_chains, dim = initial_state.shape\n", + " samples = np.zeros((num_samples, num_chains, dim))\n", + " x_current = np.array(initial_state)\n", + " current_log_prob = log_prob(x_current)\n", + " inv_mass_matrix = np.linalg.inv(mass_matrix)\n", + " chol_inv_mass_matrix = np.linalg.cholesky(inv_mass_matrix)\n", + "\n", + " pbar = tqdm(range(num_samples))\n", + " acceptance_rate = np.zeros([0])\n", + " for i in pbar:\n", + " gradients = log_prob_grad(x_current)\n", + " noise = np.dot(np.random.randn(num_chains, dim), chol_inv_mass_matrix.T)\n", + " proposal = (\n", + " x_current\n", + " + 0.5 * epsilon**2 * np.dot(gradients, inv_mass_matrix)\n", + " + epsilon * noise\n", + " )\n", + "\n", + " # proposal = x_current + 0.5 * epsilon**2 * gradients + epsilon * np.random.randn(num_chains, *dim)\n", + " proposal_log_prob = log_prob(proposal)\n", + " # Metropolis-Hastings acceptance criterion, computed for each chain\n", + " acceptance_log_prob = proposal_log_prob - current_log_prob\n", + " accept = np.log(np.random.rand(num_chains)) < acceptance_log_prob\n", + " acceptance_rate = np.concatenate([acceptance_rate, accept])\n", + " pbar.set_description(f\"Acceptance rate: {acceptance_rate.mean():.2f}\")\n", + "\n", + " # Update states where accepted\n", + " x_current[accept] = proposal[accept]\n", + " current_log_prob[accept] = proposal_log_prob[accept]\n", + "\n", + " samples[i] = x_current\n", + "\n", + " return samples" + ] + }, + { + "cell_type": "markdown", + "id": "e934cbb9", + "metadata": {}, + "source": [ + "Here we run the MALA sampler after a small burn-in. We cheat a little bit and use the previous sampler to construct a mass matrix, this makes MALA more efficient but you could just as easily set the mass matrix to identity for the burn-in then use the burn-in samples to get a mediocre mass matrix, it only requires more fiddling with parameters (epsilon)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52c53b76", + "metadata": {}, + "outputs": [], + "source": [ + "def density_grad(x):\n", + " x = torch.as_tensor(x, dtype=torch.float32)\n", + " x.requires_grad = True\n", + " model = vsim(x)\n", + " log_likelihood_value = -0.5 * torch.sum(\n", + " ((model - obs_system) ** 2) / variance, dim=(1, 2)\n", + " )\n", + " log_likelihood_value = torch.nan_to_num(log_likelihood_value, nan=-np.inf)\n", + " log_likelihood_value.sum().backward()\n", + " return x.grad.numpy()\n", + "\n", + "\n", + "nwalkers = 32\n", + "x0 = allparams + 0.01 * torch.randn(nwalkers, ndim)\n", + "mass_matrix = np.linalg.inv(np.cov(chain_mh.reshape(-1, ndim), rowvar=False))\n", + "\n", + "chain_burnin_mala = mala_sampler(\n", + " initial_state=x0,\n", + " log_prob=density,\n", + " log_prob_grad=density_grad,\n", + " num_samples=100,\n", + " epsilon=3e-1,\n", + " mass_matrix=mass_matrix,\n", + ") # burn-in\n", + "\n", + "chain_mala = mala_sampler(\n", + " initial_state=chain_burnin_mala[-1],\n", + " log_prob=density,\n", + " log_prob_grad=density_grad,\n", + " num_samples=1000,\n", + " epsilon=3e-1,\n", + " mass_matrix=mass_matrix,\n", + ") # production" + ] + }, + { + "cell_type": "markdown", + "id": "e483c657", + "metadata": {}, + "source": [ + "PLotting the chains we see that they mix much better than the MH sampler, but less than NUTS, as would be expected." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fca7a07", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "normed_chains = (chain_mala[:, 0] - np.mean(chain_mala[:, 0], axis=0)) / np.std(\n", + " chain_mala[:, 0], axis=0\n", + ")\n", + "for i in range(chain_mala.shape[2]):\n", + " plt.plot(normed_chains[:, i], color=colormaps[\"viridis\"](i / chain_mala.shape[2]))\n", + "plt.title(\"Chain for each parameter\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "57c8399f", + "metadata": {}, + "source": [ + "The autocorrelation length is better than MH, but worse than NUTS as expected. Again the effective sample size can't be trusted and is probably comparable to our NUTS example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d058064a", + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " \"Autocorrelation time: \",\n", + " np.mean(emcee.autocorr.integrated_time(chain_mala, quiet=True)),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "30c45f4f", + "metadata": {}, + "source": [ + "The corner plot much like the MH example shows that more samples are needed given the visible blobs and gaps in the sampling." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e97dead", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "N = chain_mh.shape[0] * chain_mh.shape[1]\n", + "fig = corner_plot(\n", + " np.concatenate(chain_mh, axis=0)[:: int(N / 200)],\n", + " true_values=allparams.numpy(),\n", + ")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f84bcb90", + "metadata": {}, + "outputs": [], "source": [] } ], "metadata": { + "kernelspec": { + "display_name": "PY39", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -432,7 +777,8 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" + "pygments_lexer": "ipython3", + "version": "3.9.5" } }, "nbformat": 4, diff --git a/docs/source/examples/Example_QSOLensFit.ipynb b/docs/source/examples/Example_QSOLensFit.ipynb index dff248e0..c9febffa 100644 --- a/docs/source/examples/Example_QSOLensFit.ipynb +++ b/docs/source/examples/Example_QSOLensFit.ipynb @@ -106,11 +106,22 @@ " dtype=torch.float32,\n", ")\n", "\n", - "fig, ax = plt.subplots()\n", - "\n", "A = lens.jacobian_lens_equation(thx, thy, z_s, params)\n", - "detA = torch.linalg.det(A)\n", - "\n", + "detA = torch.linalg.det(A)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2263184", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots()\n", "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", "# Get the path from the matplotlib contour plot of the critical line\n", "paths = CS.allsegs[0]\n", @@ -190,7 +201,11 @@ "cell_type": "code", "execution_count": null, "id": "11", - "metadata": {}, + "metadata": { + "tags": [ + "hide-input" + ] + }, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", diff --git a/docs/source/interfaceindex.rst b/docs/source/interfaceindex.rst index 64c1eb58..10551691 100644 --- a/docs/source/interfaceindex.rst +++ b/docs/source/interfaceindex.rst @@ -5,6 +5,6 @@ There are three user interfaces to Caustics: configuration file, object-oriented Here we have built the same tutorial three times over so you can see how the three interfaces compare. They provide progressively more freedom and power, but also require more knowledge of the underlying system. -1. Configuration File: :doc:`tutorials/BasicIntroduction_yaml` -2. Object-Oriented: :doc:`tutorials/BasicIntroduction_oop` -3. Functional: :doc:`tutorials/BasicIntroduction_func` +1. Configuration File: :doc:`tutorials/InterfaceIntroduction_yaml` +2. Object-Oriented: :doc:`tutorials/InterfaceIntroduction_oop` +3. Functional: :doc:`tutorials/InterfaceIntroduction_func` diff --git a/docs/source/intro.md b/docs/source/intro.md index 2a438810..33cff80d 100644 --- a/docs/source/intro.md +++ b/docs/source/intro.md @@ -1,5 +1,7 @@ # Welcome to Caustics’ documentation! +![Logo GIF](../../media/caustics_logo.gif) + The lensing pipeline of the future: GPU-accelerated, automatically-differentiable, highly modular and extensible. @@ -39,10 +41,10 @@ x = torch.tensor([ 5.0, -0.2, 0.0, 0.8, 0.0, 1., 1.0, 10.0 ]) # fmt: skip -minisim = caustics.LensSource( - lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100 +sim = caustics.LensSource( + lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100, quad_level=3 ) -plt.imshow(minisim(x, quad_level=3), origin="lower") +plt.imshow(sim(x), origin="lower") plt.show() ``` @@ -54,7 +56,7 @@ plt.show() newx = x.repeat(20, 1) newx += torch.normal(mean=0, std=0.1 * torch.ones_like(newx)) -images = torch.vmap(minisim)(newx) +images = torch.vmap(sim)(newx) fig, axarr = plt.subplots(4, 5, figsize=(20, 16)) for ax, im in zip(axarr.flatten(), images): @@ -67,7 +69,7 @@ plt.show() ### Automatic Differentiation ```python -J = torch.func.jacfwd(minisim)(x) +J = torch.func.jacfwd(sim)(x) # Plot the new images fig, axarr = plt.subplots(3, 7, figsize=(20, 9)) diff --git a/docs/source/tutorials/InterfaceIntroduction_func.ipynb b/docs/source/tutorials/InterfaceIntroduction_func.ipynb index 8c90c395..ffae44c3 100644 --- a/docs/source/tutorials/InterfaceIntroduction_func.ipynb +++ b/docs/source/tutorials/InterfaceIntroduction_func.ipynb @@ -196,8 +196,19 @@ "metadata": {}, "outputs": [], "source": [ - "J = torch.func.jacfwd(sim)(x) # Substitute minisim with sim for the yaml method\n", - "\n", + "J = torch.func.jacfwd(sim)(x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(3, 7, figsize=(20, 9))\n", "labels = [\n", " \"z_s\",\n", diff --git a/docs/source/tutorials/InterfaceIntroduction_oop.ipynb b/docs/source/tutorials/InterfaceIntroduction_oop.ipynb index aaca80f7..dbb251a4 100644 --- a/docs/source/tutorials/InterfaceIntroduction_oop.ipynb +++ b/docs/source/tutorials/InterfaceIntroduction_oop.ipynb @@ -129,7 +129,7 @@ "outputs": [], "source": [ "sim = caustics.LensSource(\n", - " lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100\n", + " lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100, quad_level=3\n", ")" ] }, @@ -174,7 +174,7 @@ "outputs": [], "source": [ "# Substitute minisim with sim for yaml method\n", - "image = sim(x, quad_level=3).detach().cpu().numpy()\n", + "image = sim(x).detach().cpu().numpy()\n", "\n", "plt.imshow(image, origin=\"lower\")\n", "plt.axis(\"off\")\n", @@ -226,8 +226,19 @@ "source": [ "# Now lets compute the jacobian of the simulator wrt each parameter\n", "J = torch.func.jacfwd(sim)(x)\n", - "# The shape of J is (npixels y, npixels x, nparameters)\n", - "\n", + "# The shape of J is (npixels y, npixels x, nparameters)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(3, 7, figsize=(20, 9))\n", "labels = tuple(sim.state_dict().keys())[3:]\n", "for i, ax in enumerate(axarr.flatten()):\n", diff --git a/docs/source/tutorials/InterfaceIntroduction_yaml.ipynb b/docs/source/tutorials/InterfaceIntroduction_yaml.ipynb index 6eef1777..d262ea12 100644 --- a/docs/source/tutorials/InterfaceIntroduction_yaml.ipynb +++ b/docs/source/tutorials/InterfaceIntroduction_yaml.ipynb @@ -100,7 +100,7 @@ "source": [ "### Plot the Results!\n", "\n", - "This section is mostly self explanatory. We evaluate the simulator configuration by calling it like a function. There are several possible arguments to the simulator function, here we use `quad_level` which indicates the depth of quadrature integration to use for each pixel (how accurate to make the flux)." + "This section is mostly self explanatory. We evaluate the simulator configuration by calling it like a function. " ] }, { @@ -110,7 +110,7 @@ "outputs": [], "source": [ "# Here we sample an image\n", - "image = sim(x, quad_level=3).detach().cpu().numpy()\n", + "image = sim(x).detach().cpu().numpy()\n", "\n", "plt.imshow(image, origin=\"lower\")\n", "plt.axis(\"off\")\n", @@ -163,8 +163,19 @@ "source": [ "# Now lets compute the jacobian of the simulator wrt each parameter\n", "J = torch.func.jacfwd(sim)(x)\n", - "# The shape of J is (npixels y, npixels x, nparameters)\n", - "\n", + "# The shape of J is (npixels y, npixels x, nparameters)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(3, 7, figsize=(20, 9))\n", "labels = tuple(sim.state_dict().keys())[3:]\n", "for i, ax in enumerate(axarr.flatten()):\n", diff --git a/docs/source/tutorials/InvertLensEquation.ipynb b/docs/source/tutorials/InvertLensEquation.ipynb index 77720582..edd6dca9 100644 --- a/docs/source/tutorials/InvertLensEquation.ipynb +++ b/docs/source/tutorials/InvertLensEquation.ipynb @@ -7,7 +7,7 @@ "source": [ "# Inverting the Lens Equation\n", "\n", - "The lens equation $\\vec{\\beta} = \\vec{\\theta} - \\vec{\\alpha}(\\vec{\\theta})$ allows us to find a point in the source plane given a point in the image plane. However, sometimes we know a point in the source plane and would like to see where it ends up in the image plane. This is not easy to do since a point in the source plane may map to multiple locations in the image plane. There is no closed form function to invert the lens equation, in large part because the deflection angle $\\vec{\\alpha}$ depends on the position in the image plane $\\vec{\\theta}$. To invert the lens equation, we will need to rely on optimization and a little luck to find all the images for a given source plane point. Below we will demonstrate how this is done in caustic!" + "The lens equation $\\vec{\\beta} = \\vec{\\theta} - \\vec{\\alpha}(\\vec{\\theta})$ allows us to find a point in the source plane given a point in the image plane. However, sometimes we know a point in the source plane and would like to see where it ends up in the image plane. This is not easy to do since a point in the source plane may map to multiple locations in the image plane. There is no closed form function to invert the lens equation, in large part because the deflection angle $\\vec{\\alpha}$ depends on the position in the image plane $\\vec{\\theta}$. To invert the lens equation, we will need to rely on optimization and a iterative procedures to find all the images for a given source plane point. Below we will demonstrate how this is done in caustic!" ] }, { @@ -23,6 +23,8 @@ "\n", "import torch\n", "import matplotlib.pyplot as plt\n", + "from matplotlib.patches import Polygon\n", + "from matplotlib.collections import PatchCollection\n", "import numpy as np\n", "\n", "import caustics" @@ -64,6 +66,14 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "2163ed78", + "metadata": {}, + "source": [ + "Here we run the forward raytracing for our particular lens model. In caustics we provide a convenient `forward_raytrace` function which can be called for any lens model. Internally, this constructs a number of triangles in the image plane, raytraces them to the source plane and identifies which ones contain the desired source plane position. Iteratively subdividing the triangles eventually converges on image plane positions which map to the desired source plane position. See further down for more detail." + ] + }, { "cell_type": "code", "execution_count": null, @@ -82,11 +92,23 @@ "bx, by = lens.raytrace(x, y, z_s)" ] }, + { + "cell_type": "markdown", + "id": "462b2e8f", + "metadata": {}, + "source": [ + "When we raytrace the coordinates we get out from `forward_raytrace` it is not too surprising that they all give source plane positions very close to the desired source plane position. Here we plot them so you can see:" + ] + }, { "cell_type": "code", "execution_count": null, "id": "4", - "metadata": {}, + "metadata": { + "tags": [ + "hide-input" + ] + }, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", @@ -110,20 +132,581 @@ "ax.scatter(x, y, color=\"b\", label=\"forward raytrace\", zorder=10)\n", "ax.scatter(bx, by, color=\"r\", marker=\"x\", label=\"source plane\", zorder=9)\n", "ax.scatter([sp_x.item()], [sp_y.item()], color=\"g\", label=\"true pos\", zorder=8)\n", + "ax.set_axis_off()\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "6641535a", + "metadata": {}, + "source": [ + "It is also often not necessary to model the central demagnified region since it is so faint (approximately a 100,000 times fainter in this case) that it doesn't contribute measurably to the flux of an image. We can very easily check the magnification of every point and remove the unnecessary one." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5d20df7", + "metadata": {}, + "outputs": [], + "source": [ + "m = lens.magnification(x, y, z_s)\n", + "print(m.detach().cpu().tolist())\n", + "N_m = torch.argsort(m)\n", + "\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "# Get the path from the matplotlib contour plot of the critical line\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "\n", + "plt.scatter(x[N_m[1:]], y[N_m[1:]], color=\"b\", label=\"magnified\")\n", + "plt.scatter(x[N_m[0]], y[N_m[0]], color=\"r\", label=\"de-magnified\")\n", + "plt.axis(\"off\")\n", "plt.legend()\n", "plt.show()" ] }, + { + "cell_type": "markdown", + "id": "0163d5f7", + "metadata": {}, + "source": [ + "## Lets take a look\n", + "\n", + "Using the `LensSource` simulator and the forward raytracing coordinates we can focus our calculations on the regions of interest for each image. Note however that the regions can overlap, which they do very slightly in this case." + ] + }, { "cell_type": "code", "execution_count": null, "id": "5", "metadata": {}, "outputs": [], + "source": [ + "src = caustics.Sersic(\n", + " x0=0.2, y0=0.2, q=0.9, phi=0.0, n=1.0, Re=0.05, Ie=1.0, name=\"source\"\n", + ")\n", + "\n", + "sim = caustics.LensSource(\n", + " lens=lens, source=src, z_s=z_s, x0=None, y0=None, pixelscale=0.005, pixels_x=100\n", + ")\n", + "\n", + "# Plot the source and lens\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "# Get the path from the matplotlib contour plot of the critical line\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "for i in range(len(x)):\n", + " ax.imshow(\n", + " sim([x[i], y[i]]),\n", + " extent=(\n", + " -sim.pixelscale * sim.pixels_x / 2 + x[i],\n", + " sim.pixelscale * sim.pixels_x / 2 + x[i],\n", + " -sim.pixelscale * sim.pixels_y / 2 + y[i],\n", + " sim.pixelscale * sim.pixels_y / 2 + y[i],\n", + " ),\n", + " origin=\"lower\",\n", + " )\n", + "ax.set_xlim([-1.5, 2])\n", + "ax.set_ylim([-1.5, 2])\n", + "ax.set_axis_off()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "edb5fa5e", + "metadata": {}, + "source": [ + "This is much more efficient than evaluating a whole image. Below you can see the same setup but we see how the simulator spends a lot of pixels evaluating low flux areas that don't matter much for modelling." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0689864", + "metadata": {}, + "outputs": [], + "source": [ + "sim_wide = caustics.LensSource(\n", + " lens=lens, source=src, z_s=z_s, pixelscale=0.005, pixels_x=1000\n", + ")\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "ax.imshow(\n", + " sim_wide({}),\n", + " origin=\"lower\",\n", + " extent=(\n", + " -sim_wide.pixelscale * sim_wide.pixels_x / 2,\n", + " sim_wide.pixelscale * sim_wide.pixels_x / 2,\n", + " -sim_wide.pixelscale * sim_wide.pixels_y / 2,\n", + " sim_wide.pixelscale * sim_wide.pixels_y / 2,\n", + " ),\n", + ")\n", + "ax.set_xlim([-1.5, 2])\n", + "ax.set_ylim([-1.5, 2])\n", + "ax.set_axis_off()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "7a4c8fd5", + "metadata": {}, + "source": [ + "## How forward_raytrace works\n", + "\n", + "All forward raytracing methods are imperfect as they involve iterative solutions which require enough resolution to pick out all the relevant image plane positions. To start, lets consider a more naive algorithm, simply placing random points in the image plane, then running a root-finding algorithm to get the source plane positions to line up." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e06ac42", + "metadata": {}, + "outputs": [], + "source": [ + "Ninit = 100\n", + "x_init = (torch.rand(Ninit) - 0.5) * fov\n", + "y_init = (torch.rand(Ninit) - 0.5) * fov\n", + "\n", + "\n", + "def raytrace(x, y):\n", + " return lens.raytrace(x, y, z_s)\n", + "\n", + "\n", + "final = caustics.lenses.func.forward_raytrace_rootfind(\n", + " x_init, y_init, sp_x, sp_y, raytrace\n", + ")\n", + "x_final, y_final = final[..., 0], final[..., 1]\n", + "\n", + "# Pick only points that converged\n", + "bx_final, by_final = raytrace(x_final, y_final)\n", + "R = torch.sqrt((sp_x - bx_final) ** 2 + (sp_y - by_final) ** 2)\n", + "x_final = x_final[R < 1e-3]\n", + "y_final = y_final[R < 1e-3]" + ] + }, + { + "cell_type": "markdown", + "id": "2abb217e", + "metadata": {}, + "source": [ + "Here we easily find the four magnified images, but the central demagnified image is (often) not found by this method since a point has to get lucky enough to start very close to the correct position in order for the gradient based root finder to work." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2fe0e6f", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "colors = [\"tab:red\", \"tab:blue\", \"tab:green\", \"tab:orange\", \"tab:purple\"]\n", + "for c in colors:\n", + " if x_final.shape[0] == 0:\n", + " break\n", + " R = ((x_final[0] - x_final) ** 2 + (y_final[0] - y_final) ** 2).sqrt()\n", + " ax.scatter(x_init[R < 0.1], y_init[R < 0.1], color=c)\n", + " ax.scatter(x_final[0], y_final[0], color=\"k\", s=200, marker=\"*\")\n", + " ax.scatter(x_final[0], y_final[0], color=c, s=100, marker=\"*\")\n", + " x_init = x_init[R >= 0.1]\n", + " y_init = y_init[R >= 0.1]\n", + " x_final = x_final[R >= 0.1]\n", + " y_final = y_final[R >= 0.1]\n", + "ax.axes.set_axis_off()\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "b4c5d47b", + "metadata": {}, + "source": [ + "Let's now look at a more clever algorithm. We will map triangles in the image plane to triangles in the source plane, we may then explore recursively, any triangles which enclose the desired source point. Due to the non-linearity of the gravitational lensing transformation, we will also search the neighbor of any triangle that seems to have found an image position. First we highlight in green, any triangles which contain the source point, then expand to all their neighbors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5677ef22", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "n = 10\n", + "s = torch.stack((sp_x, sp_y))\n", + "# Construct a tiling of the image plane (squares at this point)\n", + "X, Y = torch.meshgrid(\n", + " torch.linspace(-fov / 2, fov / 2, n),\n", + " torch.linspace(-fov / 2, fov / 2, n),\n", + " indexing=\"ij\",\n", + ")\n", + "E1 = torch.stack((X, Y), dim=-1)\n", + "# build the upper and lower triangles within the squares of the grid\n", + "E1 = torch.cat(\n", + " (\n", + " torch.stack((E1[:-1, :-1], E1[:-1, 1:], E1[1:, 1:]), dim=-2),\n", + " torch.stack((E1[:-1, :-1], E1[1:, :-1], E1[1:, 1:]), dim=-2),\n", + " ),\n", + " dim=0,\n", + ").reshape(-1, 3, 2)\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "S = raytrace(E1[..., 0], E1[..., 1])\n", + "S = torch.stack(S, dim=-1)\n", + "\n", + "# Identify triangles that contain the source plane point\n", + "locate1 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n", + "patches = []\n", + "for e, loc in zip(E1, locate1):\n", + " patches.append(\n", + " Polygon(\n", + " e,\n", + " fill=loc,\n", + " alpha=0.4 if loc else 1,\n", + " color=\"tab:green\" if loc else \"k\",\n", + " linewidth=1,\n", + " )\n", + " )\n", + "p = PatchCollection(patches, match_original=True)\n", + "ax.add_collection(p)\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()\n", + "\n", + "# Get all the neighbors and upsample the triangles\n", + "E2 = E1[locate1]\n", + "E2 = torch.vmap(caustics.lenses.func.triangle_neighbors)(E2)\n", + "E2 = E2.reshape(-1, 3, 2)\n", + "E2 = caustics.lenses.func.remove_triangle_duplicates(E2)\n", + "# Upsample the triangles\n", + "E2 = torch.vmap(caustics.lenses.func.triangle_upsample)(E2)\n", + "E2 = E2.reshape(-1, 3, 2)\n", + "\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "S = raytrace(E2[..., 0], E2[..., 1])\n", + "S = torch.stack(S, dim=-1)\n", + "\n", + "# Identify triangles that contain the source plane point\n", + "locate2 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n", + "patches = []\n", + "for e, loc in zip(E2, locate2):\n", + " patches.append(\n", + " Polygon(\n", + " e,\n", + " fill=loc,\n", + " alpha=0.4 if loc else 1,\n", + " color=\"tab:green\" if loc else \"k\",\n", + " linewidth=0.5,\n", + " )\n", + " )\n", + "p = PatchCollection(patches, match_original=True)\n", + "ax.add_collection(p)\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "61fdd482", + "metadata": {}, + "source": [ + "The process repeats until the triangles have converged to a very small area, at which point we then run a root finding algorithm to get the final points. The central region is a very unstable optimum, so we need to use the triangle method for several iterations before we can run the root finder to get the exact optimal point." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ef54c41", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "# Get all the neighbors and upsample the triangles\n", + "E3 = E2[locate2]\n", + "E3 = torch.vmap(caustics.lenses.func.triangle_neighbors)(E3)\n", + "E3 = E3.reshape(-1, 3, 2)\n", + "E3 = caustics.lenses.func.remove_triangle_duplicates(E3)\n", + "# Upsample the triangles\n", + "E3 = torch.vmap(caustics.lenses.func.triangle_upsample)(E3)\n", + "E3 = E3.reshape(-1, 3, 2)\n", + "\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "S = raytrace(E3[..., 0], E3[..., 1])\n", + "S = torch.stack(S, dim=-1)\n", + "\n", + "# Identify triangles that contain the source plane point\n", + "locate3 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n", + "patches = []\n", + "for e, loc in zip(E3, locate3):\n", + " patches.append(\n", + " Polygon(\n", + " e,\n", + " fill=loc,\n", + " alpha=0.4 if loc else 1,\n", + " color=\"tab:green\" if loc else \"k\",\n", + " linewidth=0.5,\n", + " )\n", + " )\n", + "p = PatchCollection(patches, match_original=True)\n", + "ax.add_collection(p)\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()\n", + "\n", + "# Get all the neighbors and upsample the triangles\n", + "E4 = E3[locate3]\n", + "E4 = torch.vmap(caustics.lenses.func.triangle_neighbors)(E4)\n", + "E4 = E4.reshape(-1, 3, 2)\n", + "E4 = caustics.lenses.func.remove_triangle_duplicates(E4)\n", + "# Upsample the triangles\n", + "E4 = torch.vmap(caustics.lenses.func.triangle_upsample)(E4)\n", + "E4 = E4.reshape(-1, 3, 2)\n", + "\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "S = raytrace(E4[..., 0], E4[..., 1])\n", + "S = torch.stack(S, dim=-1)\n", + "\n", + "# Identify triangles that contain the source plane point\n", + "locate4 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n", + "patches = []\n", + "for e, loc in zip(E4, locate4):\n", + " patches.append(\n", + " Polygon(\n", + " e,\n", + " fill=loc,\n", + " alpha=0.4 if loc else 1,\n", + " color=\"tab:green\" if loc else \"k\",\n", + " linewidth=0.5,\n", + " )\n", + " )\n", + "p = PatchCollection(patches, match_original=True)\n", + "ax.add_collection(p)\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()\n", + "\n", + "# Get all the neighbors and upsample the triangles\n", + "E5 = E4[locate4]\n", + "E5 = torch.vmap(caustics.lenses.func.triangle_neighbors)(E5)\n", + "E5 = E5.reshape(-1, 3, 2)\n", + "E5 = caustics.lenses.func.remove_triangle_duplicates(E5)\n", + "# Upsample the triangles\n", + "E5 = torch.vmap(caustics.lenses.func.triangle_upsample)(E5)\n", + "E5 = E5.reshape(-1, 3, 2)\n", + "\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "S = raytrace(E5[..., 0], E5[..., 1])\n", + "S = torch.stack(S, dim=-1)\n", + "\n", + "# Identify triangles that contain the source plane point\n", + "locate5 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n", + "patches = []\n", + "for e, loc in zip(E5, locate5):\n", + " patches.append(\n", + " Polygon(\n", + " e,\n", + " fill=loc,\n", + " alpha=0.4 if loc else 1,\n", + " color=\"tab:green\" if loc else \"k\",\n", + " linewidth=0.5,\n", + " )\n", + " )\n", + "p = PatchCollection(patches, match_original=True)\n", + "ax.add_collection(p)\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()\n", + "\n", + "# Get all the neighbors and upsample the triangles\n", + "E6 = E5[locate5]\n", + "E6 = torch.vmap(caustics.lenses.func.triangle_neighbors)(E6)\n", + "E6 = E6.reshape(-1, 3, 2)\n", + "E6 = caustics.lenses.func.remove_triangle_duplicates(E6)\n", + "# Upsample the triangles\n", + "E6 = torch.vmap(caustics.lenses.func.triangle_upsample)(E6)\n", + "E6 = E6.reshape(-1, 3, 2)\n", + "\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "S = raytrace(E6[..., 0], E6[..., 1])\n", + "S = torch.stack(S, dim=-1)\n", + "\n", + "# Identify triangles that contain the source plane point\n", + "locate6 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n", + "patches = []\n", + "for e, loc in zip(E6, locate6):\n", + " patches.append(\n", + " Polygon(\n", + " e,\n", + " fill=loc,\n", + " alpha=0.4 if loc else 1,\n", + " color=\"tab:green\" if loc else \"k\",\n", + " linewidth=0.5,\n", + " )\n", + " )\n", + "p = PatchCollection(patches, match_original=True)\n", + "ax.add_collection(p)\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()\n", + "\n", + "\n", + "# Run the root finding algorithm\n", + "E7 = E6[locate6].sum(dim=1) / 3\n", + "E7 = caustics.lenses.func.forward_raytrace_rootfind(\n", + " E7[..., 0], E7[..., 1], s[0], s[1], raytrace\n", + ")\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "ax.scatter(E7[..., 0], E7[..., 1], color=\"k\", s=100, marker=\"*\")\n", + "ax.scatter(E7[..., 0], E7[..., 1], color=\"tab:green\", s=50, marker=\"*\")\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e421807b", + "metadata": {}, + "outputs": [], "source": [] } ], "metadata": { + "kernelspec": { + "display_name": "PY39", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -133,7 +716,8 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" + "pygments_lexer": "ipython3", + "version": "3.9.5" } }, "nbformat": 4, diff --git a/docs/source/tutorials/LensZoo.ipynb b/docs/source/tutorials/LensZoo.ipynb index 53e5fcd7..51d9cdc5 100644 --- a/docs/source/tutorials/LensZoo.ipynb +++ b/docs/source/tutorials/LensZoo.ipynb @@ -94,7 +94,20 @@ " pixels_x=n_pix,\n", " z_s=z_s,\n", " upsample_factor=2,\n", - ")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc5596c8", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "# axarr[0].imshow(np.log10(convergence.numpy()), origin = \"lower\")\n", "axarr[0].axis(\"off\")\n", @@ -135,7 +148,20 @@ " pixels_x=n_pix,\n", " z_s=z_s,\n", " upsample_factor=2,\n", - ")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c06143aa", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", @@ -181,7 +207,20 @@ " pixels_x=n_pix,\n", " z_s=z_s,\n", " upsample_factor=2,\n", - ")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87b924bb", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", @@ -228,7 +267,20 @@ " pixels_x=n_pix,\n", " z_s=z_s,\n", " upsample_factor=2,\n", - ")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56818b62", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", @@ -275,7 +327,20 @@ " pixels_x=n_pix,\n", " z_s=z_s,\n", " upsample_factor=2,\n", - ")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "102eac0c", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", @@ -325,7 +390,20 @@ " pixels_x=n_pix,\n", " z_s=z_s,\n", " upsample_factor=2,\n", - ")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "984b8ece", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", @@ -375,7 +453,20 @@ " pixels_x=n_pix,\n", " z_s=z_s,\n", " upsample_factor=2,\n", - ")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bfe89486", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", @@ -420,7 +511,20 @@ " pixels_x=n_pix,\n", " z_s=z_s,\n", " upsample_factor=2,\n", - ")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6964846c", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "# convergence = avg_pool2d(lens.convergence(thx, thy, z_s, z_l).squeeze()[None, None], upsample_factor).squeeze()\n", "# axarr[0].imshow(np.log10(convergence.numpy()), origin = \"lower\")\n", @@ -462,7 +566,20 @@ " pixels_x=n_pix,\n", " z_s=z_s,\n", " upsample_factor=2,\n", - ")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", @@ -475,14 +592,6 @@ "axarr[1].set_title(\"Lensed Sersic\")\n", "plt.show()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "21", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/docs/source/tutorials/MultiplaneDemo.ipynb b/docs/source/tutorials/MultiplaneDemo.ipynb index 6b7dff66..4f3d4296 100644 --- a/docs/source/tutorials/MultiplaneDemo.ipynb +++ b/docs/source/tutorials/MultiplaneDemo.ipynb @@ -100,14 +100,25 @@ { "cell_type": "code", "execution_count": null, - "id": "5", + "id": "14291ce3", "metadata": {}, "outputs": [], "source": [ "# Effective reduced deflection angles for the multiplane lens system\n", - "ax, ay = lens.effective_reduced_deflection_angle(thx, thy, z_s)\n", - "\n", - "# Plot\n", + "ax, ay = lens.effective_reduced_deflection_angle(thx, thy, z_s)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(12, 5))\n", "im = axarr[0].imshow(\n", " ax, extent=(thx[0][0], thx[0][-1], thy[0][0], thy[-1][0]), origin=\"lower\"\n", @@ -141,8 +152,20 @@ "A = lens.jacobian_lens_equation(thx, thy, z_s)\n", "\n", "# Here we compute detA at every point\n", - "detA = torch.linalg.det(A)\n", - "\n", + "detA = torch.linalg.det(A)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57bb5d70", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "# Plot the critical line\n", "im = plt.imshow(\n", " np.log10(np.abs(detA.detach().cpu().numpy())),\n", @@ -159,7 +182,11 @@ "cell_type": "code", "execution_count": null, "id": "8", - "metadata": {}, + "metadata": { + "tags": [ + "hide-input" + ] + }, "outputs": [], "source": [ "# For completeness, here are the caustics!\n", diff --git a/docs/source/tutorials/Playground.ipynb b/docs/source/tutorials/Playground.ipynb deleted file mode 100644 index 482ef4f2..00000000 --- a/docs/source/tutorials/Playground.ipynb +++ /dev/null @@ -1,230 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "0", - "metadata": {}, - "source": [ - "# Lensing playground\n", - "\n", - "This is just a fun notebook where you can interactively change lensing parameters. It is a great way to build some intuition around lensing and make some cool pictures!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import torch\n", - "from torch.nn.functional import avg_pool2d\n", - "import matplotlib.pyplot as plt\n", - "from ipywidgets import interact\n", - "import numpy as np\n", - "\n", - "import caustics" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2", - "metadata": {}, - "outputs": [], - "source": [ - "n_pix = 100\n", - "res = 0.05\n", - "upsample_factor = 2\n", - "fov = res * n_pix\n", - "thx, thy = caustics.utils.meshgrid(\n", - " res / upsample_factor,\n", - " upsample_factor * n_pix,\n", - " dtype=torch.float32,\n", - ")\n", - "z_l = torch.tensor(0.5, dtype=torch.float32)\n", - "z_s = torch.tensor(1.5, dtype=torch.float32)\n", - "cosmology = caustics.FlatLambdaCDM(name=\"cosmo\")\n", - "cosmology.to(dtype=torch.float32)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3", - "metadata": {}, - "outputs": [], - "source": [ - "# SIE lens model, kappa map, alpha map, magnification, time delay, caustics\n", - "\n", - "\n", - "def plot_lens_metrics(thx0, thy0, q, phi, b):\n", - " lens = caustics.SIE(\n", - " cosmology=cosmology,\n", - " z_l=z_l,\n", - " x0=thx0,\n", - " y0=thy0,\n", - " q=q,\n", - " phi=phi,\n", - " b=b,\n", - " )\n", - " fig, axarr = plt.subplots(2, 3, figsize=(9, 6))\n", - " kappa = avg_pool2d(\n", - " lens.convergence(thx, thy, z_s)[None, None, :, :], upsample_factor\n", - " )[0, 0]\n", - " axarr[0][0].imshow(torch.log10(kappa), origin=\"lower\")\n", - " axarr[0][0].set_title(\"log(convergence)\")\n", - " psi = avg_pool2d(lens.potential(thx, thy, z_s)[None, None, :, :], upsample_factor)[\n", - " 0, 0\n", - " ]\n", - " axarr[0][1].imshow(psi, origin=\"lower\")\n", - " axarr[0][1].set_title(\"potential\")\n", - " timedelay = avg_pool2d(\n", - " lens.time_delay(thx, thy, z_s)[None, None, :, :], upsample_factor\n", - " )[0, 0]\n", - " axarr[0][2].imshow(timedelay, origin=\"lower\")\n", - " axarr[0][2].set_title(\"time delay\")\n", - " magnification = avg_pool2d(\n", - " lens.magnification(thx, thy, z_s)[None, None, :, :], upsample_factor\n", - " )[0, 0]\n", - " axarr[1][0].imshow(torch.log10(magnification), origin=\"lower\")\n", - " axarr[1][0].set_title(\"log(magnification)\")\n", - " alpha = lens.reduced_deflection_angle(thx, thy, z_s)\n", - " alpha0 = avg_pool2d(alpha[0][None, None, :, :], upsample_factor)[0, 0]\n", - " alpha1 = avg_pool2d(alpha[1][None, None, :, :], upsample_factor)[0, 0]\n", - " axarr[1][1].imshow(alpha0, origin=\"lower\")\n", - " axarr[1][1].set_title(\"deflection angle x\")\n", - " axarr[1][2].imshow(alpha1, origin=\"lower\")\n", - " axarr[1][2].set_title(\"deflection angle y\")\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4", - "metadata": {}, - "outputs": [], - "source": [ - "p = interact(\n", - " plot_lens_metrics,\n", - " thx0=(-2.5, 2.5, 0.1),\n", - " thy0=(-2.5, 2.5, 0.1),\n", - " q=(0.01, 0.99, 0.01),\n", - " phi=(0.0, np.pi, np.pi / 25),\n", - " b=(0.1, 2.0, 0.1),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5", - "metadata": {}, - "outputs": [], - "source": [ - "# Sersic source, demo lensed source\n", - "def plot_lens_distortion(\n", - " x0_lens,\n", - " y0_lens,\n", - " q_lens,\n", - " phi_lens,\n", - " b_lens,\n", - " x0_src,\n", - " y0_src,\n", - " q_src,\n", - " phi_src,\n", - " n_src,\n", - " Re_src,\n", - " Ie_src,\n", - "):\n", - " lens = caustics.SIE(\n", - " cosmology,\n", - " z_l,\n", - " x0=x0_lens,\n", - " y0=y0_lens,\n", - " q=q_lens,\n", - " phi=phi_lens,\n", - " b=b_lens,\n", - " )\n", - " source = caustics.Sersic(\n", - " x0=x0_src,\n", - " y0=y0_src,\n", - " q=q_src,\n", - " phi=phi_src,\n", - " n=n_src,\n", - " Re=Re_src,\n", - " Ie=Ie_src,\n", - " )\n", - " fig, axarr = plt.subplots(1, 3, figsize=(18, 6))\n", - " brightness = avg_pool2d(\n", - " source.brightness(thx, thy)[None, None, :, :], upsample_factor\n", - " )[0, 0]\n", - " axarr[0].imshow(brightness, origin=\"lower\")\n", - " axarr[0].set_title(\"Sersic source\")\n", - " kappa = avg_pool2d(\n", - " lens.convergence(thx, thy, z_s)[None, None, :, :], upsample_factor\n", - " )[0, 0]\n", - " axarr[1].imshow(torch.log10(kappa), origin=\"lower\")\n", - " axarr[1].set_title(\"lens log(convergence)\")\n", - " beta_x, beta_y = lens.raytrace(thx, thy, z_s)\n", - " mu = avg_pool2d(\n", - " source.brightness(beta_x, beta_y)[None, None, :, :], upsample_factor\n", - " )[0, 0]\n", - " axarr[2].imshow(mu, origin=\"lower\")\n", - " axarr[2].set_title(\"Sersic lensed\")\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6", - "metadata": {}, - "outputs": [], - "source": [ - "p = interact(\n", - " plot_lens_distortion,\n", - " x0_lens=(-2.5, 2.5, 0.1),\n", - " y0_lens=(-2.5, 2.5, 0.1),\n", - " q_lens=(0.01, 0.99, 0.01),\n", - " phi_lens=(0.0, np.pi, np.pi / 25),\n", - " b_lens=(0.1, 2.0, 0.1),\n", - " x0_src=(-2.5, 2.5, 0.1),\n", - " y0_src=(-2.5, 2.5, 0.1),\n", - " q_src=(0.01, 0.99, 0.01),\n", - " phi_src=(0.0, np.pi, np.pi / 25),\n", - " n_src=(0.5, 4, 0.1),\n", - " Re_src=(0.1, 2, 0.1),\n", - " Ie_src=(0.1, 2.0, 0.1),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/source/tutorials/VisualizeCaustics.ipynb b/docs/source/tutorials/VisualizeCaustics.ipynb index ec758794..d487a2fe 100644 --- a/docs/source/tutorials/VisualizeCaustics.ipynb +++ b/docs/source/tutorials/VisualizeCaustics.ipynb @@ -113,7 +113,11 @@ "cell_type": "code", "execution_count": null, "id": "6", - "metadata": {}, + "metadata": { + "tags": [ + "hide-input" + ] + }, "outputs": [], "source": [ "# Get the path from the matplotlib contour plot of the critical line\n", diff --git a/docs/source/tutorials/example.yml b/docs/source/tutorials/example.yml index c9a457f9..c161d782 100644 --- a/docs/source/tutorials/example.yml +++ b/docs/source/tutorials/example.yml @@ -26,3 +26,4 @@ simulator: lens_light: *lnslt pixelscale: 0.05 pixels_x: 100 + quad_level: 3 diff --git a/docs/source/websim.md b/docs/source/websim.md new file mode 100644 index 00000000..377f8a4d --- /dev/null +++ b/docs/source/websim.md @@ -0,0 +1,28 @@ +# Live Caustics Playground + +Here is a little online simulator run using caustics! It's slow because it's +running on a free server, but it's a good way to play with the simulator without +having to install anything. + +Pro tip: check out the "Pixelated" source to lens any image you want! + +
+ +
+ +If the window doesn't show up just follow this link: +[caustics-webapp](https://ciela-institute-caustics-webapp-guistreamlit-app-yanhhm.streamlit.app) + +For frequent simulator users (e.g., if you plan on exploring the parameter space +of a lens), we recommend installing the simulator locally and running it in your +browser. Follow the steps below: + +1. Install Caustics. Please follow the instructions on the + [install page](https://caustics.readthedocs.io/en/latest/install.html). +2. `pip install streamlit` +3. `git clone https://github.com/Ciela-Institute/caustics-webapp.git` +4. Move into the `caustics-webapp/gui/` directory and run the following command: + `streamlit run streamlit_app.py --server.enableXsrfProtection=false` + +If you were successful in installing the simulator, Step 4 should automatically +open the simulator in your default browser. diff --git a/docs/source/websim.rst b/docs/source/websim.rst deleted file mode 100644 index f74e1e27..00000000 --- a/docs/source/websim.rst +++ /dev/null @@ -1,17 +0,0 @@ -Live Caustics Playground -======================== - -Here is a little online simulator run using caustics! It's slow because it's running on a free server, but it's a good way to play with the simulator without having to install anything. - -Pro tip: check out the "Pixelated" source to lens any image you want! - -`Launch the simulator! `_ - -For frequent simulator users (e.g., if you plan on exploring the parameter space of a lens), we recommend installing the simulator locally and running it in your browser. Follow the steps below: - -1. Install Caustics. Please follow the instructions on the `install page `_. -2. ``pip install streamlit`` -3. ``git clone https://github.com/Ciela-Institute/caustics-webapp.git`` -4. Move into the ``caustics-webapp/gui/`` directory and run the following command: ``streamlit run streamlit_app.py --server.enableXsrfProtection=false`` - -If you were successful in installing the simulator, Step 4 should automatically open the simulator in your default browser. diff --git a/media/caustics_logo.gif b/media/caustics_logo.gif new file mode 100644 index 00000000..18c022f7 Binary files /dev/null and b/media/caustics_logo.gif differ diff --git a/src/caustics/lenses/base.py b/src/caustics/lenses/base.py index 11673e3b..ee819f2b 100644 --- a/src/caustics/lenses/base.py +++ b/src/caustics/lenses/base.py @@ -130,13 +130,16 @@ def forward_raytrace( z_s: Tensor, *args, params: Optional["Packed"] = None, - epsilon=1e-2, - n_init=100, - fov=5.0, + epsilon: float = 1e-3, + x0: Optional[Tensor] = None, + y0: Optional[Tensor] = None, + fov: float = 5.0, + divisions: int = 100, **kwargs, ) -> tuple[Tensor, Tensor]: """ - Perform a forward ray-tracing operation which maps from the source plane to the image plane. + Perform a forward ray-tracing operation which maps from the source plane + to the image plane. Parameters ---------- @@ -159,18 +162,20 @@ def forward_raytrace( Dynamic parameter container for the lens model. Defaults to None. epsilon: Tensor - maximum distance between two images (arcsec) before they are considered the same image. + maximum distance between two images (arcsec) before they are + considered the same image. *Unit: arcsec* - n_init: int - number of random initialization points used to try and find image plane points. - fov: float the field of view in which the initial random samples are taken. *Unit: arcsec* + divisions: int + the number of divisions of the fov on each axis when constructing + the grid to perform in the triangle search. + Returns ------- x_component: Tensor @@ -183,18 +188,41 @@ def forward_raytrace( *Unit: arcsec* """ - - # TODO make FOV more general so that it doesn't have to be centered on zero,zero - if fov is None: - raise ValueError("fov must be given to generate initial guesses") - + raytrace = partial(self.raytrace, params=params, z_s=z_s) + if x0 is None: + x0 = torch.zeros((), device=bx.device, dtype=bx.dtype) + if y0 is None: + y0 = torch.zeros((), device=by.device, dtype=by.dtype) + # X = torch.stack((x0, y0)).repeat(4, 1) + # X[0] -= fov / 2 + # X[1][0] -= fov / 2 + # X[1][1] += fov / 2 + # X[2][0] += fov / 2 + # X[2][1] -= fov / 2 + # X[3] += fov / 2 + + # Sx, Sy = raytrace(X[..., 0], X[..., 1]) + # S = torch.stack((Sx, Sy)).T + # res1, ap1 = func.triangle_search( + # torch.stack((bx, by)), + # X[:3], + # S[:3], + # raytrace, + # epsilon, + # torch.zeros((0, 2)), + # ) + # res2, ap2 = func.triangle_search( + # torch.stack((bx, by)), + # X[1:], + # S[1:], + # raytrace, + # epsilon, + # torch.zeros((0, 2)), + # ) + # res = torch.cat((res1, res2), dim=0) + # return res[:, 0], res[:, 1], torch.cat((ap1, ap2), dim=0) return func.forward_raytrace( - bx, - by, - partial(self.raytrace, params=params, z_s=z_s), - epsilon, - n_init, - fov, + torch.stack((bx, by)), raytrace, x0, y0, fov, divisions, epsilon ) diff --git a/src/caustics/lenses/func/__init__.py b/src/caustics/lenses/func/__init__.py index f9bfc03c..f46be3df 100644 --- a/src/caustics/lenses/func/__init__.py +++ b/src/caustics/lenses/func/__init__.py @@ -1,5 +1,12 @@ from .base import ( forward_raytrace, + triangle_contains, + triangle_area, + triangle_neighbors, + triangle_upsample, + triangle_equals, + remove_triangle_duplicates, + forward_raytrace_rootfind, physical_from_reduced_deflection_angle, reduced_from_physical_deflection_angle, time_delay_arcsec2_to_days, @@ -60,6 +67,13 @@ __all__ = ( "forward_raytrace", + "triangle_contains", + "triangle_area", + "triangle_neighbors", + "triangle_upsample", + "triangle_equals", + "remove_triangle_duplicates", + "forward_raytrace_rootfind", "physical_from_reduced_deflection_angle", "reduced_from_physical_deflection_angle", "time_delay_arcsec2_to_days", diff --git a/src/caustics/lenses/func/base.py b/src/caustics/lenses/func/base.py index e86a579b..c9d8ff7a 100644 --- a/src/caustics/lenses/func/base.py +++ b/src/caustics/lenses/func/base.py @@ -4,38 +4,136 @@ from ...constants import arcsec_to_rad, c_Mpc_s, days_to_seconds -def forward_raytrace(bx, by, raytrace, epsilon, n_init, fov): +def triangle_contains(p, v): """ - Perform a forward ray-tracing operation which maps from the source plane to the image plane. + determine if point v is inside triangle p. Where p is a (3,2) tensor, and v + is a (2,) tensor. + """ + p01 = p[1] - p[0] + p02 = p[2] - p[0] + dp0p02 = p[0][0] * p02[1] - p[0][1] * p02[0] + dp0p01 = p[0][0] * p01[1] - p[0][1] * p01[0] + dp01p02 = p01[0] * p02[1] - p01[1] * p02[0] + dvp02 = v[0] * p02[1] - v[1] * p02[0] + dvp01 = v[0] * p01[1] - v[1] * p01[0] + a = (dvp02 - dp0p02) / dp01p02 + b = -(dvp01 - dp0p01) / dp01p02 + return (a >= 0) & (b >= 0) & (a + b <= 1) + + +def triangle_area(p): + """ + Determine the area of triangle p where p is a (3,2) tensor. + """ + return ( + 0.5 + * ( + p[0][0] * (p[1][1] - p[2][1]) + + p[1][0] * (p[2][1] - p[0][1]) + + p[2][0] * (p[0][1] - p[1][1]) + ).abs() + ) - Parameters - ---------- - bx: Tensor - Tensor of x coordinate in the source plane. - *Unit: arcsec* +def triangle_neighbors(p): + """ + Build a set of neighbors for triangle p where p is a (3,2) tensor. The + neighbors all have the same shape as p, but are various translations and + reflections of p that share a common edge or vertex. + """ + p01 = p[1] - p[0] + p02 = p[2] - p[0] + p12 = p[2] - p[1] + pref = -(p - p[0]) + p[0] + return torch.stack( + ( + p, + p + p01, + p - p01, + p + p02, + p - p02, + p + p12, + p - p12, + pref, + pref + p01, + pref + 2 * p01, + pref + p02, + pref + 2 * p02, + pref + p01 + p02, + ), + dim=0, + ) - by: Tensor - Tensor of y coordinate in the source plane. - *Unit: arcsec* +def triangle_upsample(p): + """ + Upsample triangle p where p is a (3,2) tensor. The upsampled triangles are + all triangles internal to p built by taking the midpoints of the edges of p. + """ + p01 = (p[1] + p[0]) / 2 + p02 = (p[2] + p[0]) / 2 + p12 = (p[2] + p[1]) / 2 + return torch.stack( + ( + torch.stack((p[0], p01, p02), dim=0), + torch.stack((p01, p[1], p12), dim=0), + torch.stack((p02, p12, p[2]), dim=0), + torch.stack((p01, p12, p02), dim=0), + ), + dim=0, + ) - raytrace: function - function that takes in the x and y coordinates in the image plane and returns the x and y coordinates in the source plane. - epsilon: Tensor - maximum distance between two images (arcsec) before they are considered the same image. +def triangle_equals(p1, p2): + """ + Determine if two triangles are equal. Where p1 and p2 are (3,2) tensors. + """ + return torch.all((p1 - p2).abs() < 1e-6) + + +def remove_triangle_duplicates(p): + unique_triangles = torch.zeros((0, 3, 2)) + B = p.shape[0] + batch_triangle_equals = torch.vmap(triangle_equals, in_dims=(None, 0)) + for i in range(B): + # Compare current triangle with all triangles in the unique list + if i == 0 or not batch_triangle_equals(p[i], unique_triangles).any(): + unique_triangles = torch.cat((unique_triangles, p[i].unsqueeze(0)), dim=0) + + return unique_triangles + + +def forward_raytrace_rootfind(ix, iy, bx, by, raytrace): + """ + Perform a forward ray-tracing operation which maps from the source plane to + the image plane. + + Parameters + ---------- + ix: Tensor + Tensor of x coordinate in the image plane. This initializes the + ray-tracing optimization. Should have shape (B, 2). *Unit: arcsec* - n_init: int - number of random initialization points used to try and find image plane points. + iy: Tensor + Tensor of y coordinate in the image plane. This initializes the + ray-tracing optimization. Should have shape (B, 2). + + bx: Tensor + Tensor of x coordinate in the source plane. Should be a scalar. - fov: float - the field of view in which the initial random samples are taken. + *Unit: arcsec* + + by: Tensor + Tensor of y coordinate in the source plane. Should be a scalar. *Unit: arcsec* + raytrace: function + function that takes in the x and y coordinates in the image plane and + returns the x and y coordinates in the source plane. + Returns ------- x_component: Tensor @@ -48,37 +146,70 @@ def forward_raytrace(bx, by, raytrace, epsilon, n_init, fov): *Unit: arcsec* """ - bxy = torch.stack((bx, by)).repeat(n_init, 1) # has shape (n_init, Dout:2) - - # Random starting points in image plane - guesses = ( - torch.as_tensor(fov, dtype=bx.dtype) - * (torch.rand(n_init, 2, dtype=bx.dtype) - 0.5) - ).to( - device=bxy.device - ) # Has shape (n_init, Din:2) - + ixy = torch.stack((ix, iy), dim=1) # has shape (B, Din:2) + bxy = torch.stack((bx, by)).repeat(ix.shape[0], 1) # has shape (B, Dout:2) # Optimize guesses in image plane x, l, c = batch_lm( # noqa: E741 Unused `l` variable - guesses, + ixy, bxy, lambda *a, **k: torch.stack( raytrace(a[0][..., 0], a[0][..., 1], *a[1:], **k), dim=-1 ), ) + return x - # Clip points that didn't converge - x = x[c < 1e-2 * epsilon**2] - # Cluster results into n-images - res = [] - while len(x) > 0: - res.append(x[0]) - d = torch.linalg.norm(x - x[0], dim=-1) - x = x[d > epsilon] +def forward_raytrace(s, raytrace, x0, y0, fov, n, epsilon): - res = torch.stack(res, dim=0) - return res[..., 0], res[..., 1] + # Construct a tiling of the image plane (squares at this point) + X, Y = torch.meshgrid( + torch.linspace(x0 - fov / 2, x0 + fov / 2, n), + torch.linspace(y0 - fov / 2, y0 + fov / 2, n), + indexing="ij", + ) + E = torch.stack((X, Y), dim=-1) + # build the upper and lower triangles within the squares of the grid + E = torch.cat( + ( + torch.stack((E[:-1, :-1], E[:-1, 1:], E[1:, 1:]), dim=-2), + torch.stack((E[:-1, :-1], E[1:, :-1], E[1:, 1:]), dim=-2), + ), + dim=0, + ).reshape(-1, 3, 2) + + i = 0 + while True: + + # Expand the search to neighboring triangles + if i > 0: # no need for neighbors in the first iteration + E = torch.vmap(triangle_neighbors)(E) + E = E.reshape(-1, 3, 2) + E = remove_triangle_duplicates(E) + # Upsample the triangles + E = torch.vmap(triangle_upsample)(E) + E = E.reshape(-1, 3, 2) + + S = raytrace(E[..., 0], E[..., 1]) + S = torch.stack(S, dim=-1) + + # Identify triangles that contain the source plane point + locate = torch.vmap(triangle_contains, in_dims=(0, None))(S, s) + E = E[locate] + i += 1 + + if triangle_area(E[0]) > epsilon**2: + # Rootfind the source plane point in the triangle + Emid = E.sum(dim=1) / 3 + Emid = forward_raytrace_rootfind( + Emid[..., 0], Emid[..., 1], s[0], s[1], raytrace + ) + Smid = raytrace(Emid[..., 0], Emid[..., 1]) + Smid = torch.stack(Smid, dim=-1) + if torch.all(torch.vmap(triangle_contains)(E, Emid)) and torch.allclose( + Smid, s, atol=epsilon + ): + break + return Emid[..., 0], Emid[..., 1] def physical_from_reduced_deflection_angle(ax, ay, d_s, d_ls): diff --git a/src/caustics/sims/lens_source.py b/src/caustics/sims/lens_source.py index 05a13d7d..008859b2 100644 --- a/src/caustics/sims/lens_source.py +++ b/src/caustics/sims/lens_source.py @@ -22,25 +22,28 @@ class LensSource(Simulator): """Lens image of a source. - Straightforward simulator to sample a lensed image of a source - object. Constructs a sampling grid internally based on the - pixelscale and gridding parameters. It can automatically upscale - and fine sample an image. This is the most straightforward - simulator to view the image if you already have a lens and source - chosen. + Straightforward simulator to sample a lensed image of a source object. + Constructs a sampling grid internally based on the pixelscale and gridding + parameters. It can automatically upscale and fine sample an image. This is + the most straightforward simulator to view the image if you already have a + lens and source chosen. - Example usage:: + Example usage: + + .. code:: python import matplotlib.pyplot as plt import caustics cosmo = caustics.FlatLambdaCDM() - lens = caustics.lenses.SIS(cosmology = cosmo, x0 = 0., y0 = 0., th_ein = 1.) - source = caustics.sources.Sersic(x0 = 0., y0 = 0., q = 0.5, phi = 0.4, n = 2., Re = 1., Ie = 1.) - sim = caustics.sims.LensSource(lens, source, pixelscale = 0.05, gridx = 100, gridy = 100, upsample_factor = 2, z_s = 1.) + lens = caustics.lenses.SIS(cosmology=cosmo, x0=0.0, y0=0.0, th_ein=1.0) + source = caustics.sources.Sersic(x0=0.0, y0=0.0, q=0.5, phi=0.4, n=2.0, Re=1.0, Ie=1.0) + sim = caustics.sims.LensSource( + lens, source, pixelscale=0.05, pixels_x=100, upsample_factor=2, z_s=1.0 + ) img = sim() - plt.imshow(img, origin = "lower") + plt.imshow(img, origin="lower") plt.show() Attributes @@ -56,13 +59,21 @@ class LensSource(Simulator): lens_light: Source, optional caustics light object which defines the lensing object's light psf: Tensor, optional - An image to convolve with the scene. Note that if ``upsample_factor > 1`` the psf must also be at the higher resolution. + An image to convolve with the scene. Note that if ``upsample_factor > + 1`` the psf must also be at the higher resolution. pixels_y: Optional[int] - number of pixels on the y-axis for the sampling grid. If left as ``None`` then this will simply be equal to ``gridx`` + number of pixels on the y-axis for the sampling grid. If left as + ``None`` then this will simply be equal to ``gridx`` upsample_factor (default 1) - Amount of upsampling to model the image. For example ``upsample_factor = 2`` indicates that the image will be sampled at double the resolution then summed back to the original resolution (given by pixelscale and gridx/y). - psf_pad: Boolean(default True) - If convolving the PSF it is important to sample the model in a larger FOV equal to half the PSF size in order to account for light that scatters from outside the requested FOV inwards. Internally this padding will be added before sampling, then cropped off before returning the final image to the user. + Amount of upsampling to model the image. For example ``upsample_factor = + 2`` indicates that the image will be sampled at double the resolution + then summed back to the original resolution (given by pixelscale and + gridx/y). + quad_level: int (default None) + sub pixel integration resolution. This will use Gaussian quadrature to + sample the image at a higher resolution, then integrate the image back + to the original resolution. This is useful for high accuracy integration + of the image, but may increase memory usage and runtime. z_s: optional redshift of the source name: string (default "sim") @@ -70,10 +81,28 @@ class LensSource(Simulator): Notes: ----- - - The simulator will automatically pad the image to half the PSF size to ensure valid convolution. This is done by default, but can be turned off by setting ``psf_pad = False``. This is only relevant if you are using a PSF. - - The upsample factor will increase the resolution of the image by the given factor. For example, ``upsample_factor = 2`` will sample the image at double the resolution, then sum back to the original resolution. This is used when a PSF is provided at high resolution than the original image. Not that the when a PSF is used, the upsample_factor must equal the PSF upsampling level. - - For arbitrary pixel integration accuracy using the quad_level parameter. This will use Gaussian quadrature to sample the image at a higher resolution, then integrate the image back to the original resolution. This is useful for high accuracy integration of the image, but is not recommended for large images as it will be slow. The quad_level and upsample_factor can be used together to achieve high accuracy integration of the image convolved with a PSF. - - A `Pixelated` light source is defined by bilinear interpolation of the provided image. This means that sub-pixel integration is not required for accurate integration of the pixels. However, if you are using a PSF then you should still use upsample_factor (if your PSF is supersampled) to ensure that everything is sampled at the PSF resolution. + - The simulator will automatically pad the image to half the PSF size to + ensure valid convolution. This is done by default, but can be turned off + by setting ``psf_pad = False``. This is only relevant if you are using a + PSF. + - The upsample factor will increase the resolution of the image by the given + factor. For example, ``upsample_factor = 2`` will sample the image at + double the resolution, then sum back to the original resolution. This is + used when a PSF is provided at high resolution than the original image. + Not that the when a PSF is used, the upsample_factor must equal the PSF + upsampling level. + - For arbitrary pixel integration accuracy using the quad_level parameter. + This will use Gaussian quadrature to sample the image at a higher + resolution, then integrate the image back to the original resolution. This + is useful for high accuracy integration of the image, but is not + recommended for large images as it will be slow. The quad_level and + upsample_factor can be used together to achieve high accuracy integration + of the image convolved with a PSF. + - A `Pixelated` light source is defined by bilinear interpolation of the + provided image. This means that sub-pixel integration is not required for + accurate integration of the pixels. However, if you are using a PSF then + you should still use upsample_factor (if your PSF is supersampled) to + ensure that everything is sampled at the PSF resolution. """ # noqa: E501 @@ -91,18 +120,21 @@ def __init__( Optional[Source], "caustics light object which defines the lensing object's light", ] = None, - psf: Annotated[Optional[Tensor], "An image to convolve with the scene"] = None, pixels_y: Annotated[ Optional[int], "number of pixels on the y-axis for the sampling grid" ] = None, upsample_factor: Annotated[int, "Amount of upsampling to model the image"] = 1, - psf_pad: Annotated[bool, "Flag to apply padding to psf"] = True, + quad_level: Annotated[Optional[int], "sub pixel integration resolution"] = None, psf_mode: Annotated[ Literal["fft", "conv2d"], "Mode for convolving psf" ] = "fft", + psf_shape: Annotated[Optional[tuple[int, ...]], "The shape of the psf"] = None, z_s: Annotated[ Optional[Union[Tensor, float]], "Redshift of the source", True ] = None, + psf: Annotated[ + Optional[Union[Tensor, list]], "An image to convolve with the scene", True + ] = [[1.0]], x0: Annotated[ Optional[Union[Tensor, float]], "center of the fov for the lens source image", @@ -121,54 +153,125 @@ def __init__( self.lens = lens self.source = source self.lens_light = lens_light - if psf is None: - self.psf = None - else: - self.psf = torch.as_tensor(psf) - self.psf /= psf.sum() # ensure normalized + + # Configure PSF + self._psf_mode = psf_mode + if psf is not None: + psf = torch.as_tensor(psf) + self._psf_shape = psf.shape if psf is not None else psf_shape + + # Build parameters self.add_param("z_s", z_s) + self.add_param("psf", psf, self.psf_shape) self.add_param("x0", x0) self.add_param("y0", y0) - self.pixelscale = pixelscale + self._pixelscale = pixelscale # Image grid - if pixels_y is None: - pixels_y = pixels_x - self.gridding = (pixels_x, pixels_y) - - # PSF padding if needed - self.psf_mode = psf_mode - if psf_pad and self.psf is not None: - self.psf_pad = (self.psf.shape[1] // 2 + 1, self.psf.shape[0] // 2 + 1) - else: - self.psf_pad = (0, 0) + self._pixels_x = pixels_x + self._pixels_y = pixels_x if pixels_y is None else pixels_y + self._upsample_factor = upsample_factor + self._quad_level = quad_level # Build the imaging grid - self.upsample_factor = upsample_factor - self.n_pix = ( - self.gridding[0] + self.psf_pad[0] * 2, - self.gridding[1] + self.psf_pad[1] * 2, - ) - self.grid = meshgrid( - pixelscale / self.upsample_factor, - self.n_pix[0] * self.upsample_factor, - self.n_pix[1] * self.upsample_factor, - ) - - if self.psf is not None: - self.psf_fft = self._fft2_padded(self.psf) + self._build_grid() def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None ): super().to(device, dtype) - if self.psf is not None: - self.psf = self.psf.to(device, dtype) - self.psf_fft = self.psf_fft.to(device, dtype) - self.grid = tuple(x.to(device, dtype) for x in self.grid) + self._grid = tuple(x.to(device, dtype) for x in self._grid) # type: ignore[has-type] + self._weights = self._weights.to(device, dtype) # type: ignore[has-type] return self + @property + def upsample_factor(self): + return self._upsample_factor + + @upsample_factor.setter + def upsample_factor(self, value): + self._upsample_factor = value + self._build_grid() + + @property + def pixels_x(self): + return self._pixels_x + + @pixels_x.setter + def pixels_x(self, value): + self._pixels_x = value + self._build_grid() + + @property + def pixels_y(self): + return self._pixels_y + + @pixels_y.setter + def pixels_y(self, value): + self._pixels_y = value + self._build_grid() + + @property + def quad_level(self): + return self._quad_level + + @quad_level.setter + def quad_level(self, value): + self._quad_level = value + self._build_grid() + + @property + def pixelscale(self): + return self._pixelscale + + @pixelscale.setter + def pixelscale(self, value): + self._pixelscale = value + self._build_grid() + + @property + def psf_shape(self): + return self._psf_shape + + @psf_shape.setter + def psf_shape(self, value): + self._psf_shape = value + self._build_grid() + + @property + def psf_mode(self): + return self._psf_mode + + @psf_mode.setter + def psf_mode(self, value): + self._psf_mode = value + self._build_grid() + + def _build_grid(self): + self._psf_pad = (self.psf_shape[1] // 2, self.psf_shape[0] // 2) + + self._n_pix = ( + self.pixels_x + self._psf_pad[0] * 2, + self.pixels_y + self._psf_pad[1] * 2, + ) + self._grid = meshgrid( + self.pixelscale / self.upsample_factor, + self._n_pix[0] * self.upsample_factor, + self._n_pix[1] * self.upsample_factor, + ) + self._weights = torch.ones( + (1, 1), dtype=self._grid[0].dtype, device=self._grid[0].device + ) + if self.quad_level is not None and self.quad_level > 1: + finegrid_x, finegrid_y, weights = gaussian_quadrature_grid( + self.pixelscale / self.upsample_factor, *self._grid, self.quad_level + ) + self._grid = (finegrid_x, finegrid_y) + self._weights = weights + else: + self._grid = (self._grid[0].unsqueeze(-1), self._grid[1].unsqueeze(-1)) + def _fft2_padded(self, x): """ Compute the 2D Fast Fourier Transform (FFT) of a tensor with zero-padding. @@ -179,7 +282,7 @@ def _fft2_padded(self, x): Returns: Tensor: The 2D FFT of the input tensor with zero-padding. """ - npix = copy(self.n_pix) + npix = copy(self._n_pix) npix = (next_fast_len(npix[0]), next_fast_len(npix[1])) self._s = npix @@ -199,8 +302,8 @@ def _unpad_fft(self, x): Tensor The input tensor without padding. """ - return torch.roll(x, (-self.psf_pad[0], -self.psf_pad[1]), dims=(-2, -1))[ - ..., : self.n_pix[0], : self.n_pix[1] + return torch.roll(x, (-self._psf_pad[0], -self._psf_pad[1]), dims=(-2, -1))[ + ..., : self._s[0], : self._s[1] ] def forward( @@ -210,8 +313,6 @@ def forward( lens_light=True, lens_source=True, psf_convolve=True, - quad_level=None, - **kwargs, ): """ forward function @@ -229,75 +330,63 @@ def forward( psf_convolve: boolean when true the image will be convolved with the psf """ - z_s, x0, y0 = self.unpack(params) + z_s, psf, x0, y0 = self.unpack(params) # Automatically turn off light for missing objects if self.source is None: source_light = False if self.lens_light is None: lens_light = False - if self.psf is None: + if psf.shape == (1, 1): psf_convolve = False - grid = (self.grid[0] + x0, self.grid[1] + y0) - - if quad_level is not None and quad_level > 1: - finegrid_x, finegrid_y, weights = gaussian_quadrature_grid( - self.pixelscale / self.upsample_factor, *grid, quad_level - ) + grid = (self._grid[0] + x0, self._grid[1] + y0) # Sample the source light if source_light: if lens_source: # Source is lensed by the lens mass distribution - if quad_level is not None and quad_level > 1: - bx, by = self.lens.raytrace(finegrid_x, finegrid_y, z_s, params) - mu_fine = self.source.brightness(bx, by, params) - mu = gaussian_quadrature_integrator(mu_fine, weights) - else: - bx, by = self.lens.raytrace(*grid, z_s, params) - mu = self.source.brightness(bx, by, params) + bx, by = self.lens.raytrace(*grid, z_s, params) + mu_fine = self.source.brightness(bx, by, params) + mu = gaussian_quadrature_integrator(mu_fine, self._weights) else: # Source is imaged without lensing - if quad_level is not None and quad_level > 1: - mu_fine = self.source.brightness(finegrid_x, finegrid_y, params) - mu = gaussian_quadrature_integrator(mu_fine, weights) - else: - mu = self.source.brightness(*grid, params) + mu_fine = self.source.brightness(*grid, params) + mu = gaussian_quadrature_integrator(mu_fine, self._weights) else: # Source is not added to the scene - mu = torch.zeros_like(grid[0]) + mu = torch.zeros_like(grid[0][..., 0]) # chop off quad dim # Sample the lens light if lens_light and self.lens_light is not None: - if quad_level is not None and quad_level > 1: - mu_fine = self.lens_light.brightness(finegrid_x, finegrid_y, params) - mu += gaussian_quadrature_integrator(mu_fine, weights) - else: - mu += self.lens_light.brightness(*grid, params) + mu_fine = self.lens_light.brightness(*grid, params) + mu += gaussian_quadrature_integrator(mu_fine, self._weights) # Convolve the PSF - if psf_convolve and self.psf is not None: + if psf_convolve: if self.psf_mode == "fft": mu_fft = self._fft2_padded(mu) - mu = self._unpad_fft( - torch.fft.irfft2(mu_fft * self.psf_fft, self._s).real - ) + psf_fft = self._fft2_padded(psf / psf.sum()) + mu = self._unpad_fft(torch.fft.irfft2(mu_fft * psf_fft, self._s).real) elif self.psf_mode == "conv2d": - mu = conv2d( - mu[None, None], self.psf[None, None], padding="same" - ).squeeze() + mu = ( + conv2d( + mu[None, None], (psf.T / psf.sum())[None, None], padding="same" + ) + .squeeze(0) + .squeeze(0) + ) else: raise ValueError( f"psf_mode should be one of 'fft' or 'conv2d', not {self.psf_mode}" ) # Return to the desired image - mu_native_resolution = avg_pool2d( - mu[None, None], self.upsample_factor, divisor_override=1 - ).squeeze() + mu_native_resolution = ( + avg_pool2d(mu[None, None], self.upsample_factor).squeeze(0).squeeze(0) + ) mu_clipped = mu_native_resolution[ - self.psf_pad[1] : self.gridding[1] + self.psf_pad[1], - self.psf_pad[0] : self.gridding[0] + self.psf_pad[0], + self._psf_pad[1] : self.pixels_y + self._psf_pad[1], + self._psf_pad[0] : self.pixels_x + self._psf_pad[0], ] return mu_clipped diff --git a/src/caustics/utils.py b/src/caustics/utils.py index ccbe24d5..f6399979 100644 --- a/src/caustics/utils.py +++ b/src/caustics/utils.py @@ -1037,7 +1037,8 @@ def _lm_step(f, X, Y, Cinv, L, Lup, Ldn, epsilon, L_min, L_max): chi2_new = (dYnew @ Cinv @ dYnew).sum(-1) # Test - rho = (chi2 - chi2_new) / torch.abs(h @ (L * torch.dot(torch.diag(hess), h) + grad)) # fmt: skip + expected_improvement = torch.dot(h, hess @ h) + 2 * torch.dot(h, grad) + rho = (chi2 - chi2_new) / torch.abs(expected_improvement) # fmt: skip # Update X = torch.where(rho >= epsilon, X + h, X) diff --git a/tests/models/test_mod_api.py b/tests/models/test_mod_api.py index 8c912a03..a8cb6a01 100644 --- a/tests/models/test_mod_api.py +++ b/tests/models/test_mod_api.py @@ -114,8 +114,8 @@ def sim_obj(): def test_build_simulator(sim_yaml_file, sim_obj, x_input): sim = caustics.build_simulator(sim_yaml_file) - result = sim(x_input, quad_level=3) - expected_result = sim_obj(x_input, quad_level=3) + result = sim(x_input) + expected_result = sim_obj(x_input) assert sim.graph(True, True) assert isinstance(result, torch.Tensor) assert torch.allclose(result, expected_result) @@ -149,7 +149,7 @@ def test_complex_build_simulator(): # Open the temp file and build the simulator sim = caustics.build_simulator(temp_file) - image = sim(x, quad_level=3) + image = sim(x) assert isinstance(image, torch.Tensor) # Remove the temp file @@ -197,8 +197,8 @@ def test_build_simulator_w_state(sim_yaml_file, sim_obj, x_input): # First remove the original sim del sim newsim = caustics.build_simulator(sim_yaml_file) - result = newsim(quad_level=3) - expected_result = sim_obj(x_input, quad_level=3) + result = newsim() + expected_result = sim_obj(x_input) assert newsim.graph(True, True) assert isinstance(result, torch.Tensor) assert torch.allclose(result, expected_result) @@ -217,7 +217,7 @@ def test_build_simulator_w_state(sim_yaml_file, sim_obj, x_input): "upsample": 2, }, }, - {"function": "caustics.utils.gaussian", "sigma": 0.2}, + # {"function": "caustics.utils.gaussian", "sigma": 0.2}, [[2.0], [2.0]], ], ) diff --git a/tests/test_epl.py b/tests/test_epl.py index f0d79491..1c854e13 100644 --- a/tests/test_epl.py +++ b/tests/test_epl.py @@ -8,9 +8,16 @@ from caustics.cosmology import FlatLambdaCDM from caustics.lenses import EPL +import numpy as np +import pytest -def test_lenstronomy(sim_source, device, lens_models): + +@pytest.mark.parametrize("q", [0.4, 0.7]) +@pytest.mark.parametrize("phi", [pi / 3, -pi / 4]) +@pytest.mark.parametrize("b", [0.1, 1.0]) +@pytest.mark.parametrize("t", [0.1, 1.0, 1.9]) +def test_lenstronomy_epl(sim_source, device, lens_models, q, phi, b, t): if sim_source == "yaml": yaml_str = """\ cosmology: &cosmology @@ -36,10 +43,10 @@ def test_lenstronomy(sim_source, device, lens_models): # Parameters z_s = torch.tensor(1.0, device=device) - x = torch.tensor([0.7, 0.912, -0.442, 0.7, pi / 3, 1.4, 1.35], device=device) + x = torch.tensor([0.7, 0.912, -0.442, q, phi, b, t], device=device) - e1, e2 = param_util.phi_q2_ellipticity(phi=x[4].item(), q=x[3].item()) - theta_E = (x[5] / x[3].sqrt()).item() + e1, e2 = param_util.phi_q2_ellipticity(phi=phi, q=q) + theta_E = b / np.sqrt(q) # (x[5] / x[3].sqrt()).item() kwargs_ls = [ { "theta_E": theta_E, @@ -47,19 +54,19 @@ def test_lenstronomy(sim_source, device, lens_models): "e2": e2, "center_x": x[1].item(), "center_y": x[2].item(), - "gamma": x[6].item() + 1, # important: add +1 + "gamma": t + 1, # important: add +1 } ] # Different tolerances for difference quantities alpha_test_helper( - lens, lens_ls, z_s, x, kwargs_ls, rtol=1e-100, atol=6e-5, device=device + lens, lens_ls, z_s, x, kwargs_ls, rtol=1e-100, atol=1e-3, device=device ) kappa_test_helper( lens, lens_ls, z_s, x, kwargs_ls, rtol=3e-5, atol=1e-100, device=device ) Psi_test_helper( - lens, lens_ls, z_s, x, kwargs_ls, rtol=3e-5, atol=1e-100, device=device + lens, lens_ls, z_s, x, kwargs_ls, rtol=1e-3, atol=1e-100, device=device ) @@ -98,8 +105,3 @@ def test_special_case_sie(device): Psi_test_helper( lens, lens_ls, z_s, x, kwargs_ls, rtol=3e-5, atol=1e-100, device=device ) - - -if __name__ == "__main__": - test_lenstronomy(None) - test_special_case_sie(None) diff --git a/tests/test_masssheet.py b/tests/test_masssheet.py index a1f79ef2..b455d05b 100644 --- a/tests/test_masssheet.py +++ b/tests/test_masssheet.py @@ -5,8 +5,11 @@ from caustics.lenses import MassSheet from caustics.utils import meshgrid +import pytest -def test(sim_source, device, lens_models): + +@pytest.mark.parametrize("convergence", [-1.0, 0.0, 1.0]) +def test_masssheet(sim_source, device, lens_models, convergence): if sim_source == "yaml": yaml_str = """\ cosmology: &cosmology @@ -30,12 +33,17 @@ def test(sim_source, device, lens_models): # Parameters z_s = torch.tensor(1.2) - x = torch.tensor([0.5, 0.0, 0.0, 0.7]) + x = torch.tensor([0.5, 0.0, 0.0, convergence]) thx, thy = meshgrid(0.01, 10, device=device) - lens.reduced_deflection_angle(thx, thy, z_s, x) + ax, ay = lens.reduced_deflection_angle(thx, thy, z_s, x) + + p = lens.potential(thx, thy, z_s, x) - lens.potential(thx, thy, z_s, x) + c = lens.convergence(thx, thy, z_s, x) - lens.convergence(thx, thy, z_s, x) + assert torch.all(torch.isfinite(ax)) + assert torch.all(torch.isfinite(ay)) + assert torch.all(torch.isfinite(p)) + assert torch.all(torch.isfinite(c)) diff --git a/tests/test_nfw.py b/tests/test_nfw.py index a9c30191..cd6c2984 100644 --- a/tests/test_nfw.py +++ b/tests/test_nfw.py @@ -14,12 +14,16 @@ from caustics.cosmology import FlatLambdaCDM as CausticFlatLambdaCDM from caustics.lenses import NFW +import pytest + h0_default = float(default_cosmology.get().h) Om0_default = float(default_cosmology.get().Om0) Ob0_default = float(default_cosmology.get().Ob0) -def test(sim_source, device, lens_models): +@pytest.mark.parametrize("m", [1e8, 1e10, 1e12]) +@pytest.mark.parametrize("c", [1.0, 8.0, 20.0]) +def test_nfw(sim_source, device, lens_models, m, c): atol = 1e-5 rtol = 3e-2 z_l = torch.tensor(0.1) @@ -54,8 +58,8 @@ def test(sim_source, device, lens_models): thx0 = 0.457 thy0 = 0.141 - m = 1e12 - c = 8.0 + # m = 1e12 + # c = 8.0 x = torch.tensor([thx0, thy0, m, c]) # Lenstronomy @@ -113,7 +117,3 @@ def test_runs(sim_source, device, lens_models): assert torch.all(torch.isfinite(alpha[1])) kappa = lens.convergence(thx, thy, z_s, x) assert torch.all(torch.isfinite(kappa)) - - -if __name__ == "__main__": - test() diff --git a/tests/test_point.py b/tests/test_point.py index 2584c944..022e3fcc 100644 --- a/tests/test_point.py +++ b/tests/test_point.py @@ -6,8 +6,11 @@ from caustics.cosmology import FlatLambdaCDM from caustics.lenses import Point +import pytest -def test(sim_source, device, lens_models): + +@pytest.mark.parametrize("th_ein", [0.1, 1.0, 2.0]) +def test_point_lens(sim_source, device, lens_models, th_ein): atol = 1e-5 rtol = 1e-5 z_l = torch.tensor(0.9) @@ -37,13 +40,7 @@ def test(sim_source, device, lens_models): # Parameters z_s = torch.tensor(1.2) - x = torch.tensor([0.912, -0.442, 1.1]) - kwargs_ls = [ - {"center_x": x[0].item(), "center_y": x[1].item(), "theta_E": x[2].item()} - ] + x = torch.tensor([0.912, -0.442, th_ein]) + kwargs_ls = [{"center_x": x[0].item(), "center_y": x[1].item(), "theta_E": th_ein}] lens_test_helper(lens, lens_ls, z_s, x, kwargs_ls, rtol, atol, device=device) - - -if __name__ == "__main__": - test(None) diff --git a/tests/test_pseudo_jaffe.py b/tests/test_pseudo_jaffe.py index 84d75fe7..6e45b84e 100644 --- a/tests/test_pseudo_jaffe.py +++ b/tests/test_pseudo_jaffe.py @@ -6,8 +6,12 @@ from caustics.cosmology import FlatLambdaCDM from caustics.lenses import PseudoJaffe +import pytest -def test(sim_source, device, lens_models): + +@pytest.mark.parametrize("mass", [1e8, 1e10, 1e12]) +@pytest.mark.parametrize("Rc,Rs", [[1.0, 10.0], [1e-2, 1.0], [0.5, 1.0]]) +def test_pseudo_jaffe(sim_source, device, lens_models, mass, Rc, Rs): atol = 1e-5 rtol = 1e-5 @@ -35,26 +39,9 @@ def test(sim_source, device, lens_models): # Parameters, computing kappa_0 with a helper function z_s = torch.tensor(2.1) - x = torch.tensor([0.5, 0.071, 0.023, -1e100, 0.5, 1.5]) - d_l = cosmology.angular_diameter_distance(x[0]) - arcsec_to_rad = 1 / (180 / torch.pi * 60**2) - kappa_0 = lens.central_convergence( - x[0], - z_s, - torch.tensor(2e11), - x[4] * d_l * arcsec_to_rad, - x[5] * d_l * arcsec_to_rad, - cosmology.critical_surface_density(x[0], z_s), - ) - x[3] = ( - 2 - * torch.pi - * kappa_0 - * cosmology.critical_surface_density(x[0], z_s) - * x[4] - * x[5] - * (d_l * arcsec_to_rad) ** 2 - ) + x = torch.tensor([0.5, 0.071, 0.023, mass, Rc, Rs]) + kappa_0 = lens.get_convergence_0(z_s, x) + kwargs_ls = [ { "sigma0": kappa_0.item(), @@ -97,7 +84,3 @@ def test_massenclosed(device): masses = lens.mass_enclosed_2d(xx, z_s, x) assert torch.all(masses < x[3]) - - -if __name__ == "__main__": - test(None) diff --git a/tests/test_sersic.py b/tests/test_sersic.py index c215bbb3..0c8089e9 100644 --- a/tests/test_sersic.py +++ b/tests/test_sersic.py @@ -8,8 +8,13 @@ from caustics.light import Sersic from caustics.utils import meshgrid +import pytest -def test(sim_source, device, light_models): + +@pytest.mark.parametrize("q", [0.2, 0.7]) +@pytest.mark.parametrize("n", [1.0, 2.0, 3.0]) +@pytest.mark.parametrize("th_e", [1.0, 10.0]) +def test_sersic(sim_source, device, light_models, q, n, th_e): # Caustics setup res = 0.05 nx = 200 @@ -48,9 +53,9 @@ def test(sim_source, device, light_models): thx0_src = 0.05 thy0_src = 0.01 phi_src = 0.0 - q_src = 0.5 - index_src = 1.5 - th_e_src = 0.1 + q_src = q + index_src = n + th_e_src = th_e I_e_src = 100 # NOTE: in several places we use np.sqrt(q_src) in order to match # the definition used by lenstronomy. This only works when phi = 0. @@ -86,7 +91,3 @@ def test(sim_source, device, light_models): brightness_ls = sersic_ls.surface_brightness(x_ls, y_ls, kwargs_light_source) assert np.allclose(brightness.cpu().numpy(), brightness_ls) - - -if __name__ == "__main__": - test(None) diff --git a/tests/test_sie.py b/tests/test_sie.py index f4d06e3a..8b379936 100644 --- a/tests/test_sie.py +++ b/tests/test_sie.py @@ -10,10 +10,15 @@ from caustics.lenses import SIE from caustics.utils import meshgrid +import pytest -def test(sim_source, device, lens_models): + +@pytest.mark.parametrize("q", [0.5, 0.7, 0.9]) +@pytest.mark.parametrize("phi", [pi / 3, -pi / 4, pi / 6]) +@pytest.mark.parametrize("th_ein", [0.1, 1.0, 2.5]) +def test_sie(sim_source, device, lens_models, q, phi, th_ein): atol = 1e-5 - rtol = 1e-5 + rtol = 1e-3 if sim_source == "yaml": yaml_str = """\ @@ -38,11 +43,11 @@ def test(sim_source, device, lens_models): # Parameters z_s = torch.tensor(1.2) - x = torch.tensor([0.5, 0.912, -0.442, 0.7, pi / 3, 1.4]) - e1, e2 = param_util.phi_q2_ellipticity(phi=x[4].item(), q=x[3].item()) + x = torch.tensor([0.5, 0.912, -0.442, q, phi, th_ein]) + e1, e2 = param_util.phi_q2_ellipticity(phi=phi, q=q) kwargs_ls = [ { - "theta_E": x[5].item(), + "theta_E": th_ein, "e1": e1, "e2": e2, "center_x": x[1].item(), @@ -97,7 +102,3 @@ def test_sie_time_delay(): ) ) ) - - -if __name__ == "__main__": - test(None) diff --git a/tests/test_simulator_runs.py b/tests/test_simulator_runs.py index 0efa49ff..23816889 100644 --- a/tests/test_simulator_runs.py +++ b/tests/test_simulator_runs.py @@ -40,7 +40,7 @@ def test_simulator_runs(sim_source, device, mocker): y0: -0.03 q: 0.6 phi: -pi / 4 - n: 2.0 + n: 1.5 Re: 0.5 Ie: 1.0 @@ -61,7 +61,7 @@ def test_simulator_runs(sim_source, device, mocker): kwargs: pixelscale: 0.05 nx: 11 - ny: 12 + ny: 11 sigma: 0.2 upsample: 2 @@ -71,16 +71,20 @@ def test_simulator_runs(sim_source, device, mocker): params: z_s: 2.0 init_kwargs: - # Single lense + # Single lens lens: *lensmass source: *source lens_light: *lenslight pixelscale: 0.05 pixels_x: 50 - psf: *psf + psf: *psf{quad_level} """ - mock_from_file(mocker, yaml_str) + mock_from_file( + mocker, yaml_str.format(quad_level="") + ) # fixme, yaml should be able to accept None sim = build_simulator("/path/to/sim.yaml") # Path doesn't actually exists + mock_from_file(mocker, yaml_str.format(quad_level="\n quad_level: 3")) + sim_q3 = build_simulator("/path/to/sim.yaml") # Path doesn't actually exists else: # Model cosmology = FlatLambdaCDM(name="cosmo") @@ -115,7 +119,29 @@ def test_simulator_runs(sim_source, device, mocker): z_s=2.0, ) + sim_q3 = LensSource( + name="simulator", + lens=lensmass, + source=source, + pixelscale=0.05, + pixels_x=50, + lens_light=lenslight, + psf=psf, + z_s=2.0, + quad_level=3, + ) + sim.to(device=device) + sim_q3.to(device=device) + + # Test setters + sim.pixelscale = 0.05 + sim.pixels_x = 50 + sim.pixels_y = 50 + sim_q3.quad_level = 3 + sim.upsample_factor = 1 + sim.psf_shape = (11, 11) + sim.psf_mode = "conv2d" assert torch.all(torch.isfinite(sim())) assert torch.all( @@ -164,8 +190,7 @@ def test_simulator_runs(sim_source, device, mocker): ) # Check quadrature integration is accurate - assert torch.allclose(sim(), sim(quad_level=3), rtol=1e-1) - assert torch.allclose(sim(quad_level=3), sim(quad_level=5), rtol=1e-2) + assert torch.allclose(sim(), sim_q3(), rtol=1e-1) def test_microlens_simulator_runs(): diff --git a/tests/test_sis.py b/tests/test_sis.py index a4594e55..cf170620 100644 --- a/tests/test_sis.py +++ b/tests/test_sis.py @@ -6,8 +6,11 @@ from caustics.cosmology import FlatLambdaCDM from caustics.lenses import SIS +import pytest -def test(sim_source, device, lens_models): + +@pytest.mark.parametrize("th_ein", [0.1, 1.0, 2.0]) +def test(sim_source, device, lens_models, th_ein): atol = 1e-5 rtol = 1e-5 z_l = torch.tensor(0.5) @@ -37,7 +40,7 @@ def test(sim_source, device, lens_models): # Parameters z_s = torch.tensor(1.2) - x = torch.tensor([-0.342, 0.51, 1.4]) + x = torch.tensor([-0.342, 0.51, th_ein]) kwargs_ls = [ {"center_x": x[0].item(), "center_y": x[1].item(), "theta_E": x[2].item()} ] diff --git a/tests/test_tnfw.py b/tests/test_tnfw.py index 33a2b5eb..46e1f6e1 100644 --- a/tests/test_tnfw.py +++ b/tests/test_tnfw.py @@ -13,13 +13,20 @@ from caustics.cosmology import FlatLambdaCDM as CausticFlatLambdaCDM from caustics.lenses import TNFW +import pytest + h0_default = float(default_cosmology.get().h) Om0_default = float(default_cosmology.get().Om0) Ob0_default = float(default_cosmology.get().Ob0) -def test(sim_source, device, lens_models): +@pytest.mark.parametrize( + "m", [1e8, 1e10, 1e12] +) # Note with m=1e14 the test fails, due to the Rs_angle becoming too large (pytorch is unstable) +@pytest.mark.parametrize("c", [1.0, 8.0, 40.0]) +@pytest.mark.parametrize("t", [2.0, 5.0, 20.0]) +def test(sim_source, device, lens_models, m, c, t): atol = 1e-5 rtol = 3e-2 z_l = torch.tensor(0.1) @@ -51,16 +58,11 @@ def test(sim_source, device, lens_models): lens_model_list = ["TNFW"] lens_ls = LensModel(lens_model_list=lens_model_list) - print(lens) - # Parameters z_s = torch.tensor(0.5) thx0 = 0.457 thy0 = 0.141 - m = 1e12 - c = 8.0 - t = 3.0 x = torch.tensor([thx0, thy0, m, c, t]) # Lenstronomy @@ -118,7 +120,3 @@ def test_runs(device): Rt = lens.get_truncation_radius(x) assert Rt == (rs * t) - - -if __name__ == "__main__": - test(None) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index d1121b0d..b18b7605 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -1,6 +1,7 @@ """ Utilities for testing """ + from typing import Any, Dict, List, Union import torch @@ -16,11 +17,12 @@ from caustics.cosmology import FlatLambdaCDM from .models import mock_from_file -__all__ = ( - "mock_from_file", -) +__all__ = ("mock_from_file",) + -def setup_simulator(cosmo_static=False, use_nfw=True, simulator_static=False, batched_params=False, device=None): +def setup_simulator( + cosmo_static=False, use_nfw=True, simulator_static=False, batched_params=False, device=None +): n_pix = 20 class Sim(Simulator): @@ -88,7 +90,7 @@ def forward(self, params): cosmo_params = [_x[0] for _x in cosmo_params] lens_params = [_x[0] for _x in lens_params] source_params = [_x[0] for _x in source_params] - + sim = Sim() # Set device when not None if device is not None: @@ -97,7 +99,7 @@ def forward(self, params): cosmo_params = [_p.to(device=device) for _p in cosmo_params] lens_params = [_p.to(device=device) for _p in lens_params] source_params = [_p.to(device=device) for _p in source_params] - + return sim, (sim_params, cosmo_params, lens_params, source_params) @@ -161,7 +163,7 @@ def forward(self, params): lens_params = [_x[0] for _x in lens_params] kappa = kappa[0] source = source[0] - + sim = Sim() # Set device when not None if device is not None: @@ -177,7 +179,7 @@ def forward(self, params): def get_default_cosmologies(device=None): cosmology = FlatLambdaCDM("cosmo") cosmology_ap = FlatLambdaCDM_AP(100 * cosmology.h0.value, cosmology.Om0.value, Tcmb0=0) - + if device is not None: cosmology = cosmology.to(device=device) return cosmology, cosmology_ap @@ -210,6 +212,7 @@ def alpha_test_helper(lens, lens_ls, z_s, x, kwargs_ls, atol, rtol, device=None) thx, thy, thx_ls, thy_ls = setup_grids(device=device) alpha_x, alpha_y = lens.reduced_deflection_angle(thx, thy, z_s, x) alpha_x_ls, alpha_y_ls = lens_ls.alpha(thx_ls, thy_ls, kwargs_ls) + print(np.sum(np.abs(1 - alpha_x.cpu().numpy() / alpha_x_ls) > 1e-3)) assert np.allclose(alpha_x.cpu().numpy(), alpha_x_ls, rtol, atol) assert np.allclose(alpha_y.cpu().numpy(), alpha_y_ls, rtol, atol)